Skip to content

Commit

Permalink
Normalize CE loss by total number of (non-padding) tokens (#1875)
Browse files Browse the repository at this point in the history
  • Loading branch information
ebsmothers authored Oct 25, 2024
1 parent 8e013c2 commit 23c8829
Show file tree
Hide file tree
Showing 13 changed files with 207 additions and 181 deletions.
20 changes: 13 additions & 7 deletions recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,13 @@ def train(self) -> None:
torch.cuda.memory._record_memory_history()

utils.batch_to_device(batch, self._device)
num_tokens += batch["tokens"].numel()

# Calculate the number of unmasked tokens in the current batch
# and increment the total number of tokens seen in the step
current_num_tokens = (
batch["labels"] != self._loss_fn.ignore_index
).sum()
num_tokens += current_num_tokens

# Shape [b, s], needed for the loss not the model
labels = batch.pop("labels")
Expand All @@ -683,17 +689,17 @@ def train(self) -> None:
logits = logits.reshape(-1, logits.size(-1))

# Compute loss
loss = self._loss_fn(logits, labels)
# Loss is normalized by default so we multiply by the number of tokens
# This way we can normalize by the total number of tokens if we're accumulating gradients
running_loss += self._loss_fn(logits, labels) * current_num_tokens

# free logits otherwise it peaks backward memory
del logits

loss = loss / self._gradient_accumulation_steps
running_loss += loss
loss.backward()

# Step with optimizer
if (idx + 1) % self._gradient_accumulation_steps == 0:
loss = running_loss / num_tokens
loss.backward()
if self._clip_grad_norm is not None:
if self._optimizer_in_bwd:
raise NotImplementedError(
Expand All @@ -710,7 +716,7 @@ def train(self) -> None:
# Update the number of steps when the weights are updated
self.global_step += 1

loss_to_log = running_loss.item()
loss_to_log = loss.item()
pbar.update(1)
pbar.set_description(
f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
Expand Down
27 changes: 18 additions & 9 deletions recipes/full_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,15 +625,22 @@ def train(self) -> None:
torch.cuda.memory._record_memory_history()

utils.batch_to_device(batch, self._device)
num_tokens += batch["tokens"].numel()

loss = self._loss_step(batch)
loss = loss / self._gradient_accumulation_steps
running_loss += loss
loss.backward()
# Calculate the number of unmasked tokens in the current batch
# and increment the total number of tokens seen in the step
current_num_tokens = (
batch["labels"] != self._loss_fn.ignore_index
).sum()
num_tokens += current_num_tokens

# Loss is normalized by default so we multiply by the number of tokens
# This way we can normalize by the total number of tokens if we're accumulating gradients
running_loss += self._loss_step(batch) * current_num_tokens

# Step with optimizer
if (idx + 1) % self._gradient_accumulation_steps == 0:
loss = running_loss / num_tokens
loss.backward()
if self._clip_grad_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.parameters(),
Expand All @@ -648,7 +655,7 @@ def train(self) -> None:
self._lr_scheduler.step()
self.global_step += 1

loss_to_log = running_loss.item()
loss_to_log = loss.item()
pbar.update(1)
pbar.set_description(
f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
Expand All @@ -662,9 +669,11 @@ def train(self) -> None:
# NOTE: for optim in backward, this assumes all optimizers have the same LR. This is currently
# true since we don't expose the ability to configure this yet.
"lr": get_lr(
self._optimizer
if not self._optimizer_in_bwd
else self._optim_ckpt_wrapper,
(
self._optimizer
if not self._optimizer_in_bwd
else self._optim_ckpt_wrapper
),
),
"tokens_per_second_per_gpu": num_tokens / time_per_step,
}
Expand Down
25 changes: 17 additions & 8 deletions recipes/knowledge_distillation_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,17 +687,26 @@ def train(self) -> None:
torch.cuda.memory._record_memory_history()

batch = {k: v.to(self._device) for k, v in batch.items()}
num_tokens += batch["tokens"].numel()

# Calculate the number of unmasked tokens in the current batch
# and increment the total number of tokens seen in the step
current_num_tokens = (
batch["labels"] != self._loss_fn.ignore_index
).sum()
num_tokens += current_num_tokens

class_loss, kd_loss = self._loss_step(batch)
loss = (1 - self._kd_ratio) * class_loss + self._kd_ratio * kd_loss
loss = loss / self._gradient_accumulation_steps
running_class_loss += class_loss / self._gradient_accumulation_steps
running_kd_loss += kd_loss / self._gradient_accumulation_steps
loss.backward()
running_class_loss += class_loss * current_num_tokens
running_kd_loss += kd_loss * current_num_tokens

# Step with optimizer
if (idx + 1) % self._gradient_accumulation_steps == 0:
class_loss = running_class_loss / num_tokens
kd_loss = running_kd_loss / num_tokens
loss = (
1 - self._kd_ratio
) * class_loss + self._kd_ratio * kd_loss
loss.backward()
if self._clip_grad_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.parameters(),
Expand All @@ -709,8 +718,8 @@ def train(self) -> None:
# Update the number of steps when the weights are updated
self.global_step += 1

class_loss_to_log = running_class_loss.item()
kd_loss_to_log = running_kd_loss.item()
class_loss_to_log = class_loss.item()
kd_loss_to_log = kd_loss.item()
loss_to_log = (
1 - self._kd_ratio
) * class_loss_to_log + self._kd_ratio * kd_loss_to_log
Expand Down
20 changes: 13 additions & 7 deletions recipes/lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,7 +764,13 @@ def train(self) -> None:
torch.cuda.memory._record_memory_history()

utils.batch_to_device(batch, self._device)
num_tokens += batch["tokens"].numel()

# Calculate the number of unmasked tokens in the current batch
# and increment the total number of tokens seen in the step
current_num_tokens = (
batch["labels"] != self._loss_fn.ignore_index
).sum()
num_tokens += current_num_tokens

# Shape [b, s], needed for the loss not the model
labels = batch.pop("labels")
Expand All @@ -783,17 +789,17 @@ def train(self) -> None:
logits = logits.reshape(-1, logits.size(-1))

# Compute loss
loss = self._loss_fn(logits, labels)
# Loss is normalized by default so we multiply by the number of tokens
# This way we can normalize by the total number of tokens if we're accumulating gradients
running_loss += self._loss_fn(logits, labels) * current_num_tokens

# free logits otherwise it peaks backward memory
del logits

loss = loss / self._gradient_accumulation_steps
running_loss += loss
loss.backward()

# Step with optimizer
if (idx + 1) % self._gradient_accumulation_steps == 0:
loss = running_loss / num_tokens
loss.backward()
if self._clip_grad_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.parameters(),
Expand All @@ -806,7 +812,7 @@ def train(self) -> None:
# Update the number of steps when the weights are updated
self.global_step += 1

loss_to_log = running_loss.item()
loss_to_log = loss.item()
pbar.update(1)
pbar.set_description(
f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
Expand Down
20 changes: 13 additions & 7 deletions recipes/lora_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,6 @@ def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
labels = labels.reshape(-1)
logits = logits.reshape(-1, logits.size(-1))

# Compute loss
loss = self._loss_fn(logits, labels)

# free logits otherwise it peaks backward memory
Expand Down Expand Up @@ -679,15 +678,22 @@ def train(self) -> None:
torch.cuda.memory._record_memory_history()

utils.batch_to_device(batch, self._device)
num_tokens += batch["tokens"].numel()

loss = self._loss_step(batch)
loss = loss / self._gradient_accumulation_steps
running_loss += loss
loss.backward()
# Calculate the number of unmasked tokens in the current batch
# and increment the total number of tokens seen in the step
current_num_tokens = (
batch["labels"] != self._loss_fn.ignore_index
).sum()
num_tokens += current_num_tokens

# Loss is normalized by default so we multiply by the number of tokens
# This way we can normalize by the total number of tokens if we're accumulating gradients
running_loss += self._loss_step(batch) * current_num_tokens

# Step with optimizer
if (idx + 1) % self._gradient_accumulation_steps == 0:
loss = running_loss / num_tokens
loss.backward()
if self._clip_grad_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.parameters(),
Expand All @@ -699,7 +705,7 @@ def train(self) -> None:
# Update the number of steps when the weights are updated
self.global_step += 1

loss_to_log = running_loss.item()
loss_to_log = loss.item()
pbar.update(1)
pbar.set_description(
f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
Expand Down
20 changes: 13 additions & 7 deletions recipes/qat_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,14 @@ def train(self) -> None:
self._model.apply(enable_fq)

tokens = tokens.to(self._device)
num_tokens += tokens.numel()

# Calculate the number of unmasked tokens in the current batch
# and increment the total number of tokens seen in the step
current_num_tokens = (
batch["labels"] != self._loss_fn.ignore_index
).sum()
num_tokens += current_num_tokens

labels = labels.to(self._device)
mask = mask.to(self._device) if mask is not None else None
input_pos = (
Expand All @@ -679,23 +686,22 @@ def train(self) -> None:
logits = logits.reshape(-1, logits.size(-1))

# Compute loss
loss = self._loss_fn(logits, labels)
running_loss += self._loss_fn(logits, labels) * current_num_tokens
# free logits otherwise it peaks backward memory
del logits

loss = loss / self._gradient_accumulation_steps
running_loss += loss
loss.backward()

# Step with optimizer
if (idx + 1) % self._gradient_accumulation_steps == 0:
loss = running_loss / num_tokens
loss.backward()

self._optimizer.step()
self._optimizer.zero_grad(set_to_none=True)

# Update the number of steps when the weights are updated
self.global_step += 1

loss_to_log = running_loss.item()
loss_to_log = loss.item()
pbar.update(1)
pbar.set_description(
f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
Expand Down
16 changes: 8 additions & 8 deletions tests/recipes/test_full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
class TestFullFinetuneDistributedRecipe:
def _get_test_config_overrides(self):
return [
"batch_size=4",
"dtype=fp32",
"enable_activation_checkpointing=False",
"dataset.train_on_input=False",
Expand All @@ -52,21 +51,22 @@ def _fetch_expected_loss_values(self, model_type):

@pytest.mark.integration_test
@pytest.mark.parametrize(
"config, model_type, ckpt_type, fsdp_sharding_strategy",
"config, model_type, ckpt_type, micro_batch_size, gradient_accumulation_steps",
[
("llama2/7B_full", "llama2", "hf", None),
("llama3/8B_full", "llama3", "tune", None),
("llama3/8B_full", "llama3", "tune", "NO_SHARD"),
("llama2/7B_full", "llama2", "hf", 1, 4),
("llama3/8B_full", "llama3", "tune", 1, 4),
("llama3/8B_full", "llama3", "tune", 4, 1),
],
)
@pytest.mark.parametrize("optim_in_bwd", [True, False])
@gpu_test(gpu_count=2)
def test_loss(
self,
micro_batch_size,
gradient_accumulation_steps,
config,
model_type,
ckpt_type,
fsdp_sharding_strategy,
optim_in_bwd,
tmpdir,
monkeypatch,
Expand All @@ -84,6 +84,8 @@ def test_loss(
cmd = f"""
tune run --nnodes 1 --nproc_per_node 2 full_finetune_distributed \
--config {config} \
batch_size={micro_batch_size} \
gradient_accumulation_steps={gradient_accumulation_steps} \
output_dir={tmpdir} \
checkpointer._component_={ckpt_component} \
checkpointer.checkpoint_dir='{ckpt_dir}' \
Expand All @@ -94,8 +96,6 @@ def test_loss(
tokenizer.prompt_template=null \
metric_logger.filename={log_file} \
""".split()
if fsdp_sharding_strategy:
cmd.append(f"fsdp_sharding_strategy={fsdp_sharding_strategy}")
model_config = MODEL_TEST_CONFIGS[model_type]
cmd = cmd + self._get_test_config_overrides() + model_config
# "optimizer_in_bwd=True" would free gradient info before clip_grad, causing
Expand Down
Loading

0 comments on commit 23c8829

Please sign in to comment.