From d9c5cfd61ac890b8ba9e184a9290db9e794bd052 Mon Sep 17 00:00:00 2001 From: xuewc Date: Thu, 1 Feb 2024 06:39:18 +0800 Subject: [PATCH] refactor: util routine --- src/elisa/util/__init__.py | 0 src/elisa/{util.py => util/config.py} | 4 ++-- src/elisa/util/typing.py | 14 ++++++++++++++ 3 files changed, 16 insertions(+), 2 deletions(-) create mode 100644 src/elisa/util/__init__.py rename src/elisa/{util.py => util/config.py} (86%) create mode 100644 src/elisa/util/typing.py diff --git a/src/elisa/util/__init__.py b/src/elisa/util/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/elisa/util.py b/src/elisa/util/config.py similarity index 86% rename from src/elisa/util.py rename to src/elisa/util/config.py index e8772bdb..66fbe0ff 100644 --- a/src/elisa/util.py +++ b/src/elisa/util/config.py @@ -1,4 +1,4 @@ -"""Helper functions.""" +"""Helper functions for computation environment configuration.""" from __future__ import annotations import warnings @@ -10,7 +10,7 @@ def jax_enable_x64(use_x64: bool) -> None: - """Changes the default float precision of array in JAX. + """Changes the default float precision of arrays in JAX. Parameters ---------- diff --git a/src/elisa/util/typing.py b/src/elisa/util/typing.py new file mode 100644 index 00000000..9bd7f5fd --- /dev/null +++ b/src/elisa/util/typing.py @@ -0,0 +1,14 @@ +from typing import Union + +import jax.numpy as jnp +import numpy as np +from jax import Array + +FloatType = jnp.result_type(float) +IntType = jnp.result_type(int) + +JAXFloat = Array +PRNGKey = Array +JAXArray = Array +NumpyArray = np.ndarray +Array = Union[NumpyArray, JAXArray]