Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add KD distributed recipe #1631

Open
wants to merge 48 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 38 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
2933cd6
sft recipes to eval kd
lindawangg Sep 5, 2024
bff065a
setup kd files
lindawangg Sep 5, 2024
9dd7b47
delete test config
lindawangg Sep 5, 2024
a39e99c
added student config
lindawangg Sep 6, 2024
6dbcd38
Merge branch 'main' into add-initial-kd-recipe
lindawangg Sep 6, 2024
0c4e4f9
added teacher model loading
lindawangg Sep 6, 2024
380f267
added loss
lindawangg Sep 7, 2024
da2b4bb
kd initial experiment config
lindawangg Sep 10, 2024
8beaca0
Merge branch 'main' into add-initial-kd-recipe
lindawangg Sep 10, 2024
b54929a
separated out loss func and added test
lindawangg Sep 11, 2024
b31c56d
added documentation
lindawangg Sep 11, 2024
fe5ed97
added prereq command to config
lindawangg Sep 11, 2024
8b9ea41
Merge branch 'main' into add-initial-kd-recipe
lindawangg Sep 11, 2024
3f7fe70
re-add 8B config
lindawangg Sep 11, 2024
a87aa0c
added kd ratio
lindawangg Sep 11, 2024
f5feac4
revert 8b config
lindawangg Sep 11, 2024
8c3c42a
add kd recipe test
lindawangg Sep 12, 2024
6ba0514
mark as integration test
lindawangg Sep 12, 2024
04ea649
add save and load weights test
lindawangg Sep 12, 2024
62faa1d
fix comments 1
lindawangg Sep 13, 2024
bf15406
address kd loss test comments
lindawangg Sep 13, 2024
ac9eb0e
change to qwen2
lindawangg Sep 13, 2024
87a80b6
addressing recipe comments
lindawangg Sep 15, 2024
1fc3f64
Merge branch 'main' into add-initial-kd-recipe
lindawangg Sep 15, 2024
106aa3e
distributed recipe
lindawangg Sep 16, 2024
0f4e922
remove todo comment and test activation checkpointing
lindawangg Sep 16, 2024
22fddca
Merge branch 'main' into add-initial-kd-recipe
lindawangg Sep 16, 2024
526a4dc
Merge branch 'add-initial-kd-recipe' into add-kd-distributed
lindawangg Sep 16, 2024
0bb49dc
qwen2 distributed recipe
lindawangg Sep 17, 2024
c73857d
added to recipe registry
lindawangg Sep 17, 2024
59eff44
fdsp teacher model
lindawangg Sep 17, 2024
85d76bb
added kd distributed test
lindawangg Sep 18, 2024
dba57c4
fixed command
lindawangg Sep 18, 2024
04e2282
Merge branch 'main' into add-kd-distributed
lindawangg Sep 19, 2024
44123b9
changed to knowledge_distillation
lindawangg Sep 20, 2024
a04244d
cleaned up tests
lindawangg Sep 20, 2024
1ff9934
added gpu test
lindawangg Sep 20, 2024
703e7dc
Merge branch 'main' into add-kd-distributed
lindawangg Sep 24, 2024
0031bfb
Merge branch 'main' into add-kd-distributed
lindawangg Oct 15, 2024
307791d
added llama3 config and addressed comments
lindawangg Oct 15, 2024
fefc24d
added custom sharding layers
lindawangg Oct 15, 2024
15c5be2
Merge branch 'main' into add-kd-distributed
lindawangg Oct 21, 2024
46473ee
add test_loss back in
lindawangg Oct 22, 2024
2e212ec
Merge branch 'main' into add-kd-distributed
lindawangg Oct 24, 2024
557396e
rebase
lindawangg Oct 24, 2024
4d376e3
Merge branch 'main' into add-kd-distributed
lindawangg Oct 25, 2024
53c47ba
grad accumulation changes
lindawangg Oct 25, 2024
f193d02
remove extra num_tokens
lindawangg Oct 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 123 additions & 0 deletions recipes/configs/qwen2/knowledge_distillation_distributed.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Config for multi-device knowledge distillation in knowledge_distillation_distributed.py
# using a teacher and student model
#
# This config assumes that you've ran the following commands before launching KD:
# First download the student and teacher models
# tune download Qwen/Qwen2-0.5B-Instruct --output-dir /tmp/Qwen2-0.5B-Instruct --ignore-patterns None
# tune download Qwen/Qwen2-1.5B-Instruct --output-dir /tmp/Qwen2-1.5B-Instruct --ignore-patterns None
#
# You get better results using KD if the teacher model has already been fine-tuned on the target dataset:
# tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config qwen2/1.5B_lora
#
# To launch on a single device, run the following command from root:
# tune run --nnodes 1 --nproc_per_node 2 knowledge_distillation_distributed --config qwen2/knowledge_distillation_distributed
#
# This config works only for distilling on a single device.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: update this



# Model Arguments
model:
_component_: torchtune.models.qwen2.lora_qwen2_0_5b
lora_attn_modules: ['q_proj', 'k_proj', 'v_proj']
apply_lora_to_mlp: False
lora_rank: 32
lora_alpha: 64

teacher_model:
_component_: torchtune.models.qwen2.qwen2_1_5b

tokenizer:
_component_: torchtune.models.qwen2.qwen2_tokenizer
path: /tmp/Qwen2-0.5B-Instruct/vocab.json
merges_file: /tmp/Qwen2-0.5B-Instruct/merges.txt
max_seq_len: null

checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Qwen2-0.5B-Instruct
checkpoint_files: [
model.safetensors
]
recipe_checkpoint: null
output_dir: /tmp/Qwen2-0.5B-Instruct-kd
model_type: QWEN2

teacher_checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Qwen2-1.5B-Instruct-lora-finetune
checkpoint_files: [
hf_model_0001_0.pt
]
recipe_checkpoint: null
output_dir: /tmp/Qwen2-1.5B-Instruct-lora-finetune
model_type: QWEN2

resume_from_checkpoint: False

# Dataset and Sampler
dataset:
_component_: torchtune.datasets.alpaca_cleaned_dataset
seed: null
shuffle: True
batch_size: 8

# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
_component_: torchtune.modules.get_cosine_schedule_with_warmup
num_warmup_steps: 100

loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss

kd_loss:
_component_: torchtune.modules.loss.ForwardKLWithChunkedOutputLoss
kd_ratio: 0.5

# Training
epochs: 1
max_steps_per_epoch: null
gradient_accumulation_steps: 2

# Logging
output_dir: /tmp/qwen_kd
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: ${output_dir}
log_every_n_steps: 1
log_peak_memory_stats: False

# Environment
device: cuda
dtype: bf16
enable_activation_checkpointing: True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what the peak memory you're seeing is but with distributed you may be able to get away without this (especially for such small models) and get faster training

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to False. Isn't needed for qwen2 and training time also went from 1h to 20 mins.


# Show case the usage of pytorch profiler
# Set enabled to False as it's only needed for debugging training
profiler:
_component_: torchtune.training.setup_torch_profiler

enabled: False

#Output directory of trace artifacts
output_dir: ${output_dir}/profiling_outputs

#`torch.profiler.ProfilerActivity` types to trace
cpu: True
cuda: True

#trace options passed to `torch.profiler.profile`
profile_memory: False
with_stack: False
record_shapes: True
with_flops: False

# `torch.profiler.schedule` options:
# wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
wait_steps: 5
warmup_steps: 5
active_steps: 2
num_cycles: 1
Loading
Loading