Skip to content

Commit

Permalink
QLoRA with bias + Llama 3.2 Vision QLoRA configs (#1726)
Browse files Browse the repository at this point in the history
  • Loading branch information
ebsmothers authored Oct 25, 2024
1 parent bc486d4 commit e030626
Show file tree
Hide file tree
Showing 14 changed files with 429 additions and 175 deletions.
2 changes: 1 addition & 1 deletion recipes/configs/llama3_2_vision/11B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ enable_activation_offloading: False
dtype: bf16

# Logging
output_dir: /tmp/full-llama3.2-vision-finetune
output_dir: /tmp/lora-llama3.2-vision-finetune
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: /tmp/Llama-3.2-11B-Vision-Instruct/logs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ enable_activation_offloading: False
dtype: bf16

# Logging
output_dir: /tmp/full-llama3.2-vision-finetune
output_dir: /tmp/lora-llama3.2-vision-finetune
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: /tmp/Llama-3.2-11B-Vision-Instruct/logs
Expand Down
88 changes: 88 additions & 0 deletions recipes/configs/llama3_2_vision/11B_qlora.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Config for multi-device QLoRA finetuning in lora_finetune_distributed.py
# using a Llama3.2 11B Vision Instruct model
#
# This config assumes that you've run the following command before launching:
# tune download meta-llama/Llama-3.2-11B-Vision-Instruct --output-dir /tmp/Llama-3.2-11B-Vision-Instruct
#
# To launch on 2 devices, run the following command from root:
# tune run --nproc_per_node 2 lora_finetune_distributed --config llama3_2_vision/11B_qlora
#
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training:
# tune run --nproc_per_node 2 lora_finetune_distributed --config llama3_2_vision/11B_qlora checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
#
# This config works best when the model is being fine-tuned on 2+ GPUs.
# For single device QLoRA finetuning please use 11B_qlora_single_device.yaml

# Model arguments
model:
_component_: torchtune.models.llama3_2_vision.qlora_llama3_2_vision_11b
decoder_trainable: "frozen"
encoder_trainable: "lora"
fusion_trainable: "lora"
lora_attn_modules: ['q_proj', 'v_proj']
apply_lora_to_mlp: False
apply_lora_to_output: False
lora_rank: 8
lora_alpha: 16
lora_dropout: 0.0
image_size: 560 # Make sure this matches the image_size in tokenizer

# Transform
tokenizer:
_component_: torchtune.models.llama3_2_vision.llama3_2_vision_transform
path: /tmp/Llama-3.2-11B-Vision-Instruct/original/tokenizer.model
image_size: 560
max_seq_len: 8192

# Checkpointer
checkpointer:
_component_: torchtune.training.FullModelMetaCheckpointer
checkpoint_dir: /tmp/Llama-3.2-11B-Vision-Instruct/original/
checkpoint_files: [consolidated.pth]
recipe_checkpoint: null
output_dir: /tmp/Llama-3.2-11B-Vision-Instruct/
model_type: LLAMA3_VISION
resume_from_checkpoint: False

# Dataset
dataset:
_component_: torchtune.datasets.multimodal.the_cauldron_dataset
subset: ocrvqa
seed: null
shuffle: True
collate_fn: torchtune.data.padded_collate_tiled_images_and_mask

# Fine-tuning arguments
epochs: 1
max_steps_per_epoch: null
batch_size: 2
gradient_accumulation_steps: 4
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.01
lr: 2e-5
lr_scheduler:
_component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
num_warmup_steps: 100
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
clip_grad_norm: 1.0
compile: False # set it to True for better memory and performance

# Training env
device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False
dtype: bf16

# Logging
output_dir: /tmp/qlora-llama3.2-vision-finetune
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: /tmp/Llama-3.2-11B-Vision-Instruct/logs
log_every_n_steps: 1
log_peak_memory_stats: False
113 changes: 113 additions & 0 deletions recipes/configs/llama3_2_vision/11B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Config for single device QLoRA finetuning in lora_finetune_single_device.py
# using a Llama3.2 11B Vision Instruct model
#
# This config assumes that you've run the following command before launching:
# tune download meta-llama/Llama-3.2-11B-Vision-Instruct --output-dir /tmp/Llama-3.2-11B-Vision-Instruct
#
# To launch on a single device, run the following command from root:
# tune run lora_finetune_single_device --config llama3_2_vision/11B_qlora_single_device
#
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training:
# tune run lora_finetune_single_device --config llama3_2_vision/11B_qlora_single_device checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
#
# This config works only for training on single device.

# Model arguments
model:
_component_: torchtune.models.llama3_2_vision.qlora_llama3_2_vision_11b
decoder_trainable: "frozen"
encoder_trainable: "lora"
fusion_trainable: "lora"
lora_attn_modules: ['q_proj', 'v_proj']
apply_lora_to_mlp: False
apply_lora_to_output: False
lora_rank: 8
lora_alpha: 16
lora_dropout: 0.0
image_size: 560 # Make sure this matches the image_size in tokenizer

# Transform
tokenizer:
_component_: torchtune.models.llama3_2_vision.llama3_2_vision_transform
path: /tmp/Llama-3.2-11B-Vision-Instruct/original/tokenizer.model
image_size: 560
max_seq_len: 8192

# Checkpointer
checkpointer:
_component_: torchtune.training.FullModelMetaCheckpointer
checkpoint_dir: /tmp/Llama-3.2-11B-Vision-Instruct/original/
checkpoint_files: [consolidated.pth]
recipe_checkpoint: null
output_dir: /tmp/Llama-3.2-11B-Vision-Instruct/
model_type: LLAMA3_VISION
resume_from_checkpoint: False

# Dataset
dataset:
_component_: torchtune.datasets.multimodal.the_cauldron_dataset
subset: ocrvqa
seed: null
shuffle: True
collate_fn: torchtune.data.padded_collate_tiled_images_and_mask

# Fine-tuning arguments
epochs: 1
max_steps_per_epoch: null
batch_size: 2
gradient_accumulation_steps: 16
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.01
lr: 2e-5
optimizer_in_bwd: False
lr_scheduler:
_component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
num_warmup_steps: 100
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
clip_grad_norm: 1.0
compile: False # set it to True for better memory and performance

# Training env
device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False
dtype: bf16

# Logging
output_dir: /tmp/qlora-llama3.2-vision-finetune
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: /tmp/Llama-3.2-11B-Vision-Instruct/logs
log_every_n_steps: 1
log_peak_memory_stats: False

# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
enabled: False

#Output directory of trace artifacts
output_dir: ${output_dir}/profiling_outputs

#`torch.profiler.ProfilerActivity` types to trace
cpu: True
cuda: True

#trace options passed to `torch.profiler.profile`
profile_memory: True
with_stack: False
record_shapes: True
with_flops: False

# `torch.profiler.schedule` options:
# wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
wait_steps: 1
warmup_steps: 2
active_steps: 1
num_cycles: 1
9 changes: 3 additions & 6 deletions tests/torchtune/modules/low_precision/test_nf4_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,6 @@ class TestNF4Linear:
Class for testing our NF4Linear implementation.
"""

def test_bias_unsupported(self):
with pytest.raises(RuntimeError, match="does not currently support biases"):
_ = FrozenNF4Linear(1, 1, bias=True)

@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
def test_parameters(self, dtype):
nf4_linear = FrozenNF4Linear(512, 512, device="cpu", dtype=dtype)
Expand All @@ -59,9 +55,10 @@ def test_state_dict(self, dtype):
assert isinstance(state_dict["weight"], NF4Tensor)

@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
def test_output_dtype(self, dtype):
@pytest.mark.parametrize("bias", [True, False])
def test_output_dtype(self, dtype, bias):
# Test to ensure W4 A16 produces A16 / W4A32 produces A32
nf4_linear = FrozenNF4Linear(512, 512, device="cpu", dtype=dtype)
nf4_linear = FrozenNF4Linear(512, 512, device="cpu", dtype=dtype, bias=bias)
inp = torch.randn(2, 512, dtype=dtype, requires_grad=True)
out = nf4_linear(inp)
assert out.dtype == dtype
Expand Down
107 changes: 52 additions & 55 deletions tests/torchtune/modules/peft/test_dora.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,80 +49,77 @@ def inputs(self, in_dim) -> torch.Tensor:
return inputs

@pytest.fixture
def dora_linear(self, in_dim, out_dim) -> DoRALinear:
dora_linear = DoRALinear(
in_dim=in_dim,
out_dim=out_dim,
rank=RANK,
alpha=ALPHA,
use_bias=False,
)
def dora_linear(self, in_dim, out_dim):
def create_dora_linear(use_bias, dtype, in_dim=in_dim, out_dim=out_dim):
with training.set_default_dtype(dtype):
dora_linear = DoRALinear(
in_dim=in_dim,
out_dim=out_dim,
rank=RANK,
alpha=ALPHA,
use_bias=use_bias,
)

fixed_init_model(dora_linear)
return dora_linear
fixed_init_model(dora_linear)
return dora_linear

return create_dora_linear

@pytest.fixture
def qdora_linear(self, in_dim, out_dim) -> DoRALinear:
with training.set_default_dtype(torch.bfloat16):
qdora_linear = DoRALinear(
in_dim=512,
out_dim=512,
rank=RANK,
alpha=ALPHA,
use_bias=False,
quantize_base=True,
)
fixed_init_model(qdora_linear, dtype=torch.bfloat16)
def qdora_linear(self):
def create_qdora_linear(
use_bias=False, dtype=torch.bfloat16, in_dim=512, out_dim=512
):
with training.set_default_dtype(dtype):
qdora_linear = DoRALinear(
in_dim=in_dim,
out_dim=out_dim,
rank=RANK,
alpha=ALPHA,
use_bias=use_bias,
quantize_base=True,
)
fixed_init_model(qdora_linear)
return qdora_linear

return create_qdora_linear

def test_forward(self, inputs, dora_linear, out_dim) -> None:
dora_linear = dora_linear(use_bias=False, dtype=torch.float32)
expected = torch.tensor(EXPECTED_VAL)
actual = dora_linear(inputs)
assert actual.shape == (BSZ, SEQ_LEN, out_dim)
torch.testing.assert_close(actual.mean(), expected, atol=1e-4, rtol=1e-6)

def test_dora_weight_nf4_when_quantized(self, qdora_linear):
@pytest.mark.parametrize("use_bias", [True, False])
def test_dora_weight_nf4_when_quantized(self, use_bias, qdora_linear):
qdora_linear = qdora_linear(use_bias=use_bias, dtype=torch.bfloat16)
assert isinstance(qdora_linear.weight, NF4Tensor)

def test_bias_raises(self):
with pytest.raises(
NotImplementedError, match="DoRALinear does not support using bias"
):
DoRALinear(
in_dim=512,
out_dim=512,
rank=RANK,
alpha=ALPHA,
use_bias=True,
quantize_base=False,
)

@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
def test_qdora_parity(self, dtype):
if use_bias:
assert not isinstance(qdora_linear.bias, NF4Tensor)
assert qdora_linear.bias.dtype == torch.bfloat16

# Note: with bfloat16 F.linear(x, weight, bias) != F.linear(x, weight) + bias.
# This means we would get different results (irrespective of QDoRA).
# So we leave that test case out
@pytest.mark.parametrize(
"use_bias, dtype",
[(False, torch.bfloat16), (True, torch.float32), (False, torch.float32)],
)
def test_qdora_parity(self, use_bias, dtype, dora_linear, qdora_linear):
with training.set_default_dtype(dtype):
torch.manual_seed(0)
qdora_linear = DoRALinear(
in_dim=512,
out_dim=512,
rank=RANK,
alpha=ALPHA,
use_bias=False,
quantize_base=True,
qdora_linear = qdora_linear(
use_bias=use_bias, dtype=dtype, in_dim=512, out_dim=512
)
torch.manual_seed(0)
dora_linear = DoRALinear(
in_dim=512,
out_dim=512,
rank=RANK,
alpha=ALPHA,
use_bias=False,
quantize_base=False,
dora_linear = dora_linear(
use_bias=use_bias, dtype=dtype, in_dim=512, out_dim=512
)

# set weight of dora_linear to unquantized weight of qdora_linear and check
# parity.
dora_linear.weight.data = qdora_linear.weight.to(dtype)

if use_bias:
dora_linear.bias.data = qdora_linear.bias.detach().clone()
qdora_linear.initialize_dora_magnitude()
dora_linear.initialize_dora_magnitude()

Expand Down
Loading

0 comments on commit e030626

Please sign in to comment.