diff --git a/praxis/contrib/gpu/scripts_gpu/te_helper.py b/praxis/contrib/gpu/scripts_gpu/te_helper.py new file mode 100644 index 00000000..6078ab6d --- /dev/null +++ b/praxis/contrib/gpu/scripts_gpu/te_helper.py @@ -0,0 +1,402 @@ +import os + +from absl import logging + +from praxis import base_layer +from praxis import pax_fiddle +from praxis import pytypes +from praxis import layers +from praxis.layers.checkpoint_policy import AutodiffCheckpointType +from praxis.layers import activations +from praxis.layers import attentions, grouped_query_attention, multi_query_attention +from praxis.layers import embedding_softmax +from praxis.layers import normalizations +from praxis.contrib.gpu.scripts_gpu.lora_layers import ( + LoraAttentionProjection, + LoraCombinedQKVProjection, + LoraLinear, +) + +try: + import transformer_engine.jax as te + import transformer_engine.jax.flax as te_flax + import transformer_engine.jax.praxis as te_praxis + _IS_TRANSFORMER_ENGINE_INSTALLED = True + import praxis.layers.repeats as praxis_repeat + # This is to make Repeat module correctly generate collections we need. + praxis_repeat.SCAN_VARIABLE_AXES.update({base_layer.NON_PAX_VAR_COLLECTION[1]: 0, # 1-idx = params_axes + te.fp8.FP8Helper.FP8_COLLECTION_NAME:0}) + TE_PIPELINE_EXTRA_VMAP_VAR_AXES = { + base_layer.NON_PAX_VAR_COLLECTION[1]: 0, # 1-idx = params_axes + te.fp8.FP8Helper.FP8_COLLECTION_NAME:0 + } + + TE_PIPELINE_EXTRA_SCAN_VAR_BROADCAST = [te.fp8.FP8Helper.FP8_COLLECTION_NAME] + + ENABLE_TE_SP = bool(int(os.environ.get('ENABLE_TE_SP', 0))) + +except ModuleNotFoundError as e: + _IS_TRANSFORMER_ENGINE_INSTALLED = False + TE_PIPELINE_EXTRA_VMAP_VAR_AXES = {} + TE_PIPELINE_EXTRA_SCAN_VAR_BROADCAST = [] + ENABLE_TE_SP = False + +LayerTpl = pax_fiddle.Config[base_layer.BaseLayer] +JTensor = pytypes.JTensor + + +class TransformerEngineHelperBase: + + @staticmethod + def get_fprop_caller_of_stack_transformer(fprop, deterministic): + raise NotImplementedError + + @staticmethod + def set_layer_params_to_stack_transformer(stacked_transformer_obj, layer_p, layer_id): + raise NotImplementedError + + @staticmethod + def get_input_bld(original_bld, batch_axes, mdl_axis): + # This is used to specify the sharding pattern of inputs to TransformerLayers. + raise NotImplementedError + + @staticmethod + def get_bld_mapping_for_pipelined_transformer(xformer_layer_p): + raise NotImplementedError + + @staticmethod + def check_checkpoint_policy(tpl): + raise NotImplementedError + + +class TENotInstalledHelper(TransformerEngineHelperBase): + + @staticmethod + def get_fprop_caller_of_stack_transformer(fprop, deterministic): + return fprop + + @staticmethod + def set_layer_params_to_stack_transformer(stacked_transformer_obj, layer_p, layer_id): + layer_p.name = f'layer_{layer_id}' + layer_p.use_cross_attention = stacked_transformer_obj.use_cross_attention + layer_p.num_heads = stacked_transformer_obj.num_heads + layer_p.dim_per_head = stacked_transformer_obj.dim_per_head + layer_p.input_dims = stacked_transformer_obj.model_dims + layer_p.packed_input = stacked_transformer_obj.packed_input + layer_p.atten_dropout_prob = stacked_transformer_obj.atten_dropout_prob or stacked_transformer_obj.dropout_prob + layer_p.residual_dropout_prob = ( + stacked_transformer_obj.residual_dropout_prob or stacked_transformer_obj.dropout_prob + ) + layer_p.relu_dropout_prob = stacked_transformer_obj.relu_dropout_prob or stacked_transformer_obj.dropout_prob + layer_p.hidden_dims = stacked_transformer_obj.hidden_dims + if stacked_transformer_obj.local_window_size is not None: + if isinstance(stacked_transformer_obj.local_window_size[0], tuple): + p_i.tr_atten_tpl.local_window_size = stacked_transformer_obj.local_window_size[i] + else: + p_i.tr_atten_tpl.local_window_size = stacked_transformer_obj.local_window_size + + if stacked_transformer_obj.residual_droppath_prob > 0.0: + layer_p.residual_droppath_prob = ( + stacked_transformer_obj.residual_droppath_prob * layer_id / max(1, stacked_transformer_obj.num_layers) + ) + return layer_p + + @staticmethod + def get_input_bld(original_bld, *_): + return original_bld + + @staticmethod + def get_bld_mapping_for_pipelined_transformer(xformer_layer_p): + return xformer_layer_p.tr_atten_tpl.activation_split_dims_mapping.bld + + @staticmethod + def check_checkpoint_policy(_): + """Every checkpoint policy is valid without TE""" + pass + + +class TEInstalledHelper(TransformerEngineHelperBase): + + @staticmethod + def get_fprop_caller_of_stack_transformer(_, deterministic): + def _fprop( + transformer, + x_in, + paddings, + attention_mask, + cross_inputs, + cross_attention_mask, + segment_pos + ): + x_out = transformer( + inputs=x_in, + attention_mask=attention_mask, + encoded=cross_inputs, + encoder_decoder_mask=cross_attention_mask, + deterministic=deterministic) + return x_out + return _fprop + + + @staticmethod + def set_layer_params_to_stack_transformer(stacked_transformer_obj, _, layer_id): + te_transformer_tpl = pax_fiddle.Config(te_praxis.TransformerLayer, + name=f'layer_{layer_id}', + enable_relative_embedding=False, + enable_sequence_parallel=ENABLE_TE_SP, + transpose_batch_sequence=False + ) + + def update_ln_te_tpl(te_tpl, transformer_layer_tpl): + # TE requires all normalization are the same + assert transformer_layer_tpl.ln_tpl == transformer_layer_tpl.tr_fflayer_tpl.ln_tpl + ln_tpl = transformer_layer_tpl.ln_tpl + if issubclass(ln_tpl.cls, normalizations.LayerNorm): + te_tpl.layernorm_type = 'layernorm' + assert ln_tpl.use_scale + assert ln_tpl.use_bias + elif issubclass(ln_tpl.cls, normalizations.RmsNorm): + te_tpl.layernorm_type = 'rmsnorm' + else: + raise ValueError(f'Unsupported {ln_tpl.cls=}, LayerNorm, RmsNorm are supported.') + te_tpl.zero_centered_gamma = not ln_tpl.direct_scale + te_tpl.layernorm_epsilon = ln_tpl.epsilon + return te_tpl + + def update_ff_te_tpl(te_tpl, ff_layer_tpl): + mlp_activations = () + if ff_layer_tpl.use_gated_activation: + mlp_activations += ('linear',) + + if issubclass(ff_layer_tpl.activation_tpl.cls, activations.Identity): + mlp_activations += ('linear',) + else: + mlp_activations += (ff_layer_tpl.activation_tpl.cls.__name__.lower(),) + + te_tpl.mlp_activations = mlp_activations + return te_tpl + + def update_attn_te_tpl(te_tpl, attn_tpl): + if issubclass(attn_tpl.cls, attentions.DotProductAttention): + # Check the DotProductAttention parameters are aligned to TE's attention + assert attn_tpl.internal_enable_query_scale or attn_tpl.scale_logits_by_head_dims + assert not (attn_tpl.internal_enable_query_scale and attn_tpl.scale_logits_by_head_dims) + assert not attn_tpl.internal_enable_per_dim_scale + assert not attn_tpl.scale_query_by_dim_per_head + assert not attn_tpl.dconv_qkv + assert not attn_tpl.internal_gshard_gaussian_init + assert attn_tpl.relative_bias_tpl is None + assert attn_tpl.attention_extra_logit is None + assert attn_tpl.ngrammer_tpl is None + te_tpl.enable_rotary_pos_emb = attn_tpl.use_rotary_position_emb + if attn_tpl.rotary_position_emb_tpl.cls == embedding_softmax.RotaryPositionalEmbedding: + te_tpl.rotary_pos_emb_group_method = 'alternate' + elif issubclass(attn_tpl.cls, grouped_query_attention.GroupedQueryAttention): + te_tpl.num_gqa_groups = attn_tpl.num_kv_heads + if attn_tpl.rope_min_max_timescales is not None: + te_tpl.enable_rotary_pos_emb = True + te_tpl.rotary_pos_emb_windows = attn_tpl.rope_min_max_timescales + assert attn_tpl.atten_temp == 1. + elif issubclass(attn_tpl.cls, multi_query_attention.MultiQueryDotProductAttention): + te_tpl.num_gqa_groups = attn_tpl.num_kv_heads + te_tpl.enable_rotary_pos_emb = attn_tpl.use_rotary_position_emb + if attn_tpl.rotary_position_emb_tpl.cls == embedding_softmax.RotaryPositionalEmbedding: + te_tpl.rotary_pos_emb_group_method = 'alternate' + else: + raise ValueError(f'Unsupported {attn_tpl.cls=}') + assert attn_tpl.atten_logit_cap <= 0., 'atten_logit_cap > 0. is not supported in TE' + te_tpl.scaled_query_init = False + te_tpl.scale_attn_logits = True + return te_tpl + + transformer_layer_tpl = stacked_transformer_obj.transformer_layer_params_tpl + # Update TE normalization layer configs + te_transformer_tpl = update_ln_te_tpl(te_transformer_tpl, transformer_layer_tpl) + # Update TE feed forward layer configs + te_transformer_tpl = update_ff_te_tpl(te_transformer_tpl, transformer_layer_tpl.tr_fflayer_tpl) + # Update TE attention layer configs + te_transformer_tpl = update_attn_te_tpl(te_transformer_tpl, transformer_layer_tpl.tr_atten_tpl) + # TE currently only allow the bias config to be same between feed forward, qkv proj, out proj + assert (transformer_layer_tpl.tr_fflayer_tpl.has_bias == + transformer_layer_tpl.tr_atten_tpl.use_bias), "TE only allows same bias settings." + te_transformer_tpl.use_bias = transformer_layer_tpl.tr_fflayer_tpl.has_bias + te_transformer_tpl.self_attn_mask_type = 'padding_causal' \ + if stacked_transformer_obj.mask_self_attention else 'padding' + + te_transformer_tpl.logical_axes_rules = te_flax.extend_logical_axis_rules(tuple()) + te_transformer_tpl.params_init = stacked_transformer_obj.params_init + te_transformer_tpl.dtype = stacked_transformer_obj.fprop_dtype + te_transformer_tpl.layer_type = te_praxis.TransformerLayerType.DECODER if stacked_transformer_obj.use_cross_attention \ + else te_praxis.TransformerLayerType.ENCODER + te_transformer_tpl.num_attention_heads = stacked_transformer_obj.num_heads + te_transformer_tpl.hidden_size = stacked_transformer_obj.model_dims + te_transformer_tpl.mlp_hidden_size = stacked_transformer_obj.hidden_dims + te_transformer_tpl.dropout_rng_name = base_layer.RANDOM + te_transformer_tpl.attention_dropout = stacked_transformer_obj.atten_dropout_prob or stacked_transformer_obj.dropout_prob + te_transformer_tpl.hidden_dropout = stacked_transformer_obj.residual_dropout_prob or stacked_transformer_obj.dropout_prob + te_transformer_tpl.intermediate_dropout = stacked_transformer_obj.relu_dropout_prob or stacked_transformer_obj.dropout_prob + if stacked_transformer_obj.residual_droppath_prob > 0.0: + te_transformer_tpl.drop_path = ( + stacked_transformer_obj.residual_droppath_prob * layer_id / max(1, stacked_transformer_obj.num_layers) + ) + + assert stacked_transformer_obj.dim_per_head == stacked_transformer_obj.model_dims // stacked_transformer_obj.num_heads + assert stacked_transformer_obj.packed_input == False + assert len(stacked_transformer_obj.moe_layers) == 0 + assert stacked_transformer_obj.ngrammer_tpls is None + assert stacked_transformer_obj.local_window_size is None + + def update_lora_te_tpl(te_tpl, transformer_layer_tpl): + lora_enabled = False + te_lora_scope = "none" + lora_rank = None + if ( + transformer_layer_tpl.tr_fflayer_tpl.fflayer_tpl.linear_tpl.__fn_or_cls__ + is LoraLinear + ): + lora_enabled = True + mlp_included_in_lora = True + current_rank = ( + transformer_layer_tpl.tr_fflayer_tpl.fflayer_tpl.linear_tpl.rank + ) + lora_rank = ( + current_rank if lora_rank is None else lora_rank & current_rank + ) + + attention_included_in_lora = False + if ( + hasattr(transformer_layer_tpl.tr_atten_tpl, "combined_qkv_proj_tpl") + and transformer_layer_tpl.tr_atten_tpl.combined_qkv_proj_tpl.__fn_or_cls__ + is LoraCombinedQKVProjection + ): + lora_enabled = True + attention_included_in_lora = True + current_rank = ( + transformer_layer_tpl.tr_atten_tpl.combined_qkv_proj_tpl.rank + ) + lora_rank = ( + current_rank if lora_rank is None else lora_rank & current_rank + ) + + if ( + hasattr(transformer_layer_tpl.tr_atten_tpl, "proj_tpl") + and transformer_layer_tpl.tr_atten_tpl.proj_tpl.__fn_or_cls__ + is LoraAttentionProjection + ): + lora_enabled = True + attention_included_in_lora = True + current_rank = transformer_layer_tpl.tr_atten_tpl.proj_tpl.rank + lora_rank = ( + current_rank if lora_rank is None else lora_rank & current_rank + ) + + if lora_enabled: + assert ( + lora_rank > 0 + ), "LoRA rank should be the same for all layers and greater than 0." + if attention_included_in_lora and mlp_included_in_lora: + te_lora_scope = "all" + elif attention_included_in_lora and not mlp_included_in_lora: + te_lora_scope = "exclude_mlp" + elif mlp_included_in_lora and not attention_included_in_lora: + te_lora_scope = "mlp" + + te_transformer_tpl.low_rank_adaptation_scope = te_lora_scope + te_transformer_tpl.low_rank_adaptation_dim = lora_rank + + return te_tpl + + try: + te_transformer_tpl = update_lora_te_tpl( + te_transformer_tpl, transformer_layer_tpl + ) + except Exception as e: + logging.warning(f"Unable to use LoRA with TE: {e}") + + + return te_transformer_tpl + + @staticmethod + def get_input_bld(original_bld, batch_axes, mdl_axis): + if ENABLE_TE_SP: + return [batch_axes, mdl_axis, None] + return original_bld + + @staticmethod + def get_bld_mapping_for_pipelined_transformer(_): + rules = te_flax.extend_logical_axis_rules(tuple()) + # rules [(batch_axis_name, ('replicat', 'data'))', ...)] + batch_mapping = rules[0][1] + hidden_tp_mapping = rules[4][1] + # [Batch, Seqlen, Hidden] + bld_mapping = [batch_mapping, None, hidden_tp_mapping] + return bld_mapping + + @staticmethod + def check_checkpoint_policy(tpl): + """Some checkpoint policies are not compatible with TE fused attention""" + if issubclass(tpl.cls, layers.transformers.StackedTransformer): + remat = tpl.remat + attention_dropout = tpl.atten_dropout_prob or tpl.dropout_prob + elif issubclass(tpl.cls, layers.transformers.StackedTransformerRepeated): + if not issubclass(tpl.block.cls, layers.transformers.StackedTransformer): + return + remat = True # Current StackedTransformerRepeat always enables remat + attention_dropout = tpl.block.atten_dropout_prob or tpl.block.dropout_prob + else: + raise ValueError(f'Unsupported class={tpl.cls}') + + supported_checkpoint_policies = [ + AutodiffCheckpointType.SAVE_CONTEXT, + AutodiffCheckpointType.SAVE_CONTEXT_AND_OUT_PROJ, + AutodiffCheckpointType.SAVE_DOT_FOR_MLPERF_200B, + AutodiffCheckpointType.SAVE_QUANTIZED, + AutodiffCheckpointType.SAVE_DOT_EXCEPT_LOGITS_FFN1, + AutodiffCheckpointType.SAVE_DOT_EXCEPT_LOGITS] + fused_attn_enabled = int(os.getenv("NVTE_FUSED_ATTN", "0")) + if remat and fused_attn_enabled and attention_dropout > 0.: + assert tpl.checkpoint_policy in supported_checkpoint_policies, \ + "Fused attn in TE only permits policies that save 'context' tensors when dropout is " \ + "enabled. This restriction is due to the maintenance of the dropout offset within TE, " \ + "which is incompatible with the JAX remat. Consequently, it's necessary to bypass " \ + "recomputation in the attention layer when fused attention is activated. The supported " \ + f"checkpoint_policies are {supported_checkpoint_policies} but the provided " \ + f"checkpoint_policy is '{tpl.checkpoint_policy}'." + + +class TransformerEngineHelper(TransformerEngineHelperBase): + + @staticmethod + def is_enabled_te(): + enable_te = bool(int((os.environ.get("ENABLE_TE", False)))) + return (_IS_TRANSFORMER_ENGINE_INSTALLED and enable_te) + + @staticmethod + def get_helper(): + if TransformerEngineHelper.is_enabled_te(): + return TEInstalledHelper + return TENotInstalledHelper + + @staticmethod + def get_fprop_caller_of_stack_transformer(fprop, deterministic): + return TransformerEngineHelper.get_helper().get_fprop_caller_of_stack_transformer( + fprop, deterministic) + + @staticmethod + def set_layer_params_to_stack_transformer(stacked_transformer_obj, layer_p, layer_id): + return TransformerEngineHelper.get_helper().set_layer_params_to_stack_transformer( + stacked_transformer_obj, layer_p, layer_id) + + @staticmethod + def get_input_bld(original_bld, batch_axes, mdl_axis): + return TransformerEngineHelper.get_helper().get_input_bld( + original_bld, batch_axes, mdl_axis) + + @staticmethod + def get_bld_mapping_for_pipelined_transformer(xformer_layer_p): + return TransformerEngineHelper.get_helper().get_bld_mapping_for_pipelined_transformer( + xformer_layer_p) + + @staticmethod + def check_checkpoint_policy(tpl): + return TransformerEngineHelper.get_helper().check_checkpoint_policy(tpl) diff --git a/praxis/contrib/gpu/scripts_gpu/test_te_helper.py b/praxis/contrib/gpu/scripts_gpu/test_te_helper.py new file mode 100644 index 00000000..c65a25d6 --- /dev/null +++ b/praxis/contrib/gpu/scripts_gpu/test_te_helper.py @@ -0,0 +1,86 @@ +from praxis import base_hyperparams +from praxis import layers +from praxis import pax_fiddle +from praxis.contrib.gpu.scripts_gpu.te_helper import TransformerEngineHelper +from paxml.contrib.gpu.scripts_gpu.llama_utils import BaseLLaMA +from paxml.contrib.gpu.scripts_gpu.configs import Synthetic5B +from paxml.tasks.lm.params.lm_cloud import SyntheticDataset + +import transformer_engine.jax.praxis as te_praxis + + +class SyntheticLLaMA7B(BaseLLaMA, SyntheticDataset): + pass + + +class TestGPT5B(): + + def test_te_tpl_convert(self): + task = Synthetic5B().task() + st_tpl = task.model.lm_tpl.stacked_transformer_tpl.block + te_tpl = TransformerEngineHelper().set_layer_params_to_stack_transformer(st_tpl, None, 0) + te_cls = base_hyperparams.instantiate(te_tpl) + assert te_cls.hidden_size == st_tpl.model_dims + assert te_cls.mlp_hidden_size == st_tpl.hidden_dims + assert te_cls.num_attention_heads == st_tpl.num_heads + assert te_cls.num_gqa_groups == te_cls.num_attention_heads + assert te_cls.layernorm_type == 'layernorm' + assert te_cls.layernorm_epsilon == 1e-5 + assert te_cls.zero_centered_gamma == True + assert te_cls.hidden_dropout == 0. + assert te_cls.hidden_dropout_dims == () + assert te_cls.attention_dropout == 0. + assert te_cls.intermediate_dropout == 0. + assert te_cls.intermediate_dropout_dims == () + assert te_cls.mlp_activations == ('gelu',) + assert te_cls.use_bias == True + assert te_cls.apply_residual_connection_post_layernorm == False + assert te_cls.output_layernorm == False + assert te_cls.float32_attention_logits == False + assert te_cls.layer_type == te_praxis.TransformerLayerType.ENCODER + assert te_cls.self_attn_mask_type == 'padding_causal' + assert te_cls.self_attn_bias_type == None + assert te_cls.enable_rotary_pos_emb == False + assert te_cls.rotary_pos_emb_windows == (1, 10000) + assert te_cls.enable_relative_embedding == False + assert te_cls.drop_path == 0. + assert te_cls.transpose_batch_sequence == False + assert te_cls.scale_attn_logits == True + assert te_cls.scaled_query_init == False + + +class TestLLaMA7B(): + + def test_te_tpl_convert(self): + task = SyntheticLLaMA7B().task() + st_tpl = task.model.lm_tpl.stacked_transformer_tpl + te_tpl = TransformerEngineHelper().set_layer_params_to_stack_transformer(st_tpl, None, 0) + te_cls = base_hyperparams.instantiate(te_tpl) + assert te_cls.hidden_size == 4096 + assert te_cls.mlp_hidden_size == 16384 + assert te_cls.num_attention_heads == 32 + assert te_cls.num_gqa_groups == 32 + assert te_cls.layernorm_type == 'rmsnorm' + assert te_cls.layernorm_epsilon == 1e-5 + assert te_cls.zero_centered_gamma == False + assert te_cls.hidden_dropout == 0. + assert te_cls.hidden_dropout_dims == () + assert te_cls.attention_dropout == 0. + assert te_cls.intermediate_dropout == 0. + assert te_cls.intermediate_dropout_dims == () + assert te_cls.mlp_activations == ('linear', 'silu') + assert te_cls.use_bias == False + assert te_cls.apply_residual_connection_post_layernorm == False + assert te_cls.output_layernorm == False + assert te_cls.float32_attention_logits == False + assert te_cls.layer_type == te_praxis.TransformerLayerType.ENCODER + assert te_cls.self_attn_mask_type == 'padding_causal' + assert te_cls.self_attn_bias_type == None + assert te_cls.enable_rotary_pos_emb == True + assert te_cls.rotary_pos_emb_windows == (1, 10000) + assert te_cls.rotary_pos_emb_group_method == 'consecutive' + assert te_cls.enable_relative_embedding == False + assert te_cls.drop_path == 0. + assert te_cls.transpose_batch_sequence == False + assert te_cls.scale_attn_logits == True + assert te_cls.scaled_query_init == False diff --git a/praxis/layers/pipeline.py b/praxis/layers/pipeline.py index 62b65057..24a09953 100644 --- a/praxis/layers/pipeline.py +++ b/praxis/layers/pipeline.py @@ -28,6 +28,8 @@ from praxis import py_utils from praxis import pytypes from praxis.layers import checkpoint_policy +from praxis.contrib.gpu.scripts_gpu.te_helper import TE_PIPELINE_EXTRA_VMAP_VAR_AXES +from praxis.contrib.gpu.scripts_gpu.te_helper import TE_PIPELINE_EXTRA_SCAN_VAR_BROADCAST NestedMap = py_utils.NestedMap JTensor = pytypes.JTensor @@ -423,6 +425,7 @@ def layer_fprop(layer, *args, **kwargs): NON_TRAINABLE: 0, INTERMEDIATES: 0, HYPER_PARAMS: 0, + **TE_PIPELINE_EXTRA_VMAP_VAR_AXES }, split_rngs={PARAMS: self.is_initializing(), RANDOM: True}, metadata_params={ @@ -846,7 +849,7 @@ def _fill_nan_for_bubbles(x): # # Note that fprop should not use PARAMS rng because there is no var init. variable_carry = [] - variable_broadcast = [PARAMS] + variable_broadcast = [PARAMS] + TE_PIPELINE_EXTRA_SCAN_VAR_BROADCAST if self.is_mutable_collection(NON_TRAINABLE): variable_carry.append(NON_TRAINABLE) else: @@ -869,7 +872,7 @@ def _fill_nan_for_bubbles(x): if bf16_vars_to_convert is not None: scan_fn = nn.map_variables( scan_fn, - mapped_collections=[PARAMS], + mapped_collections=[PARAMS, 'fp8_meta_collection'], mutable=True, trans_in_fn=_get_to_f32_converter(bf16_vars_to_convert), trans_out_fn=_get_to_bf16_converter(bf16_vars_to_convert), diff --git a/praxis/layers/transformer_models.py b/praxis/layers/transformer_models.py index 8709b270..36ad3f63 100644 --- a/praxis/layers/transformer_models.py +++ b/praxis/layers/transformer_models.py @@ -33,6 +33,7 @@ from praxis.layers import multi_query_attention from praxis.layers import normalizations from praxis.layers import transformers +from praxis.contrib.gpu.scripts_gpu.te_helper import TransformerEngineHelper NestedMap = py_utils.NestedMap JTensor = pytypes.JTensor @@ -539,6 +540,8 @@ def set_sharding_params_v1( if training_optimized else [batch_axes, None, None] ) + bld = TransformerEngineHelper.get_input_bld(bld, batch_axes, mdl_axis) + egcm = ( [data_axis, None, None, mdl_axis] if training_optimized diff --git a/praxis/layers/transformers.py b/praxis/layers/transformers.py index fde24460..69ebba01 100644 --- a/praxis/layers/transformers.py +++ b/praxis/layers/transformers.py @@ -40,6 +40,7 @@ from praxis.layers import repeats from praxis.layers import stats from praxis.layers import stochastics +from praxis.contrib.gpu.scripts_gpu.te_helper import TransformerEngineHelper NestedMap = py_utils.NestedMap WeightInit = base_layer.WeightInit @@ -1738,28 +1739,10 @@ def _layer_params(i): p_i = self._clone_layer_params(self.transformer_layer_params_tpl[ii]) else: p_i = self._clone_layer_params(self.transformer_layer_params_tpl) - p_i.name = f'layer_{i}' - p_i.use_cross_attention = self.use_cross_attention - p_i.num_heads = self.num_heads - p_i.dim_per_head = self.dim_per_head - p_i.input_dims = self.model_dims - p_i.packed_input = self.packed_input - p_i.atten_dropout_prob = self.atten_dropout_prob or self.dropout_prob - p_i.residual_dropout_prob = ( - self.residual_dropout_prob or self.dropout_prob - ) - p_i.relu_dropout_prob = self.relu_dropout_prob or self.dropout_prob - p_i.hidden_dims = self.hidden_dims - if self.local_window_size is not None: - if isinstance(self.local_window_size[0], tuple): - p_i.tr_atten_tpl.local_window_size = self.local_window_size[i] - else: - p_i.tr_atten_tpl.local_window_size = self.local_window_size - if self.residual_droppath_prob > 0.0: - p_i.residual_droppath_prob = ( - self.residual_droppath_prob * i / max(1, self.num_layers) - ) + TransformerEngineHelper.check_checkpoint_policy(self._to_fdl_config()) + + p_i = TransformerEngineHelper.set_layer_params_to_stack_transformer(self, p_i, i) if self.moe_layers and i in self.moe_layers: assert self.num_experts > 0 @@ -1878,6 +1861,8 @@ def _fprop( ) return x_out + _fprop = TransformerEngineHelper.get_fprop_caller_of_stack_transformer(_fprop, self.do_eval) + fprop = _fprop if self.remat: fprop = nn.remat( @@ -2070,6 +2055,8 @@ class WeightSharding(base_layer.BaseLayer.WeightSharding): def setup(self) -> None: wp = self.weight_split_dims_mapping + TransformerEngineHelper.check_checkpoint_policy(self._to_fdl_config()) + repeat_l_params = pax_fiddle.Config( repeats.Repeat, sub_tpl=self.block, @@ -2347,7 +2334,7 @@ def __call__( else: assert self.pipeline_stage.cls == StackedTransformerRepeated xformer_layer_p = self.pipeline_stage.block.transformer_layer_params_tpl - bld_mapping = xformer_layer_p.tr_atten_tpl.activation_split_dims_mapping.bld + bld_mapping = TransformerEngineHelper.get_bld_mapping_for_pipelined_transformer(xformer_layer_p) if not self.stream_io: # Annotate the inputs before the pipeline to prevent unexpected # propagation from earlier layers.