-
Notifications
You must be signed in to change notification settings - Fork 411
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add KD distributed recipe #1631
base: main
Are you sure you want to change the base?
Changes from 38 commits
2933cd6
bff065a
9dd7b47
a39e99c
6dbcd38
0c4e4f9
380f267
da2b4bb
8beaca0
b54929a
b31c56d
fe5ed97
8b9ea41
3f7fe70
a87aa0c
f5feac4
8c3c42a
6ba0514
04ea649
62faa1d
bf15406
ac9eb0e
87a80b6
1fc3f64
106aa3e
0f4e922
22fddca
526a4dc
0bb49dc
c73857d
59eff44
85d76bb
dba57c4
04e2282
44123b9
a04244d
1ff9934
703e7dc
0031bfb
307791d
fefc24d
15c5be2
46473ee
2e212ec
557396e
4d376e3
53c47ba
f193d02
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
# Config for multi-device knowledge distillation in knowledge_distillation_distributed.py | ||
# using a teacher and student model | ||
# | ||
# This config assumes that you've ran the following commands before launching KD: | ||
# First download the student and teacher models | ||
# tune download Qwen/Qwen2-0.5B-Instruct --output-dir /tmp/Qwen2-0.5B-Instruct --ignore-patterns None | ||
# tune download Qwen/Qwen2-1.5B-Instruct --output-dir /tmp/Qwen2-1.5B-Instruct --ignore-patterns None | ||
# | ||
# You get better results using KD if the teacher model has already been fine-tuned on the target dataset: | ||
# tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config qwen2/1.5B_lora | ||
# | ||
# To launch on a single device, run the following command from root: | ||
# tune run --nnodes 1 --nproc_per_node 2 knowledge_distillation_distributed --config qwen2/knowledge_distillation_distributed | ||
# | ||
# This config works only for distilling on a single device. | ||
|
||
|
||
# Model Arguments | ||
model: | ||
_component_: torchtune.models.qwen2.lora_qwen2_0_5b | ||
lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] | ||
apply_lora_to_mlp: False | ||
lora_rank: 32 | ||
lora_alpha: 64 | ||
|
||
teacher_model: | ||
_component_: torchtune.models.qwen2.qwen2_1_5b | ||
|
||
tokenizer: | ||
_component_: torchtune.models.qwen2.qwen2_tokenizer | ||
path: /tmp/Qwen2-0.5B-Instruct/vocab.json | ||
merges_file: /tmp/Qwen2-0.5B-Instruct/merges.txt | ||
max_seq_len: null | ||
|
||
checkpointer: | ||
_component_: torchtune.training.FullModelHFCheckpointer | ||
checkpoint_dir: /tmp/Qwen2-0.5B-Instruct | ||
checkpoint_files: [ | ||
model.safetensors | ||
] | ||
recipe_checkpoint: null | ||
output_dir: /tmp/Qwen2-0.5B-Instruct-kd | ||
model_type: QWEN2 | ||
|
||
teacher_checkpointer: | ||
_component_: torchtune.training.FullModelHFCheckpointer | ||
checkpoint_dir: /tmp/Qwen2-1.5B-Instruct-lora-finetune | ||
checkpoint_files: [ | ||
hf_model_0001_0.pt | ||
] | ||
recipe_checkpoint: null | ||
output_dir: /tmp/Qwen2-1.5B-Instruct-lora-finetune | ||
model_type: QWEN2 | ||
|
||
resume_from_checkpoint: False | ||
|
||
# Dataset and Sampler | ||
dataset: | ||
_component_: torchtune.datasets.alpaca_cleaned_dataset | ||
seed: null | ||
shuffle: True | ||
batch_size: 8 | ||
|
||
# Optimizer and Scheduler | ||
optimizer: | ||
_component_: torch.optim.AdamW | ||
weight_decay: 0.01 | ||
lr: 3e-4 | ||
lr_scheduler: | ||
_component_: torchtune.modules.get_cosine_schedule_with_warmup | ||
num_warmup_steps: 100 | ||
|
||
loss: | ||
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss | ||
|
||
kd_loss: | ||
_component_: torchtune.modules.loss.ForwardKLWithChunkedOutputLoss | ||
kd_ratio: 0.5 | ||
|
||
# Training | ||
epochs: 1 | ||
max_steps_per_epoch: null | ||
gradient_accumulation_steps: 2 | ||
|
||
# Logging | ||
output_dir: /tmp/qwen_kd | ||
metric_logger: | ||
_component_: torchtune.training.metric_logging.DiskLogger | ||
log_dir: ${output_dir} | ||
log_every_n_steps: 1 | ||
log_peak_memory_stats: False | ||
|
||
# Environment | ||
device: cuda | ||
dtype: bf16 | ||
enable_activation_checkpointing: True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure what the peak memory you're seeing is but with distributed you may be able to get away without this (especially for such small models) and get faster training There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Changed to False. Isn't needed for qwen2 and training time also went from 1h to 20 mins. |
||
|
||
# Show case the usage of pytorch profiler | ||
# Set enabled to False as it's only needed for debugging training | ||
profiler: | ||
_component_: torchtune.training.setup_torch_profiler | ||
|
||
enabled: False | ||
|
||
#Output directory of trace artifacts | ||
output_dir: ${output_dir}/profiling_outputs | ||
|
||
#`torch.profiler.ProfilerActivity` types to trace | ||
cpu: True | ||
cuda: True | ||
|
||
#trace options passed to `torch.profiler.profile` | ||
profile_memory: False | ||
with_stack: False | ||
record_shapes: True | ||
with_flops: False | ||
|
||
# `torch.profiler.schedule` options: | ||
# wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat | ||
wait_steps: 5 | ||
warmup_steps: 5 | ||
active_steps: 2 | ||
num_cycles: 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: update this