From 5747dc360c58b439ddfb31e9fa7f57837c627cfe Mon Sep 17 00:00:00 2001 From: George Necula Date: Sun, 6 Oct 2024 06:40:51 -0700 Subject: [PATCH] Replace usage of jax.experimental.host_callback.call with jax.experimental.io_callback. The jax.experimental.host_callback module is deprecated and will be removed. See https://github.com/google/jax/issues/20385. PiperOrigin-RevId: 682885092 --- t5x/decoding_test.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/t5x/decoding_test.py b/t5x/decoding_test.py index 13bf10fc9..ebac83a67 100644 --- a/t5x/decoding_test.py +++ b/t5x/decoding_test.py @@ -21,7 +21,7 @@ from absl.testing import absltest from absl.testing import parameterized import jax -from jax.experimental import host_callback as hcb +from jax.experimental import io_callback import jax.numpy as jnp import numpy as np from t5x import decoding @@ -156,12 +156,10 @@ def callback_fn(current_index_and_sequences): sequences[i, current_index[i] + 1] = EOS_ID return sequences - sequences = hcb.call( + sequences = io_callback( callback_fn, + jax.ShapeDtypeStruct(state.sequences.shape, state.sequences.dtype), (state.cur_index, state.sequences), - result_shape=jax.ShapeDtypeStruct( - state.sequences.shape, state.sequences.dtype - ), ) return state.replace(sequences=sequences)