Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Perform gradient clipping on global batch when using gradient accumulation #6

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
49 changes: 47 additions & 2 deletions praxis/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2691,6 +2691,8 @@ def _get_raw_grad_transformation(self, lr: optax.Schedule):

def sharded_static_accumulation(
num_sub_batches: int,
clip_gradient_norm_to_value: float,
clip_gradient_single_norm_to_value: float,
base_tx: ShardedGradientTransformation,
) -> ShardedGradientTransformation:
"""Gradient transformation for ShardedStaticAccumulator optimizer."""
Expand Down Expand Up @@ -2759,10 +2761,52 @@ def update_fn(updates: NestedJTensor,
lambda: new_count)

def _run_base_tx():

def _compute_grad_norm(grads: NestedMap) -> JTensor:
"""Computes total grad norm."""
grad_norms_squared = jax.tree_map(lambda x: jnp.sum(x * x), grads)
grad_norms_squared, _ = jax.tree_util.tree_flatten(grad_norms_squared)
return jnp.sqrt(jnp.sum(jnp.stack(grad_norms_squared)))


def scale_gradients(
raw_grads: NestedMap,
clip_grad_norm_to_value: float = 0.0,
clip_grad_single_norm_to_value: float = 0.0):

def clip_grads(grads):
assert not (clip_grad_norm_to_value and clip_grad_single_norm_to_value)
if clip_grad_norm_to_value:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe assert only one of them is true?

grad_norm = _compute_grad_norm(raw_grads)

grad_scale = jnp.minimum(
jnp.array(1, grad_norm.dtype),
jnp.array(clip_grad_norm_to_value, grad_norm.dtype)
/ grad_norm)
grads = jax.tree_map(lambda g: g * grad_scale, grads)
elif clip_grad_single_norm_to_value:
grad_single_norm = jax.tree_map(lambda x: jnp.sqrt(jnp.sum(x * x)),
grads)

def scale_gradient(grad, norm):
return grad * jnp.minimum(
jnp.array(1, norm.dtype),
jnp.array(clip_grad_single_norm_to_value,
norm.dtype) / norm)
grads = jax.tree_map(scale_gradient, grads, grad_single_norm)

return grads

grads = clip_grads(raw_grads)
return grads

averaged_updated = jax.tree_map(lambda acc: acc / num_sub_batches,
new_accumulated_update)
scaled_updated = scale_gradients(averaged_updated,
clip_gradient_norm_to_value,
clip_gradient_single_norm_to_value)
emission_updates, emission_base_state = base_tx.update(
averaged_updated, state.base_state, params) # pytype: disable=attribute-error # jax-ndarray
scaled_updated, state.base_state, params) # pytype: disable=attribute-error # jax-ndarray
return (emission_updates,
jax.tree_map(lambda u: jnp.zeros_like(u, dtype=jnp.float32),
updates), emission_base_state)
Expand Down Expand Up @@ -2830,4 +2874,5 @@ def _get_raw_grad_transformation(
self, lr: optax.Schedule) -> GeneralGradientTransformation:
p = self._hparams
base_tx = self.base_optimizer._get_raw_grad_transformation(lr) # pylint: disable=protected-access
return sharded_static_accumulation(p.num_sub_batches, base_tx)
return sharded_static_accumulation(p.num_sub_batches, p.clip_gradient_norm_to_value,
p.clip_gradient_single_norm_to_value, base_tx)