Skip to content

Commit

Permalink
Replace usage of jax.experimental.host_callback.call with jax.experim…
Browse files Browse the repository at this point in the history
…ental.io_callback.

The jax.experimental.host_callback module is deprecated and will be removed.

See jax-ml/jax#20385.

PiperOrigin-RevId: 683196778
  • Loading branch information
gnecula authored and t5-copybara committed Oct 7, 2024
1 parent 705247b commit 9bb1e85
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions t5x/decoding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax.experimental import host_callback as hcb
from jax.experimental import io_callback
import jax.numpy as jnp
import numpy as np
from t5x import decoding
Expand Down Expand Up @@ -156,12 +156,10 @@ def callback_fn(current_index_and_sequences):
sequences[i, current_index[i] + 1] = EOS_ID
return sequences

sequences = hcb.call(
sequences = io_callback(
callback_fn,
jax.ShapeDtypeStruct(state.sequences.shape, state.sequences.dtype),
(state.cur_index, state.sequences),
result_shape=jax.ShapeDtypeStruct(
state.sequences.shape, state.sequences.dtype
),
)
return state.replace(sequences=sequences)

Expand Down

0 comments on commit 9bb1e85

Please sign in to comment.