diff --git a/t5x/eval.py b/t5x/eval.py index 73448560f..66061bd2d 100644 --- a/t5x/eval.py +++ b/t5x/eval.py @@ -30,6 +30,7 @@ os.environ['FLAX_LAZY_RNG'] = 'no' from absl import logging from clu import metric_writers +import gin import jax import seqio from t5x import checkpoints @@ -61,6 +62,7 @@ def __call__( ... +# @gin.configurable class InferenceEvaluator: """Runs evaluation of the model against a given SeqIo task.""" @@ -402,7 +404,6 @@ def _maybe_run_train_eval(train_state: train_state_lib.TrainState): from absl import app from absl import flags import fiddle as fdl - import gin from t5x import config_utils FLAGS = flags.FLAGS