You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Are there any plans to support Gemma2 in the torchtitan? I tried to use torchtitan to finetune Gemma2 model, but stuck on the following problem: how to parallelize tied layer in Gemma2 model? Maybe somebody kwon the solution for this problem 😄
The text was updated successfully, but these errors were encountered:
If you apply fully_shard to each transformer block and then to the root module, this should work for tied embedding and final linear. The root module will manage both.
I want to shard output embedding layer - I use same strategy as in Llama, but training stacked after first butch ColwiseParallel( input_layouts=Shard(1), output_layouts=Shard(-1) if loss_parallel else Replicate(), use_local_output=not loss_parallel, )
Are there any plans to support Gemma2 in the torchtitan? I tried to use torchtitan to finetune Gemma2 model, but stuck on the following problem: how to parallelize tied layer in Gemma2 model? Maybe somebody kwon the solution for this problem 😄
The text was updated successfully, but these errors were encountered: