Skip to content

Commit

Permalink
Running Partitioner.compile within Mesh context-manager
Browse files Browse the repository at this point in the history
  • Loading branch information
mingxu1067 committed Nov 22, 2023
1 parent 4b8a1ad commit d08a684
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion t5x/partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,13 @@ def lower(self, *args, **kwargs):
self._logical_axis_rules):
return self._pjitted_fn.lower(*args, **kwargs)

def lower_and_compile(self, *args, **kwargs):
with Mesh(self._mesh.devices,
self._mesh.axis_names), flax_partitioning.axis_rules(
self._logical_axis_rules):
return self._pjitted_fn.lower(*args, **kwargs).compile()



class BasePjitPartitioner(BasePartitioner):
"""Partitioner that uses T5X version of jax.pjit."""
Expand Down Expand Up @@ -816,7 +823,7 @@ def partition(

def compile(self, partitioned_fn: PjittedFnWithContext,
*args) -> CompiledPartitionedCallable:
return partitioned_fn.lower(*args).compile()
return partitioned_fn.lower_and_compile(*args)


class PjitPartitioner(BasePjitPartitioner):
Expand Down

0 comments on commit d08a684

Please sign in to comment.