Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 688497033
  • Loading branch information
gnecula authored and t5-copybara committed Oct 24, 2024
1 parent b642f30 commit 98bd34c
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions t5x/export_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,8 @@ def create_inference_function(
output_len: Optional[int] = None,
) -> Callable[[Mapping[str, Any], Any], PyTree]:
"""Fetches a model and returns the inference function based on inference_mode."""
# Always use native serialization. The non-native serialization is deprecated.
del native_lowering
if partitioner and train_state_initializer:
maybe_partition = lambda fn: partitioner.partition( # pylint:disable=g-long-lambda
fn,
Expand Down Expand Up @@ -390,13 +392,12 @@ def model_fn(
if jax2tf_disable_platform_checks
else []
)
if native_lowering and (not native_lowering_platforms):
if not native_lowering_platforms:
# Change default value to make the exported cpu model still work.
native_lowering_platforms = ['cpu', 'tpu']
model_fn = jax2tf.convert(
model_fn,
polymorphic_shapes=[None, polymorphic_shapes_inputs],
native_serialization=native_lowering,
native_serialization_platforms=native_lowering_platforms,
native_serialization_disabled_checks=disabled_checks,
enable_xla=enable_xla,
Expand Down Expand Up @@ -1517,8 +1518,7 @@ def save(
validation_examples: Optional list of validation examples. If proveded, they
will be used to validate the latency and numeric accuracy of the TPU saved
model.
native_lowering: for experimental purposes only -- if True, don't convert
Jax fns to TF fns.
native_lowering: deprecated, always True.
native_lowering_platforms: In conjunction with `native_lowering`, specify
the platform(s) for which to lower the code. Must be a tuple of strings,
including a subset of: 'cpu', 'cuda', 'rocm', 'tpu'. The default
Expand All @@ -1538,6 +1538,8 @@ def save(
create_polymorphic_shapes_fn: Optional function to create polymorphic shapes
for input tensors to the JAX model function.
""" # fmt: skip
# Always use native serialization. The non-native serialization is deprecated.
del native_lowering
jax.monitoring.record_event('/jax/t5x/export/beacon')
output_dirs = _standardize_output_dirs(output_dir)
del output_dir
Expand Down Expand Up @@ -1584,7 +1586,7 @@ def save(
if create_decoding_state_callback_fn is not None:
decoding_state_callback_fn = create_decoding_state_callback_fn(
vocab=output_vocab,
call_tf_graph=native_lowering,
call_tf_graph=True,
)

model_tf_fn = create_inference_function_fn(
Expand All @@ -1598,7 +1600,7 @@ def save(
polymorphic_shapes_inputs=create_polymorphic_shapes_fn(
input_signature, preprocessor
),
native_lowering=native_lowering,
native_lowering=True,
native_lowering_platforms=native_lowering_platforms,
)

Expand Down

0 comments on commit 98bd34c

Please sign in to comment.