diff --git a/torchtune/models/convert_weights.py b/torchtune/models/convert_weights.py index 3652813c12..68277d6dfc 100644 --- a/torchtune/models/convert_weights.py +++ b/torchtune/models/convert_weights.py @@ -249,6 +249,7 @@ def tune_to_peft_adapter_weights( num_heads: int = 32, num_kv_heads: int = 32, dim: int = 4096, + head_dim: int = None, ): converted_state_dict = {} full_mapping = {} @@ -266,7 +267,8 @@ def tune_to_peft_adapter_weights( } ) - head_dim = dim // num_heads + if head_dim is None: + head_dim = dim // num_heads def _permute_lora_matrix(t, n_heads): rank = t.shape[-1] diff --git a/torchtune/utils/_checkpointing/_checkpointer.py b/torchtune/utils/_checkpointing/_checkpointer.py index 83ffda0255..f43ee31d9a 100644 --- a/torchtune/utils/_checkpointing/_checkpointer.py +++ b/torchtune/utils/_checkpointing/_checkpointer.py @@ -537,6 +537,7 @@ def save_checkpoint( num_heads=self._config["num_attention_heads"], num_kv_heads=self._config["num_key_value_heads"], dim=self._config["hidden_size"], + head_dim=self._config.get("head_dim", None), ) peft_output_path = Path.joinpath( self._output_dir, "adapter_model"