Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support runtime type-checking of generic functions #130

Open
davnn opened this issue Oct 16, 2023 · 1 comment
Open

Support runtime type-checking of generic functions #130

davnn opened this issue Oct 16, 2023 · 1 comment
Labels
feature New feature

Comments

@davnn
Copy link

davnn commented Oct 16, 2023

Hi,

do you think it is conceivable to implement type checks for generics, e.g. generic array types or generic data types, see below.

from jaxtyping import jaxtyped
from beartype import beartype
from jax import Array as JaxArray
from torch import Tensor as TorchArray
from numpy import ndarray as NumpyArray

GenericArray = TypeVar("GenericArray", JaxArray, NumpyArray, TorchArray)
GenericFloat = TypeVar("GenericFloat", Float16, Float32, Float64)

@jaxtyped
@beartype
def f(a: Shaped[GenericArray, "n"]) -> Shaped[GenericArray, "n"]:
    return a

@jaxtyped
@beartype
def f(a: GenericFloat[NumpyArray, "n"]) -> GenericFloat[NumpyArray, "n"]:
    return a

I would be happy to contribute, but I am unsure if there is even a possiblity of success.

@patrick-kidger
Copy link
Owner

Yup, I think this should be possible! I'd be happy to take a PR on this.

@patrick-kidger patrick-kidger added the feature New feature label Oct 16, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature New feature
Projects
None yet
Development

No branches or pull requests

2 participants