diff --git a/t5x/eval.py b/t5x/eval.py index 73448560f..ab94b0eb6 100644 --- a/t5x/eval.py +++ b/t5x/eval.py @@ -30,6 +30,7 @@ os.environ['FLAX_LAZY_RNG'] = 'no' from absl import logging from clu import metric_writers +import gin import jax import seqio from t5x import checkpoints @@ -61,6 +62,7 @@ def __call__( ... +@gin.configurable class InferenceEvaluator: """Runs evaluation of the model against a given SeqIo task."""