Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Perform gradient clipping on global batch when using gradient accumulation #6

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

ashors1
Copy link
Contributor

@ashors1 ashors1 commented Feb 14, 2023

Refactoring to allow gradient clipping to be performed on full batch rather than subbatches when using ShardedStaticAccumulator. Note that this refactor allows us to maintain support for enable_skip_step_on_gradient_anomalies and requires x+1 grad norm calculations per global batch when using ShardedStaticAccumulator with x subbatches (once per subbatch to determine whether step should be skipped, once when applying gradient clipping in base optimizer update) and requires one grad clip per global batch.

This PR should be taken together with the corresponding Paxml PR.

@zhangqiaorjc zhangqiaorjc self-assigned this Mar 3, 2023
@zhangqiaorjc zhangqiaorjc self-requested a review March 3, 2023 04:17
Copy link
Member

@zhangqiaorjc zhangqiaorjc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks Anna!

praxis/optimizers.py Outdated Show resolved Hide resolved
praxis/optimizers.py Outdated Show resolved Hide resolved

raw_grad_norm = _compute_grad_norm(raw_grads)

grads, grad_scale = clip_grads(raw_grads, raw_grad_norm)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to compute and return grad_scale?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not needed. I no longer return grad_scale with the latest commit

grad_scale = jnp.array(1.0)
return grads, grad_scale

raw_grad_norm = _compute_grad_norm(raw_grads)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

iiuc, if clip_grad_single_norm_to_value is True, then raw_grad_norm is not used and we have to compute grad_single_norm separately anyways?

can we move the if-elif-else statement inside out and avoid redundant computation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely. I have addressed this with my latest commit


def scale_gradients(
raw_grads: NestedMap,
clip_grad_norm_to_value: Optional[float] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looking at praxis optimizers, clip_gradient_norm_to_value and clip_gradient_single_norm_to_value default are 0.0 and not None right?

so perhaps the types here should be float and default 0.0 instead of Optional?

clip_grad_single_norm_to_value: Optional[float] = None):

def clip_grads(grads):
if clip_grad_norm_to_value:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe assert only one of them is true?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants