jaxtyping is a library providing type annotations and runtime type-checking for: * shape and dtype of JAX arrays * PyTrees