From 028fab9b2b6bea329cb5f151c9e443cc8b5a7aa1 Mon Sep 17 00:00:00 2001 From: Niket Kumar Bhumihar Date: Wed, 11 Sep 2024 16:31:59 -0700 Subject: [PATCH] Rename `is_supported_empty_aggregation_type` and `is_supported_aggregation_type` functions. PiperOrigin-RevId: 673582513 --- t5x/checkpoint_utils.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/t5x/checkpoint_utils.py b/t5x/checkpoint_utils.py index c43a8f016..d817b61af 100644 --- a/t5x/checkpoint_utils.py +++ b/t5x/checkpoint_utils.py @@ -251,6 +251,12 @@ def detect_checkpoint_type( return checkpoint_type +def _is_supported_empty_value(value: Any) -> bool: + if hasattr(ocp.type_handlers, 'is_supported_empty_aggregation_type'): + return ocp.type_handlers.is_supported_empty_aggregation_type(value) + return ocp.type_handlers.is_supported_empty_value(value) + + def get_restore_parameters( directory: epath.Path, structure: PyTree, @@ -280,7 +286,7 @@ def _get_param_info( name: str, meta_or_value: Union[Any, ocp.metadata.tree.ValueMetadataEntry], ) -> Union[ocp.type_handlers.ParamInfo, Any]: - if ocp.type_handlers.is_supported_empty_aggregation_type(meta_or_value): + if _is_supported_empty_value(meta_or_value): # Empty node, ParamInfo should not be returned. return meta_or_value elif not isinstance(meta_or_value, ocp.metadata.tree.ValueMetadataEntry):