Global vs Local SPMD
January 27, 2026Global SPMD (also known as the “global view”, exposed by code using DTensor or jax.Array) refers to writing multi-device code as if it was on a single device, with an orthogonal mechanism for expressing how these full tensors are distributed over multiple devices (this mechanism can be implicit or explicit, e.g., as seen in this table).
Local SPMD (also known as “per-device view”, and exposed by local_map and shard_map, and also traditional PyTorch distributed code operating on plain Tensors, e.g., Megatron-style) refers to writing code from the “local” view on a single device, with explicit collectives when communicating across devices.
The big question I want to address in this post is, how do I pick between these two modes? Conventional wisdom in the JAX ecosystem is that you should default to global SPMD (either via auto or explicit sharding mode), and drop down to manual local SPMD if the compiler isn’t giving you the correct communications and you want to do it by hand. I want to give a more nuanced version of this take.
First, there is nothing about global SPMD that precludes fine-grained control over when collectives happen. JAX doesn’t directly support this kind of mode, but it’s not difficult to imagine how it would work: take JAX with explicit mesh axes, but instead of only erroring out on ambiguous communication, error out when any implicit communication happens (e.g., you must explicitly call reshard to trigger communication). We actually added an explicit mode to DTensor along these lines, although it currently doesn’t work super well because we lack some other important aspects of JAX sharding in types. Specifically, we don’t have the invariant that the cotangent sharding is directly computable from the primal sharding–DTensor is very much like an “auto” mode in that respect, where the forwards/backwards sharding can be chosen differently. This makes it difficult to rule out implicit redistributes in backwards, since whether or not a redistribute occurs is heavily dependent on the exact details of how sharding has propagated through the backwards graph.
For me, the more important difference is that global and local SPMD are actually different semantics. An obvious divergence is that in local SPMD, there isn’t any source of truth about what the “global” view of the Tensor is: the local tensor just exists, you know that there are different versions of it on the other nodes. You don’t know how you’re supposed to stack these tensors together to get a global tensor: you typically only know this at the boundary of the local region using out_specs / out_placements. And even if you knew how to stack the tensors together, local SPMD has different semantics than global SPMD, as the exact computation you perform depends on how exactly the local tensor is sharded. You’re not doing an operation on the global tensor: you’re chunking the tensor, running the operation on each chunk, and then stacking it back together. The whole point of sharding propagation in global SPMD is to figure out if this is equivalent to running the operation on the full tensor, and there are many cases when it is not.
If you are not thinking carefully about your distributed computation, local SPMD can be a source of bugs. It is common to write distributed code where certain parallelisms are enabled or disabled. If you do a reduction over an axis, if that axis is replicated the result is replicated, but if it is sharded you will end up with a partial reduction that has to be accounted for in some other way. If you forget, the code will work when the parallelism is turned off and silently break when the parallelism is turned on. A bug like this is horrible enough that frameworks invest in ways to deal with situation. In JAX’s shard_map with check_vma=True, a type system detects if you did a reduction on a sharded dimension and then tried to declare it as replicated on the way out, since the varying/invariant type system would notice that the sharding axis is varying across the mesh and thus inconsistent with an out_specs that claims it is replicated. In PyTorch, something like run_check) checks at runtime that tensors you claim are replicated are actually replicated (run_check is horribly inefficient but you can implement other ways to do this more quickly, like with async checksums).
This is perhaps the reason why Megatron is sometimes considered unfriendly for experimentation. Everything is written in local SPMD (as it doesn’t use DTensor), and if you want to experiment on something new you must upfront resolve all of the interactions with parallelism in the implementation of your code. This is all doable, but it can be pretty confusing if you are not a parallelism expert and easy to get wrong.
There is a flip side to this, however: if you are thinking carefully about your parallelism and are carefully orchestrating your local compute with your communications, it is much more natural to write things in local SPMD style. The local SPMD style only gives you operations that can be efficiently computed (since they are always local) and doesn’t force you to say what the global interpretation of a Tensor is (especially when it’s unnatural, like online softmax.) So once you get out of the experimentation phase and are working on efficiency, if you need some nontrivial communication pattern, it would be pretty normal to switch from global SPMD to local SPMD. But there’s also a lot of pedestrian modules that don’t need anything fancy, and it is better to keep them in global SPMD in that case.
In the PyTorch ecosystem, there are some more low level reasons why you might prefer local SPMD over global SPMD. The most prominent is DTensor’s eager overhead. Many parts of DTensor are implemented in Python rather than C++, and on the first invocation we must compute shard propagation rules, which is entirely in Python and quite expensive. It is possible to get reasonable performance with DTensor: if you torch.compile you can eliminate the overhead entirely, CUDA graphs also work, and FSDP2 shows that careful, minimal use of DTensor can still have acceptable CPU overhead. But this is perhaps one of the big reasons why distributed code with plain Tensors remains quite popular today.