diff --git a/paxml/tasks/lm/params/nvidia.py b/paxml/tasks/lm/params/nvidia.py index 718589554..7cf76204a 100644 --- a/paxml/tasks/lm/params/nvidia.py +++ b/paxml/tasks/lm/params/nvidia.py @@ -773,6 +773,7 @@ class Grok(NVIDIA1_3B): DCN_MESH_SHAPE = [1, 32, 1, 1] USE_ROPE = True + USE_TE_DPA=False def task(self) -> pax_fiddle.Config[tasks_lib.SingleTask]: task_p = super().task() @@ -802,6 +803,7 @@ def task(self) -> pax_fiddle.Config[tasks_lib.SingleTask]: combine_qkv=self.COMBINE_QKV, checkpoint_policy=self.CHECKPOINT_POLICY, use_fp8=self.USE_FP8, + use_te_dpa=self.USE_TE_DPA, ) ## set sharding lm_cls = cast( @@ -860,6 +862,7 @@ class Grok_Proxy(NVIDIA1_3B): DCN_MESH_SHAPE = [1, 8, 1, 1] USE_ROPE = True + USE_TE_DPA=False def task(self) -> pax_fiddle.Config[tasks_lib.SingleTask]: task_p = super().task() @@ -889,6 +892,7 @@ def task(self) -> pax_fiddle.Config[tasks_lib.SingleTask]: combine_qkv=self.COMBINE_QKV, checkpoint_policy=self.CHECKPOINT_POLICY, use_fp8=self.USE_FP8, + use_te_dpa=self.USE_TE_DPA, ) ## set sharding lm_cls = cast( @@ -953,6 +957,7 @@ class Grok_Proxy_PP(NVIDIA5B): MESH_AXIS_NAMES = ['stage', 'replica', 'data', 'data_expert', 'mdl'] USE_ROPE = True + USE_TE_DPA=False def task(self) -> pax_fiddle.Config[tasks_lib.SingleTask]: task_p = super().task() @@ -984,6 +989,7 @@ def task(self) -> pax_fiddle.Config[tasks_lib.SingleTask]: num_pipeline_stages=self.NUM_STAGES, num_pipeline_microbatches=self.NUM_MICROBATCHES, use_fp8=self.USE_FP8, + use_te_dpa=self.USE_TE_DPA, ) ## set sharding replica_axis='replica'