ezyang's blog

the arc of software bends towards understanding

JAX

Replicate Forwards, Partial Backwards

A central thesis of sharding in types is that the backward sharding can be directly computed from the forward sharding. This is not true for DTensor today, e.g., as seen in sum to Partial, and it leads to confusion where users cannot easily predict what the sharding of tensors are in their program. The question now arises: given a forward sharding, what should its backward sharding be? There are some easy cases to fill in:

Read more...

The JAX sharding type system

Conventionally, a type system is something that classifies values into data types like float32 or int64. However, fancy type systems go beyond data types, allowing us to talk about potentially arbitrary invariants on data; for example, if we were to talk about the “type” of a array, it would cover not only its data type, but also its shape, e.g., f32[40, 20]. JAX’s type system of abstract values (avals) goes further than just data types and shapes and is equipped to reason about sharding related invariants. However, this type system is poorly documented, especially recent additions like reduced/unreduced axes (circa June 2025). In this blog post, I want to give a consolidated description of the sharding related aspects of JAX’s typing in explicit sharding mode, as of 2026. Disclaimer: I am not a JAX developer, and there may potentially be mistakes in this presentation; please let me know about errors in Twitter. I will assume that you have some knowledge about how to work with JAX sharding in the frontend; please refer to Distributed arrays and automatic parallelization, Explicit sharding and Manual parallelism with shard_map for a refresher on these topics.

Read more...

Global vs Local SPMD

Global SPMD (also known as the “global view”, exposed by code using DTensor or jax.Array) refers to writing multi-device code as if it was on a single device, with an orthogonal mechanism for expressing how these full tensors are distributed over multiple devices (this mechanism can be implicit or explicit, e.g., as seen in this table).

Local SPMD (also known as “per-device view”, and exposed by local_map and shard_map, and also traditional PyTorch distributed code operating on plain Tensors, e.g., Megatron-style) refers to writing code from the “local” view on a single device, with explicit collectives when communicating across devices.

Read more...

Megatron via shard_map

In Computing sharding with einsum, we worked an example of Megatron style tensor parallelism where we discover that the ordinary backwards formula for linear results in a pending reduction on grad_input, even though the input was replicated and no communications happened in forwards. In Megatron, which is implemented with plain Tensors and manual collectives, you just have to know that this reduction is necessary and manually insert it with a custom autograd function.

Read more...

Learning to love mesh-oriented sharding

Famously, PyTorch and JAX don’t agree on how shardings should be represented: PyTorch takes a mesh-dim oriented view, where for each dimension in your device mesh, you specify what sharding should be applied; JAX takes a tensor-dim oriented view, where for each dimension on your tensor, you say which mesh dimensions (potentially multiple!) shard it. Among my Twitter followers, it is generally agreed that the JAX formulation is more intuitive from a user perspective. OK, fine; if you prefer one representation over another, it’s easy enough to translate between the two representations (in easy situations, at least!) In this post, I want to talk more about the framework implementation side: what is the better internal representation of sharding? I don’t claim to have all the answers, but my motivation for writing this post is to help explain where I currently stand and how I evaluate proposals for evolving DTensor and sharding in PyTorch.

Read more...

The Parallelism Mesh Zoo

When training large scale LLMs, there is a large assortment of parallelization strategies which you can employ to scale your training runs to work on more GPUs. There are already a number of good resources for understanding how to parallelize your models: I particularly recommend How To Scale Your Model and The Ultra-Scale Playbook. The purpose of this blog post is to discuss parallelization strategies in a more schematic way by focusing only on how they affect your device mesh. The device mesh is an abstraction used by both PyTorch and JAX that takes your GPUs (however many of them you’ve got in your cluster!) and organizes them into a N-D tensor that expresses how the devices communicate with each other. When we parallelize computation, we shard a tensor along one dimension of the mesh, and then do collectives along that dimension when there are nontrivial dependencies between shards. Being able to explain why a device mesh is set up the way it is for a collection of parallelization strategies is a good check for seeing if you understand how the parallelization strategies work in the first place! (Credit: This post was influenced by Visualizing 6D Mesh Parallelism.)

Read more...

vmap in Haskell

vmap is an interface popularized by JAX which offers you a vectorizing map. Semantically, a vmap is exactly equivalent to a map in Haskell; the key difference is that operations run under a vmap are vectorized. If you map a convolution and a matrix multiply, you will have one big loop which repeatedly calls convolution and matrix multiply for each entry in your batch. If you vmap a convolution and matrix multiply, you’ll call the batched versions of convolution and matrix multiply once. Unless you have a fuser, on most modern deep learning frameworks, calling the batched implementations of these operations will be much faster.

Read more...