diff --git a/llava/model/llava_arch.py b/llava/model/llava_arch.py index 7b299d3c4..cc7910f2a 100644 --- a/llava/model/llava_arch.py +++ b/llava/model/llava_arch.py @@ -25,6 +25,7 @@ from llava.mm_utils import get_anyres_image_grid_shape +from transformers.integrations import is_deepspeed_zero3_enabled class LlavaMetaModel: @@ -94,7 +95,12 @@ def initialize_vision_modules(self, model_args, fsdp=None): def get_w(weights, keyword): return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k} - self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector')) + if is_deepspeed_zero3_enabled(): + import deepspeed + with deepspeed.zero.GatheredParameters(self.mm_projector.parameters(), modifier_rank=0): + self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector')) + else: + self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector')) def unpad_image(tensor, original_size):