From ce315ce17ae41dadd33d2279440e79ab7e63346f Mon Sep 17 00:00:00 2001 From: ebsmothers Date: Fri, 12 Jul 2024 10:11:55 -0700 Subject: [PATCH] Fix Gemma 7B LoRA checkpoint save (#1169) --- torchtune/models/convert_weights.py | 4 +++- torchtune/utils/_checkpointing/_checkpointer.py | 1 + 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/torchtune/models/convert_weights.py b/torchtune/models/convert_weights.py index 3652813c12..68277d6dfc 100644 --- a/torchtune/models/convert_weights.py +++ b/torchtune/models/convert_weights.py @@ -249,6 +249,7 @@ def tune_to_peft_adapter_weights( num_heads: int = 32, num_kv_heads: int = 32, dim: int = 4096, + head_dim: int = None, ): converted_state_dict = {} full_mapping = {} @@ -266,7 +267,8 @@ def tune_to_peft_adapter_weights( } ) - head_dim = dim // num_heads + if head_dim is None: + head_dim = dim // num_heads def _permute_lora_matrix(t, n_heads): rank = t.shape[-1] diff --git a/torchtune/utils/_checkpointing/_checkpointer.py b/torchtune/utils/_checkpointing/_checkpointer.py index 83ffda0255..f43ee31d9a 100644 --- a/torchtune/utils/_checkpointing/_checkpointer.py +++ b/torchtune/utils/_checkpointing/_checkpointer.py @@ -537,6 +537,7 @@ def save_checkpoint( num_heads=self._config["num_attention_heads"], num_kv_heads=self._config["num_key_value_heads"], dim=self._config["hidden_size"], + head_dim=self._config.get("head_dim", None), ) peft_output_path = Path.joinpath( self._output_dir, "adapter_model"