Skip to content

Commit

Permalink
refactor: util routine
Browse files Browse the repository at this point in the history
  • Loading branch information
wcxve committed Jan 31, 2024
1 parent ebd55da commit d9c5cfd
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 2 deletions.
Empty file added src/elisa/util/__init__.py
Empty file.
4 changes: 2 additions & 2 deletions src/elisa/util.py → src/elisa/util/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Helper functions."""
"""Helper functions for computation environment configuration."""
from __future__ import annotations

import warnings
Expand All @@ -10,7 +10,7 @@


def jax_enable_x64(use_x64: bool) -> None:
"""Changes the default float precision of array in JAX.
"""Changes the default float precision of arrays in JAX.
Parameters
----------
Expand Down
14 changes: 14 additions & 0 deletions src/elisa/util/typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from typing import Union

import jax.numpy as jnp
import numpy as np
from jax import Array

FloatType = jnp.result_type(float)
IntType = jnp.result_type(int)

JAXFloat = Array
PRNGKey = Array
JAXArray = Array
NumpyArray = np.ndarray
Array = Union[NumpyArray, JAXArray]

0 comments on commit d9c5cfd

Please sign in to comment.