diff --git a/megatron/initialize.py b/megatron/initialize.py index af801efa40..f85944e821 100644 --- a/megatron/initialize.py +++ b/megatron/initialize.py @@ -211,6 +211,7 @@ def _initialize_distributed(): args.pipeline_model_parallel_size, args.virtual_pipeline_model_parallel_size, args.pipeline_model_parallel_split_rank, + args.fp8_e4m3 or args.fp8_hybrid, ) if args.rank == 0: print( diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 26717789e8..1ddd3adedd 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -1337,7 +1337,7 @@ def __init__(self, config, if self.use_fp8: assert args.transformer_impl == 'transformer_engine', \ 'transformer-engine required for fp8 training and inference' - self.fp8_group = mpu.get_data_parallel_group() + self.fp8_group = mpu.get_amax_reduction_group() if args.fp8_e4m3: fp8_format = transformer_engine.common.recipe.Format.E4M3 elif args.fp8_hybrid: