Skip to content

Releases: patrick-kidger/jaxtyping

jaxtyping v0.2.34

01 Sep 14:44
Compare
Choose a tag to compare
  • Compatibility with ray -- this fixes crashes with the error message has no attribute 'index_variadic' (#198, #237)
  • Compatibility with Python 3.12: fixed deprecation warnings about ast.Str. (#236, thanks @phinate!)
  • Error message improvements + doc improvements (#233, #240, thanks @padix-key and @jjyyxx!)

New Contributors

Full Changelog: v0.2.33...v0.2.34

jaxtyping v0.2.33

12 Jul 12:02
Compare
Choose a tag to compare
  • Compatibility with Python 3.10 when using Any as the array type.
  • Compatibility with generic array types.
  • Array typevars now respect __constraints__

Full Changelog: v0.2.32...v0.2.33

jaxtyping v0.2.32

12 Jul 09:44
Compare
Choose a tag to compare
  • The array type can now be either Any or a TypeVar. In both cases this means that anything is allowed at runtime. As usual, static type checkers will only look at the array part of an annotation, so that an annotation of the form Float[T, "foo bar"] (where T = TypeVar("T")) will be treated as just T by static type checkers. This allows for expressing array-type-polymorphism with static typechecking. Here's an example:

    import numpy as np
    import torch
    from typing import TypeVar
    
    TensorLike = TypeVar("TensorLike", np.ndarray, torch.Tensor)
    
    def stack_scalars(x: Float[TensorLike, ""], y: Float[TensorLike, ""]) -> Float[TensorLike, "2"]:
        if isinstance(x, np.ndarray) and isinstance(y, np.ndarray):
            return np.stack([x, y])
        elif isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor):
            return torch.stack([x, y])
        else:
            raise ValueError("Invalid array types!")
  • Fixed a bug in which the very first argument to a function was erroneously reported as the one at fault for a typechecking error. This bug occurred when using default arguments.

Full Changelog: v0.2.31...v0.2.32

jaxtyping v0.2.31

25 Jun 18:34
Compare
Choose a tag to compare
  • Now duck-type on array shapes and dtypes, so you can use jaxtyping for your custom arraylike objects:

     class FooDtype(jaxtyping.AbstractDtype):
     	dtypes = ["foo"]
    
     class MyArray:
     	@property
     	def dtype(self):
     		return "foo"
    
     	@property
     	def shape(self):
     		return (3, 1, 4)
    
     def f(x: FooDtype[MyArray, "3 1 4"]): ...
  • Improved compatibility when typeguard warns that you're typechecking a function without annotations: it will no longer mention the jaxtyping-internal check_params function and will instead mention the name of the function that is missing annotations.

  • Improved the error message when typechecking fails, to state the full some_module.SomeClass.some_method rather than just some_method.

  • Fixed a JAX deprecation warning for jax.tree_map. (Thanks @groszewn!)

New Contributors

Full Changelog: v0.2.30...v0.2.31

jaxtyping v0.2.30

13 Jun 17:25
Compare
Choose a tag to compare
  • Now reporting the correct source code line numbers when using the import hook. Makes debuggers useful again! #214
  • Now supports numpy structured dtypes (Thanks @alexfanqi! #211)
  • Now respecting typing.no_type_check. #216

New Contributors

Full Changelog: v0.2.29...v0.2.30

jaxtyping v0.2.29

27 May 14:29
Compare
Choose a tag to compare
  • Crash fix for when jax is available but jaxlib is not. (Thanks @ar0ck! #191)
  • Crash fix when used alongside old TensorFlow versions that don't support tensor.ndim (Thanks @dziulek! #193)
  • Crash fix when using a default argument as a symbolic dimension size. (Thanks @jaraujo98! #208)
  • Improved import times by defining the IPython magic lazily. (Thanks @superbobry! #201)
  • The import hook will now typecheck functions that do not have any annotations in the arguments or return value. This is useful for those that do manual isinstance checks in the body of teh function. (Thanks @nimashoghi! #205)
  • Dropped the dependency on numpy. This makes it possible to just use jaxtyping+typeguard as the one-stop-shop for all runtime typechecking, even when you're not using arrays. Obviously that's a little unusual -- not really the main focus of jaxtyping -- but helps when wanting a single choice of runtime type checker across an entire codebase, only parts of which may use arrays. (#212)

New Contributors

Full Changelog: v0.2.28...v0.2.29

jaxtyping v0.2.28

07 Mar 17:31
Compare
Choose a tag to compare

Autogenerated release notes as follows:

What's Changed

Full Changelog: v0.2.27...v0.2.28

jaxtyping v0.2.27

06 Mar 19:38
Compare
Choose a tag to compare

Quick bugfix release:

  • Fixed some isinstance checks against variadics crashing (although this was when it was about to return False anyway). (Thanks @asford! #186)
  • Fixed docs for downstream libraries (Equinox, ...) not generating correctl (#182)

New Contributors

Full Changelog: v0.2.26...v0.2.27

jaxtyping v0.2.26

25 Feb 12:10
Compare
Choose a tag to compare

Features

  • Added jaxtyping.print_bindings to manually inspect the values of each axis, whilst inside a function.
  • Added support for jaxtyping.{Int4, UInt4}. (#174, thanks @jianlijianli!)

Bugfixes

  • We no longer import JAX at all, even if it is present. This ensures compatibility when using jaxtyping+PyTorch alongside an old JAX installation. (All JAX re-exports, like jaxtyping.Array = jax.Array, are looked up dynamically rather than import time.) (#178)
  • We no longer raise false postiives when @jaxtyped-ing generators (with yield statements). (#91, #171, thanks @knyazer!)

Internals

  • Added support for beartype's pseudostandard __instancecheck_str__ method. Instead of isinstance(x, Float[Array, "foo"]), then one can now call Float[Array, "foo"].__instancecheck_str__(x), which will return either an empty string (success) or an error message describing why the check failed (wrong shape, wrong dtype, ...). In practice this feature probably isn't super usable right now; we'll need to wait until we've later done a better job ensuring compatibility between the jaxtyping import hooks and the beartype import hooks.

Docs

New Contributors

Full Changelog: v0.2.25...v0.2.26

jaxtyping v0.2.25

15 Dec 18:38
Compare
Choose a tag to compare

This release is primarily a usability release, designed to help ensure the library is being used correctly.

  • The error messages from a failed typecheck have been improved, to explicitly highlight more information about which argument was wrong. :)
  • If the jaxtyping.jaxtyped(typechecker=...) argument is not passed, then a warning will be displayed. In practice, this will trigger:
    • if using the old double-decorator syntax (@jaxtyped @beartype def foo(...): ...) -- upgrade to the new @jaxtyped(typechecker=beartype) def foo(...): ... syntax and get better error messages! :)
    • If making the easy mistake of writing @jaxtyped(beartype) def foo(...): ... -- in this case it's actually the beartype call that is jaxtype'd, not foo.
  • Incorrect use of jaxtyping annotations will now raise an jaxtyping.AnnotationError rather than a mix of RuntimeErrors, NameErrors etc. For example isinstance(x, Float) is not correct (you should write something like Float[Array, "..."]) instead), and this will raise such an AnnotationError.
  • Introduced two config flags:
    • JAXTYPING_DISABLE=1 / jaxtyping.config.update("jaxtyping_disable", True): if enabled then all runtime type checking will be skipped.
    • JAXTYPING_REMOVE_TYPECHECKER_STACK=1 / jaxtyping.config.update("jaxtyping_remove_typechecker_stack", True): if enabled then type-checking errors will only show the jaxtyping.TypeCheckError, and won't include any extra stack trace from the underlying type-checker (beartype/typeguard). Some users have found that they preferred the conciseness over the extra information.

Full Changelog: v0.2.24...v0.2.25