-
Notifications
You must be signed in to change notification settings - Fork 61
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support typechecking of jax.sharding.NamedSharding #164
Comments
Haha, thank you! I like this idea, and I think the approach is using Regarding the syntax, I don't think we can use Regarding meshes: morally speaking, I don't think we should actually need a mesh? That is, given a sharding during lowering, can we get a |
That's a shame! Not obvious to me how a prefix would work, because we still need a separator between the dimension name and the sharding, so e.g. a prefix of Some ideas:
In net I probably go for
If Thus, since we can't reliably recover names, my approach was to go in the other direction: reliably map names to |
I like your syntax suggestions! My only suggestion is to probably switch them around: As an alternative strategy, we could consider something like explicitly taking For the shardings, I'm curious what @yashk2810 thinks. (Although I don't know if he checks GitHub comments like this :) ) |
|
Given the discussion of The "tuple of mesh axes" part is not supported in the syntax discussion above. (It's sometimes useful to, e.g., express sharding simultaneously over the x and y axes of a TPU mesh or to express sharding of a reshaped tensor.) The natural extension to the syntaxes above would be to support tuples via comma, e.g.
Take a look which you think is more readable:
I mildly prefer (2), because putting the axis name first (before the sharding) seems to put slightly more emphasis on the axis name. Also, I think
I can see the appeal of reusing the existing type! The main disadvantages I see are:
Wonderful! Ok, let's rely on this functionality then. |
Thanks @yashk2810 ! Okay so on balance, I think I'm inclined to go with the
On the points you've raised:
When it comes to handling meshes, I suppose we should simply do an Does all of the above sound reasonable to you? If so, then I'd be happy to take a pull request implementing this. :) |
Thanks Patrick. I think your reasoning is mostly valid, although I value things substantially different than you (I don't care about non-JAX support, and I suspect I care much more about actually using this feature than you do :)), which makes me land in a different place than you. One place where I somewhat disagree with your reasoning:
There's a particular (perhaps idiosyncratic to me) way of viewing things where this is not true. For way of example, let I recognize this may be a view that is somewhat idiosyncratic to users of I understand you've made your decision and I'm not trying to relitigate it. I think the syntax you've proposed is workable if not (for me) perfect. If I want a different syntax in my own codebases (where I am free from the constraints of non-JAX support, where I want to use
Sounds great.
Good enough! Happy to take a stab when I get some time. Might take some time... |
I love jaxtyping! Can I have more of it please?
Specifically, I'd like to make assertions about the sharding of my
jax.Array
objects. Given an arrayFloat[Array, "batch seqlen channel"]
I'd like to assert its sharding with syntax like this:Float[ShardedArray, "batch/data_parallel seqlen channel/tensor_parallel"]
. This syntax is a commonly used plain-text representation for shardings, following e.g. the notation in Figure 5 of Efficiently Scaling Transformer Inference.The intention is that the sharding part of this syntax would this syntax would parse to a sharding spec of
jax.sharding.PartitionSpec('data_parallel', None, 'tensor_parallel')
. We could then assert equivalence of this partition spec against the array's actual sharding using a combination ofjax.debug.inspect_array_sharding
andjax.sharding.XLACompatibleSharding.is_equivalent_to
.There's a small hiccup: to convert a
jax.sharding.PartitionSpec
to ajax.sharding.NamedSharding
, we need ajax.sharding.Mesh
, which is non-constant data (contains jax "device" objects) that is undesirable to put in a type signature. I think the best user experience would be to put this in a thread-local; perhaps even the one that JAX already uses for (now-superseded) pjit:jax._src.mesh.thread_resources.env.physical_mesh
(unfortunately, this is private). In that case, the sharding assertion could look like this:Complete colab that tries this out on 8 CPUs, and shows that it works under
jit
too:: https://colab.research.google.com/drive/1oLy66BjKOWmh7dFu8aZbo_gBypDtlNeQ?usp=sharing.The text was updated successfully, but these errors were encountered: