ezyang's blog

the arc of software bends towards understanding

Global vs Local SPMD

Global SPMD (also known as the “global view”, exposed by code using DTensor or jax.Array) refers to writing multi-device code as if it was on a single device, with an orthogonal mechanism for expressing how these full tensors are distributed over multiple devices (this mechanism can be implicit or explicit, e.g., as seen in this table).

Local SPMD (also known as “per-device view”, and exposed by local_map and shard_map, and also traditional PyTorch distributed code operating on plain Tensors, e.g., Megatron-style) refers to writing code from the “local” view on a single device, with explicit collectives when communicating across devices.

The big question I want to address in this post is, how do I pick between these two modes? Conventional wisdom in the JAX ecosystem is that you should default to global SPMD (either via auto or explicit sharding mode), and drop down to manual local SPMD if the compiler isn’t giving you the correct communications and you want to do it by hand. I want to give a more nuanced version of this take.

First, there is nothing about global SPMD that precludes fine-grained control over when collectives happen. JAX doesn’t directly support this kind of mode, but it’s not difficult to imagine how it would work: take JAX with explicit mesh axes, but instead of only erroring out on ambiguous communication, error out when any implicit communication happens (e.g., you must explicitly call reshard to trigger communication). We actually added an explicit mode to DTensor along these lines, although it currently doesn’t work super well because we lack some other important aspects of JAX sharding in types.1

For me, the more important difference is that global and local SPMD are actually different semantics. An obvious divergence is that in local SPMD, there isn’t any source of truth about what the “global” view of the Tensor is: the local tensor just exists, you know that there are different versions of it on the other nodes. You don’t know how you’re supposed to stack these tensors together to get a global tensor: you typically only know this at the boundary of the local region using out_specs / out_placements. And even if you knew how to stack the tensors together, local SPMD has different semantics than global SPMD, as the exact computation you perform depends on how exactly the local tensor is sharded. You’re not doing an operation on the global tensor: you’re chunking the tensor, running the operation on each chunk, and then stacking it back together. The whole point of sharding propagation in global SPMD is to figure out if this is equivalent to running the operation on the full tensor, and there are many cases when it is not.

If you are not thinking carefully about your distributed computation, local SPMD can be a source of bugs. It is common to write distributed code where certain parallelisms are enabled or disabled. If you do a reduction over an axis, if that axis is replicated the result is replicated, but if it is sharded you will end up with a partial reduction that has to be accounted for in some other way. If you forget, the code will work when the parallelism is turned off and silently break when the parallelism is turned on. A bug like this is horrible enough that frameworks invest in ways to deal with situation.2

This is perhaps the reason why Megatron is sometimes considered unfriendly for experimentation. Everything is written in local SPMD (as it doesn’t use DTensor), and if you want to experiment on something new you must upfront resolve all of the interactions with parallelism in the implementation of your code. This is all doable, but it can be pretty confusing if you are not a parallelism expert and easy to get wrong.

There is a flip side to this, however: if you are thinking carefully about your parallelism and are carefully orchestrating your local compute with your communications, it is much more natural to write things in local SPMD style. The local SPMD style only gives you operations that can be efficiently computed (since they are always local) and doesn’t force you to say what the global interpretation of a Tensor is (especially when it’s unnatural, like online softmax.) So once you get out of the experimentation phase and are working on efficiency, if you need some nontrivial communication pattern, it would be pretty normal to switch from global SPMD to local SPMD. But there’s also a lot of pedestrian modules that don’t need anything fancy, and it is better to keep them in global SPMD in that case.

In the PyTorch ecosystem, there are some more low level reasons why you might prefer local SPMD over global SPMD. The most prominent is DTensor’s eager overhead. Many parts of DTensor are implemented in Python rather than C++, and on the first invocation we must compute shard propagation rules, which is entirely in Python and quite expensive. It is possible to get reasonable performance with DTensor: if you torch.compile you can eliminate the overhead entirely, CUDA graphs also work, and FSDP2 shows that careful, minimal use of DTensor can still have acceptable CPU overhead. But this is perhaps one of the big reasons why distributed code with plain Tensors remains quite popular today.


  1. Specifically, we don’t have the invariant that the cotangent sharding is directly computable from the primal sharding–DTensor is very much like an “auto” mode in that respect, where the forwards/backwards sharding can be chosen differently. This makes it difficult to rule out implicit redistributes in backwards, since whether or not a redistribute occurs is heavily dependent on the exact details of how sharding has propagated through the backwards graph. ↩︎

  2. In JAX’s shard_map with check_vma=True, a type system detects if you did a reduction on a sharded dimension and then tried to declare it as replicated on the way out, since the varying/invariant type system would notice that the sharding axis is varying across the mesh and thus inconsistent with an out_specs that claims it is replicated. In PyTorch, something like run_check) checks at runtime that tensors you claim are replicated are actually replicated (run_check is horribly inefficient but you can implement other ways to do this more quickly, like with async checksums). ↩︎

Megatron via shard_map

In Computing sharding with einsum, we worked an example of Megatron style tensor parallelism where we discover that the ordinary backwards formula for linear results in a pending reduction on grad_input, even though the input was replicated and no communications happened in forwards. In Megatron, which is implemented with plain Tensors and manual collectives, you just have to know that this reduction is necessary and manually insert it with a custom autograd function.

If we wanted to write a similar explicit-collective-style Megatron implementation in JAX, we would use shard_map. Like in Megatron, you have to call a function which is a no-op in forwards and an all-reduce in backwards. However, JAX has built this function into its core library, and has an interesting name for it: jax.lax.pcast(..., to='varying') (previously known as pvary, although apparently this is deprecated now.) Why is this a “cast”? The answer is that JAX’s shard_map actually comes with an optional type system (enabled with check_vma=True) which reject your program if you forget to insert an all-reduce in your backwards!

Let’s see how this works. As a reminder, our shapes are:

input: [sequence, batch, in_features]
weight: [in_features, out_features]
output: [sequence, batch, out_features]

We can describe the input and output sharding of these in JAX style (tensor-dim oriented), which is what we would feed into the in_specs and out_specs of shard_map:

input: P(None, None, None)
weight: P(None, "tp")
output: P(None, None, "tp")

Although on the boundaries of shard_map we can say exactly what the sharding of the input/output tensors are (e.g., how to reassemble them back into full tensors), on the inside of a shard_map this is not a well-defined question: you only have the local tensors and can do whatever you want with them before you reassemble them back into global tensors with shardings.

However, when check_vma=True, JAX will keep still track of something weaker than sharding: whether or not the tensors are varying (i.e., different) across a mesh dimension. This is conventionally notated as dtype[local_shape]@{varying axes}, e.g., f32[3]{i} means that this tensor varies across the mesh axis i (the braces are omitted when nothing varies). Let’s write down the local shapes of our input/output tensors with variance information (note that the sharded tensor dimensions have shapes that are divided by the mesh they are sharded over):

input: f32[sequence, batch, in_features]  # nothing varying
weight: f32[in_features, out_features/tp]@{tp}  # varies over tp dim
output: f32[sequence, batch, out_features/tp]@{tp}  # varies over tp dim

The point of a type system is that you have typing rules that say whether or not an operation is legal between two types, rejecting programs that are ill-typed. In particular, in jaxpr it’s illegal to have an operation like matrix multiply between two tensors with differing VMA: you have to insert a cast to make the VMAs match before you can do the operation. Actually, JAX will typically insert these casts implicitly for you, but for clarity we’re going to insert the cast explicitly here:

import os
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=2"

import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec as P
from jax import shard_map

devices = jax.devices()
mesh = Mesh(devices, axis_names=("tp",))
sequence, batch, in_features, out_features = 4, 2, 8, 16

input = jnp.ones((sequence, batch, in_features))
# NB: the full weight, shard_map will automatically shard weight for us given the
# input spec
weight = jnp.ones((in_features, out_features))

def colwise_linear(input, weight):
    print('input', jax.typeof(input))
    print('weight', jax.typeof(weight))
    input = jax.lax.pcast(input, "tp", to="varying")
    print('pcast_input', jax.typeof(input))
    output = jnp.einsum("sbi,io->sbo", input, weight)
    print('output', jax.typeof(output))
    return output

sharded_linear = shard_map(
    colwise_linear,
    mesh=mesh,
    in_specs=(P(None, None, None), P(None, "tp")),
    out_specs=P(None, None, "tp"),
)

output = sharded_linear(input, weight)

This prints:

input float32[4,2,8]
weight float32[8,8]{V:tp}
pcast_input float32[4,2,8]{V:tp}
output float32[4,2,8]{V:tp}

It’s a little difficult to see why the program “doesn’t typecheck” when you omit the pcast, because even if you don’t pcast JAX will implicitly insert it for you, but you can verify that the cast happens anyway by inspecting the HLO with and without the explicit pcast (this is left as an exercise to the reader to verify; an LLM can one-shot this program transformation). The type system here is also not that invasive: local operations simply propagate variance (varying to varying, invariant to invariant), and you only need a small menagerie of collective operations to help you manage collectives which can take you between varying and invariant.

Although JAX’s type system here is a bit difficult to explain from first principles, it seems quite beneficial as it makes it impossible to forget backwards reductions that are required when you do operations between sharded and replicated tensors. We’re hoping to bring a similar capability to PyTorch DTensor in the near future, introducing a new “LTensor” subclass that is able to track metadata along these lines.

Computing sharding with einsum

Mental arithmetic in grade school (e.g., memorizing your times tables) is typically justified on the grounds that facility in basic calculations makes it easier to focus on higher-level problems that require being able to do these manipulations. When working on DTensor, I have also found it important to be able to quickly calculate what shardings you get when you do matrix multiplies on sharded tensors. Without being able to do this quickly and accurately, working through examples becomes a slog. I’ve also found that while diagrammatic approaches (e.g., drawing a matrix and slicing it into shards) are intuitive, they are slow and unwieldy to do calculations with.

Recently, I’ve found that working on sharding with einsum is nice and efficient, and I hope to persuade you to do it this way when you need to reason about sharding! This post somewhat overlaps with Sharded Matrices and How to Multiply Them, but with some different emphasis and some different notation.

Einsum primer

Einstein summation is a compact way of representing many multi-dimensional linear algebra operations, including matrix multiplies. It is nice because you don’t have to puzzle through the abstruse differences of matrix multiply operations like @, torch.matmul, torch.bmm, torch.mm: for any “matrix multiply”, as long as you know the input and output shapes of your tensor, you can directly write out an einsum equation. For example, classic matrix multiply as you see it in math has a signature like mm(x: f32[A, B], y: f32[B, C]) -> f32[A, C]. In einsum notation, you would simply write torch.einsum("ij,jk->ik", x, y): each of the indices lines up exactly with the input sizes. As another example, in nn.Linear, your weight has shape (in_features, out_features). You don’t have to remember how to setup the transposition, just write torch.einsum("bi,oi->bo", input, weight).

A useful piece of terminology that pops up for einsum is a contraction dimension. This is any index that appears in the input tensors but not the output tensors. The ones that show up in both inputs and outputs are free dimensions: if the free dimension is in all inputs it’s a batch dimension, and if it’s missing from some inputs we will broadcast those tensors.

Einsum backwards

Do you always forget how exactly you should transpose your tensors in the backward formula for matrix multiply? As long as you aren’t doing weird things in your einsum (e.g., no repeated indices, every input index is paired with another index), there is a very simple way to compute backwards: keep every input constant except the one you want to compute the gradient for, and swap its index set with the output index set.

For example, linear is "bi,oi->bo" for (input, weight -> output). Then we have:

grad_input  = torch.einsum("bo,oi->bi", grad_output, weight)
grad_weight = torch.einsum("bi,bo->oi", input, grad_output)

Intuitively, the reason this works is because reverse-mode AD actually is just transposing the linear function defined by our einsum, and transposed matrix multiplies can be implemented by just reading off its shapes.

Einsum sharding

Now that we’re thinking in terms of einsum formulas, all we need is the sharding rule for einsum. The sharding rule tells us under what situations we can perform a matrix multiply by simply doing matrix multiplies on the local shards, producing the output matrix under some output placement.

There are not too many rules. Take a running example "abi,aoi->abo", we can write down these valid placements for a particular mesh dimension (I’ve replaced numeric dim indices with the einsum character index for readability):

  1. If everything is replicated, the output is replicated: Replicate(), Replicate() -> Replicate()
  2. If a batch dimension is sharded, the output batch dimension is also sharded: Shard("a"), Shard("a") -> Shard("a")
  3. If a free dimension is sharded, the output free dimension is sharded, but any broadcasted input must be replicated: Shard("b"), Replicate() -> Shard("b")
  4. If a contraction dimension is sharded, we will have a pending reduction: Shard("i"), Shard("i") -> Partial()

You can look at Computation With Sharded Arrays for a more detailed explanation for each of these cases.

Worked example: Tensor parallelism

In 2019, Xiaolin Li asked this question about CopyToModelParallelRegion in Megatron:

Why the backward function of _CopyToModelParallelRegion calls reduce fuction? Can somebody share the mathematical proof?

Let’s answer Xiaolin’s question. In Megatron, ColumnParallelLinear is defined as:

input: [sequence, batch, in_features]
weight: [in_features, out_features]
output: [sequence, batch, out_features]

In einsum notation, this is torch.einsum("sbi,io->sbo", input, weight).

On the TP mesh dimension, we have this sharding:

input: Replicate()
weight: Shard("out_features")
output: Shard("out_features")

Let us assume that grad_output: Shard("out_features"). Let’s compute the placements of grad_weight and grad_input. First the derivative formulas:

grad_input = torch.einsum("sbo,io->sbi", grad_output, weight)
grad_weight = torch.einsum("sbi,sbo->io", input, grad_output)

So we see:

grad_input: Partial()  # o is sharded and a contraction dim
grad_weight: Shard("out_features")  # o is sharded and a free dim

We see that grad_input has a pending reduction, and if downstream backwards is expecting to receive replicated tensors, we must trigger an all-reduce (e.g., in Megatron this all-reduce is manually triggered by _CopyToModelParallelRegion; if you use DTensor, it will just propagate the Partial() until a redistribution to Replicate() is required.)

Worked example: Sequence parallel with a replicated scaling factor

In sequence parallel, we will shard the sequence dimension of an input, but not the weight. Let’s say we have a learnable scaling factor:

input: [sequence, batch, hidden]
weight: [hidden]
output: [sequence, batch, hidden]

In einsum notation, this is torch.einsum("sbh,h->sbh", input, weight).

On the SP mesh dimension, we have this sharding:

input: Shard("sequence")
weight: Replicate()
output: Shard("sequence")

Then we have:

grad_input = torch.einsum("sbh,h->sbh", grad_output, weight)
grad_weight = torch.einsum("sbh,sbh->h", input, grad_output)

So we see:

grad_input: Shard("sequence")  # s is sharded and a free dim
grad_weight: Partial()  # s is sharded and a contraction dim

Here, we must do an all-reduce over grad_weight to get the true replicated gradient.

Notice that this example is very similar to the tensor parallelism one, but the roles of input and weight have been swapped!

Hugo Migration

This blog has lived on WordPress since it was initially created during a social challenge at MIT to write a blog post a week or pay up with beer. I remember a very important piece of advice I had been given at that time: don’t fuck around with your blog authoring software, just do the minimum viable thing (use WordPress) and focus on writing posts.

It’s 2026 now, the world is different, and in particular the existence of coding agents means that this particular advice falls flat now: it has never been easier to vibe code your own blog software and be done in an afternoon of token generation. Similarly, over the years, I had been increasingly unhappy about my WordPress setup (too hard to add images, ancient version of WordPress, Markdown has taken over the world why am I still writing in ReST, I love scripts.mit.edu but I definitely don’t want to use it to host serious things). So I typed this into ChatGPT and Claude and asked it what I should migrate too.

I currently have a Wordpress blog whose 633 posts are written in ReST using rest-wordpress with some manual code edits, and a theme based on Ashley that I also customized. I’d like to migrate to another blogging solution. I care a lot about ensuring the URLs are preserved. To a lesser extent, I also care about the comments, although I’m willing to compromise here (e.g., an offline flow where I have to explicitly publish comments might be OK; I know static site is difficult to support comments; I also know that email newsletter is popular and I’d like to support this modality if possible. I don’t use a WYSIWYG editor. It’s on Wordpress 5.1.19. It would be nice to have a way for people. Some more niche things plugins I’ve used is WP LaTeX and Share a Draft but I’m willing to do a lossy conversion if necessary (I don’t use LaTeX that much now; it’s just important to make sure the old posts still format correctly). Many of my posts have images and I’d like an easier flow than my current flow (where I have to manually upload my images to my server and then hyperlink them into the post). What do you recommend?

It suggested Hugo, which I had played around with before in AI Blindspots, and I figured, “Why not, I’ll just ask Claude to do the migration. A few hours later, two pro sessions worth of tokens and some PHP export scripts, the entire blog was moved over, no muss, no fuss. I live streamed a portion of this migration process although there’s nothing that special about it.

I actually wasn’t going to write a blog post about this, but I saw Jeff Geerling’s blog also had made frontpage Hacker News. I too haven’t figured out how I am going to solve the comments problem on the new format; I also think I will figure out how to get an email newsletter going from the blog. Here’s to seeing if this can encourage you to use LLMs to make the jump for your own personal site!

The gap between a Helpful Assistant and a Senior Engineer

Let’s suppose you asked an AI coding agent to “implement a CLI calculator”. Imagine if, instead of only writing short Python script, it also started building an automated test suite, a crash reporting mechanism and a telemetry subsystem. You’d be like, “What the fuck is this?”

But now let’s say that you were planning to release this project to users. It would be clearly negligent to not have an automated test suite. A crash reporting mechanism might be overkill for a simple calculator, but for more complicated CLIs interacting with the real world, it may not always be feasible to have reproducer, in which case crash logs are essential. Similarly, a telemetry subsystem would be wildly inappropriate for an open source local-only calculator, but it could make sense for a networked application or a corporate tool of all consenting users. One of the important functions of a senior engineer is to be able to evaluate the context a software project lives in and figure out if we need to do something, even if it isn’t explicitly asked for. This is contrast to a helpful assistant, who is first and foremost obligated to follow the user’s instructions. This leads to a gap between a Helpful Assistant and a Senior Engineer.

In principle, you could prompt the LLM agent to act like a Senior Engineer. In fact, why stop at Senior, let’s tell the LLM to be a Staff Engineer! Imagine that scaling continues: what would you expect the LLM to do when instructed to act in this way? Well, imagine a human L7 engineer who has just been hired by a big tech company to head up some big, new, multi-year initiative. Will they say, “Sure, I can help with that!” and start busily coding away? Of course not: they will go out and start reviewing code, reading docs, talking to people, asking questions, shadowing oncalls, doing small starter tasks–they will start by going out and building context. Here, the “helpful assistant” frame for LLMs is limiting: sure, Claude might ask you a few questions to clarify the task upfront, but if your coding agent starts asking you about “relevant partner teams” and “org-wide priorities for this half” you are definitely going to raise an eyebrow.

What would take for an LLM to be able to act like a Senior Engineer?

  • Perhaps prompting is all you need, and you just need to write enough information about the surrounding context for a project, and once you feed in enough tokens, a smart model can infer the rest of the details you didn’t explicitly right down. This context would be bespoke for every project; you would have to redo this exercise every time you had a new project!

  • Perhaps you can instead prompt a model on how to operate agentically to get the context it needs. This prompt here might be more reusable. But the model may need to actually do wetwork (e.g., talk to humans) to get all of the information it needs. And remember the old saying: the more generic the advice is, the less useful it is. Specificity is king, which leads to…

  • Let’s say we solve continual learning. Instead of crafting the perfect prompt upfront; you could just drop the model as an “embodied” software developer. It reads code, talks to people, does projects, and in doing so slowly develops its latent context, in the same way a human engineer does. Building context will often be bottlenecked in the same way humans are: you can’t get experience related to pushing a feature to production, until you’ve actually pushed the feature to production (however long that takes).

But just like how you shouldn’t micromanage a Senior Engineer, all of these approaches involve fundamentally different expectations about what an AI coding agent should do, and so even if a model and scaffold are capable of doing these things, it is altogether another question if it will be asked to behave in this way. So let’s not take it as a foregone conclusion that METR task times will keep following the empirical trendline: I expect a phase transition when the context an LLM needs to do a good job exceeds the capability of scaffolding to provide on the fly.

Code review as human alignment, in the era of LLMs

I’ve recently been doing a lot of both submitting and reviewing pull requests to PyTorch that were authored with substantial LLM assistance. This is a big difference from earlier this year, where it was clear LLMs worked well for greenfield projects but the code was too hopelessly sloppy for a production codebase. Here are my merged PRs that mention claude code in their description; Jason Ansel has also had a similar experience (Meta only link, here is the list of issues he referenced in his writeup). There already has been increasing discourse (Simon Willison, LLVM) on how code review should adapt to this new era of LLMs. My contribution to this discourse is this: within teams, code review should change to being primarily be a human alignment mechanism.

Here is a simple example: it is well known that LLMs are prone to generating overly defensive code: e.g., they will be constantly sprinkling try...catch everywhere or testing if a variable is some type when system invariants imply that it should always be that type. If someone sends me a PR with these problems, I am not commenting on these problems solely because I want them to be fixed. If that’s all I cared about, I could have just fed my comments directly to claude code. The real problem is that the human who was operating the LLM didn’t agree with me that this defensive code was bad, and the point of the review is to align them with me on what is overly defensive versus not. In the most trivial cases, maybe the engineer didn’t read the LLM output, in which case the remedy is to make them actually read the code. But sometimes real human work has to happen; for example, maybe there is a global system invariant that one has to understand to know if the defensiveness is necessary or not. If we agree about the global system invariants, there’s no reason the code review has to go through me: the original code author can just instruct the LLM to fix problems and keep me out of the loop until they have aligned the LLM output to themselves–at which point we should do the more expensive human to human alignment. The ideal is that I don’t need to ever write review comments about mechanical problems, because they have already been fixed by the original author ahead of time.

Conversely, when I am putting up an LLM generated PR for human review, I am trying to transmit higher level information. How does the new code work? What do I need to know about the existing system to understand this code? This doesn’t even have to be in the PR description: if the LLM proposes a fix that I myself don’t understand, or seems difficult to understand, I will simply instruct it to try it a different way, until the resulting diff is obviously correct. Tokens are cheap: we should expect more out of the author of code, because the cost of generating these PRs has gone way down. Similarly, I am willing to throw out the code and start again; you don’t have to feel bad about wasting my time (I didn’t type it! I spent my time understanding the problem, and none of that is regretted.)

There is a lot of scaremongering about how engineers who don’t pick up AI tools will be left behind. My take on this is that there a number of different skills that make up what it means to be a good software engineer, and it is clear that LLM coding, even today, is clearly reweighting the relative importance of these skills. I care a lot more about your ability to read code, reason about the big picture, communicate clearly and to have good taste, than I care about your ability to mechanically write code. There is an archetype of junior engineer who is not that good at coding but very good at the softer, higher level skills, and I think they will be very valuable in this new world order. Conversely, I think going forward I will have substantially less patience if I have to keep telling you the same things over and over, because I just don’t value raw “ability to code” as much anymore. My ideal state is like that with long time senior teammates: I can trust that they have made good low level decisions, and I can focus on understanding the bigger picture and updating my mental model of how the system works.

Today’s LLMs have no memory: they have to rediscover everything in the system from first principles every time they are run. The purpose of the humans, of the team, is to collectively maintain a shared vision of what, platonically, the system should do. I want code review to reconfigure itself around this purpose.

Learning to love mesh-oriented sharding

Famously, PyTorch and JAX don’t agree on how shardings should be represented: PyTorch takes a mesh-dim oriented view, where for each dimension in your device mesh, you specify what sharding should be applied; JAX takes a tensor-dim oriented view, where for each dimension on your tensor, you say which mesh dimensions (potentially multiple!) shard it. Among my Twitter followers, it is generally agreed that the JAX formulation is more intuitive from a user perspective. OK, fine; if you prefer one representation over another, it’s easy enough to translate between the two representations (in easy situations, at least!) In this post, I want to talk more about the framework implementation side: what is the better internal representation of sharding? I don’t claim to have all the answers, but my motivation for writing this post is to help explain where I currently stand and how I evaluate proposals for evolving DTensor and sharding in PyTorch.

Closed versus open. I am going to make a precise technical claim: JAX sharding is closed, where as PyTorch sharding is (in principle) open. Here, what I mean by closed/open refers to the capability for users to extend a system: traditional ADTs are closed (you can’t add another constructor to an ADT), whereas object-oriented classes are open (you can define a new subclass of a class). Now, technically JAX sharding is open: the jax.sharding.Sharding is a base class that is intended to be subclassed, but to do this you have to define things like _to_xla_hlo_sharding, which is as good as not being supported. The regular class everyone uses, NamedSharding, consists of a mesh and a tuple of mesh axes, with no obvious extension points. I also offer for the defense this unanswered forum post: https://github.com/jax-ml/jax/discussions/23703

In contrast, PyTorch sharding is in principle extensible: the sharding is expressed as a list of Placement, a class which is subclassed to define custom shardings. The extensibility of Placement isn’t really well supported (for example, there’s no way of conveniently adding extra rules for placements to sharding rules), but it works enough that both internally and externally there are implementations of weird placements (internally, StridedShard and NormPartial… and technically all of the non-sum reductions supported by Partial as well as uneven sharding; externally, see RaggedShard and InterleavedShard).

Why does mesh-dim oriented sharding support extensibility in this way? The key is that mesh-oriented sharding is very imperative in nature: you can think of the list of placements as a sequence of transformations you apply to the tensor from left-to-right. Concretely, given the current local tensor (as produced by all of the placements you handled for the mesh dims before the one you’re currently processing), run an invertible function to split this tensor along the current mesh dimension. This gives you a bunch of new local tensors which you recursively continue sharding with the rest of the mesh dims. The invertibility of the function is the only real constraint on what function you can provide (since you need to be able to reassemble the shards back into the original full tensor), but otherwise your choice of function is unconstrained. It is in this sense that Placement is morally extensible.

When designing systems, it is not an unambiguous good to make the system more flexible. Closed systems like JAX’s mean you don’t have to worry about hilariously complicated situations like what if you unevenly shard on the same dimension multiple times (do you have any guarantees on the local sizes of tensors being somewhat balanced?) But sometimes, the use case demands a greater degree of expressivity (in the same way that manual memory management allows you to do more than you can conveniently do in a GC’ed language.)

How expressive does Sharding have to be? One of the primary value propositions of DTensor is that it specifies a standardized representation for saying how a tensor is sharded across your cluster. It’s very good to have this information, because it prevents accidents, like forgetting that a tensor dimension is sharded so you do a reduction on that dimension without first doing a collective and you get subtly wrong results that take weeks to debug. It’s better to have a system that is correct but slow, than it is to have a system that is fast but incorrect.

Being able to express all distributed states is not a terminal goal. There are lots of situations in distributed optimizations where you temporarily need to put the system in a state where it is very difficult to describe exactly how to interpret data across nodes. For example, when you implement ring attention, to avoid communications when performing softmax, you instead perform online softmax. It’s quite difficult to say what the “placements” of the running quantities in online softmax are. In this case, we shouldn’t overly stress ourselves with defining a placement: we should just use local_map or shard_map and absolve ourselves of needing to actually say exactly how data is laid out at any given point in time. But the key is that we should only do this in local regions of code; if we give up and local_map our entire model, we might as well have just not written our code with DTensor at all. So we should seek additional expressivity when it is needed to express how data is being communicated across system boundaries.

Here are some classic examples from LLM training where you need a little bit of extra expressivity, starting with simple cases and becoming more complicated:

  1. Suppose you are applying FSDP to a model, where the parameter sizes haven’t been chosen with parallelism in mind; and in particular, the size of your cluster doesn’t evenly divide with the parameter count. It can be convenient to allow for an uneven sharding to happen in this case, so that the user doesn’t have to manually take care of padding out their tensor so that it can be allgathered.
  2. Say you do a matrix multiply between two tensors which are sharded on the contraction dimension. A reduction is required to communicate the local results into the final tensor. However, sometimes, it can be profitable to delay this reduction, since it can be coalesced into a later reduction. This requires the ability to express that a tensor has a pending reduction on some mesh axis.
  3. If you have both FSDP and row-wise TP, if your FSDP implementation naively shards on the first dim of your weight tensor, you need to ensure that the TP sharding occurs before the FSDP sharding (so that when you undo your FSDP sharding, you have the expected TP sharding ready to go for TP.) This requires the ability to express the order of sharding in a non-standard order (right-to-left, as is supported by list of mesh axes aka PartitionSpec), OR the ability to express that the FSDP is a weird “strided” shard where you don’t have contiguous data, instead you have stripes of data that will then be further sharded by the TP sharding.
  4. Suppose you have a tensor going into a matrix multiply which is not sharded on batch (because you’re naughty and you’re following the standard PyTorch convention of not actually expressing batch-sharding in DTensor) but is sharded on sequence. If you want to treat both batch and sequence as “batch” for the matmul, in PyTorch, this typically requires flattening these two dimensions into a single flat batch dimension. However, this cannot be done, as there is no native Placement that can represent a flattened (Replicate, Shard); however, this works with StridedShard (or InterleavedShard, which is the same thing.) More generally, it is extremely irritating that DTensors cannot reliably have view operations applied to them (that would be supported on plain tensors), and you need weird shard types to be able to handle many view operations.
  5. Traditional FSDP2 assumes that there’s not any requirement for how parameters are placed on nodes; but things like block-wise quantization and structure-aware optimizers need the ability to place a block/parameter on a single device, so that you have access to all the data you need. This won’t be a standard sharding; the shards will be ragged.

I think it’s a worthy use of complexity budget to search for a system design that can handle all of these things, especially since PyTorch’s existing mesh-oriented sharding is already tantalizingly close to supporting this.

Why is adding a new Placement to PyTorch hard? I tend to think, fundamentally, that a mesh-oriented sharding strategy can support arbitrary Placement subclasses. So why does this not work so well in PyTorch? I think there really only are two issues:

  • There is no official way to retroactively add extra sharding propagation rules to existing operators. What I always tell people is that a sharding propagation rule is simply a mathematical equality, saying that map(f, in_placement(x)) == out_placement(f(x)). Mathematical equalities are infinitely compositional: you can always add more true equalities to your system without compromising correctness. But there isn’t actually a way to do this.
  • Many sharding propagation rules are written assuming only shard exists. Placement provides an is_shard method to test if something is a shard (as opposed to replicate/partial), and sharding propagation rules often assume that if this is True, you specifically have a standard, even Shard, as if it was the only sharding in the universe. This means that rules are often secretly buggy when custom Placements are added. StridedShard, in particular, naughtily advertises that it is_shard(), which means that we will happily allow for it to contract with a plain Shard, leading to this bug: https://github.com/pytorch/pytorch/issues/166598 Now, to be clear; often rules WILL work for arbitrary sharding subclasses; for example, if an input dimension is a batch dimension, it doesn’t matter how the data is sliced up, your operation is functorial over that dimension. Will Constable has been working on refactoring our sharding rules to distinguish between the “it’s a batch dimension” situation versus the “I specifically need an even sharding” or “I need these shardings on two inputs to be exactly the same kind of sharding.”

I think with these two issues fixed, and a bit of careful design on what the overrideable API on Placement is for subclasses, I think we can have a very good extensibility story for shardings.

Draw high dimensional tensors as a matrix of matrices

I have recently needed to draw the contents of high-dimensional (e.g., 4D and up) tensors where it is important to ensure that is clear how to identify each of the dimensions in the representation. Common strategies I’ve seen people do in this situation include printing a giant list 2D slices (what the default PyTorch printer will do) or flattening the Tensor in some way back down to a 2D tensor. However, if you have a lot of horizontal space, there is a strategy that I like that makes it easy to identify all the axes of the higher dimensional tensor: draw it as a matrix of matrices.

Here are some examples, including the easy up-to-2D cases for completeness.

0D: torch.arange(1).view() :

0

1D: torch.arange(2) :

0  1

2D: torch.arange(4).view(2, 2 ) :

0  1
2  3

3D: torch.arange(8).view(2, 2, 2) :

0  1    4  5
2  3    6  7

4D: torch.arange(16).view(2, 2, 2, 2) :

0  1    4  5
2  3    6  7

8  9   12 13
10 11   14 15

5D: torch.arange(32).view(2, 2, 2, 2, 2):

0  1    4  5  :  16 17   20 21
2  3    6  7  :  18 19   22 23
              :
8  9   12 13  :  24 25   28 29
10 11   14 15  :  26 27   30 31

The idea is that every time you add a new dimension, you alternate between stacking the lower dimension matrices horizontally and vertically. You always stack horizontally before stacking vertically, to follow the standard row-major convention for printing in the 2D case. Dimensions always proceed along the x and y axis, but the higher dimensions (smaller dim numbers) involve skipping over blocks. For example, a “row” on dim 3 in the 4D tensor is [0, 1] but the “row” on dim 1 is [0, 4] (we skip over to the next block.) The fractal nature of the construction means we can keep repeating the process for as many dimensions as we like.

In fact, for the special case when every size in the tensor is 2, the generated sequence of indices form a Morton curve. But I don’t call it that, since I couldn’t find a popular name for the variation of the Morton curve where the radix of each digit in the coordinate representation can vary.

Knowledge check. For the 4D tensor of size (2, 2, 2, 2) arranged in this way, draw the line(s) that would split the tensor into the pieces that torch.split(x, 1, dim), for each possible dimension 0, 1, 2 and 3. Answer under the fold.

.

.

.

.

.

.

.

.

.

.

.

.

.

.

.

dim=0

>>> [x.reshape(-1) for x in torch.arange(16).view(2,2,2,2).split(1,dim=0)]
[tensor([0, 1, 2, 3, 4, 5, 6, 7]), tensor([ 8, 9, 10, 11, 12, 13, 14, 15])]

     0  1    4  5
     2  3    6  7
   ----------------
     8  9   12 13
    10 11   14 15


dim=1

>>> [x.reshape(-1) for x in torch.arange(16).view(2,2,2,2).split(1,dim=1)]
[tensor([ 0, 1, 2, 3, 8, 9, 10, 11]), tensor([ 4, 5, 6, 7, 12, 13, 14, 15])]

     0  1 |  4  5
     2  3 |  6  7
          |
     8  9 | 12 13
    10 11 | 14 15

dim=2

>>> [x.reshape(-1) for x in torch.arange(16).view(2,2,2,2).split(1,dim=2)]
[tensor([ 0, 1, 4, 5, 8, 9, 12, 13]), tensor([ 2, 3, 6, 7, 10, 11, 14, 15])]

     0  1    4  5
   ------- -------
     2  3    6  7

     8  9   12 13
   ------- -------
    10 11   14 15

dim=3

>>> [x.reshape(-1) for x in torch.arange(16).view(2,2,2,2).split(1,dim=3)]
[tensor([ 0, 2, 4, 6, 8, 10, 12, 14]), tensor([ 1, 3, 5, 7, 9, 11, 13, 15])]

     0 |  1    4 |  5
     2 |  3    6 |  7

     8 |  9   12 | 13
    10 | 11   14 | 15

So you want to control flow in PT2

With contributions from Richard Zou.

PT2’s dominant internal representation, FX graphs, do not directly support control flow (if statements, while loops): they only represent straight-line basic blocks. Most of our graph capture mechanisms are tracing based (fx.symbolic_trace, make_fx, Dynamo), which means that we expect to be able to linearize all conditionals we encounter into a straight line program. Sometimes, you want to work with code that has control flow while working the compiler stack. There is no silver bullet, instead there are a lot of different options with different tradeoffs.

Regional compilation

We have a perfectly good general purpose language that supports control flow: Python. To handle control flow, compile only regions/submodules of your program that have no internal control flow, and then string them together with a standard Python control flow constructs. PT2 compiled regions are compositional with non-compiled regions, “it works.”

Pro:

  • Simple: requires no major model changes
  • Universal: it always works (including data dependent flow, calling into third-party libraries, making an HTTP request, anything!)

Cons:

  • You will not get a full graph this way; you will only get graphs for each region. In particular, you will not be able to do truly global optimizations, nor will you be able to serialize a self-contained Python-less representation of the entire model
  • It can sometimes be inconvenient to structure your program so all the regions you want are compilable. Suppose you have this call graph between modules: A -> B -> C. C is compileable; A is compileable except for its call to B, which is what does the control flow. It’s easy to compile C, but you can’t directly compile A, as it has a B-shaped bit that can’t be compiled. What to do? If you split A so it is pipelined as A1, B, A2, you can then compile A1 and A2, but not B. Dynamo also supports “graph breaks” to automatically perform this split for you, in which case you just disable compilation on B, but graph break generated graphs can be difficult to reason about as the inputs to A2 are implicitly inferred.

Link: Reducing torch.compile cold start compilation time with regional compilation

Multiple graphs dispatched with guards

When the control flow is controlled by arguments that are known ahead of time (no data-dependent), you can also compile at the top level and get the flattened straight-line program for the particular branching you had in this case. Because Dynamo is a symbolic bytecode interpreter, it can automatically determine what inputs were used as part of control flow, and generate guards to validate that we would take the same paths again. If those values change, we will recompile the program at the new values. We dispatch between all the different unrollings of the program we have generated.

Pros:

  • Simple: requires no major model changes
  • You get a full graph for a particular unrolling of loops / conditionals, so global optimizations are possible

Cons:

  • Doesn’t work with data-dependent shapes.
  • You will end up with a graph for every unrolling; for example, if you have a loop that ranges from 1 to 32, you will end up with 32 different graphs. This will increase compile time.

Black box via custom operator

An FX graph just calls operators. The operator internally can have whatever control flow in them they want. So you can always black box a problematic region of your model into an operator and preserve compilation for everything else.

Pros:

  • You get a single, full graph that works for all possible branches

Cons:

  • A custom operator only supports inputs/outputs that fall inside our type system, which means you can only pass simple types like Tensor, int, bool (or pytree-able containers containing these things). There is some in progress work to relax this to allow more opaque types.
  • You have to explicitly declare all the inputs/outputs for the custom operator. This can be tiresome if the black boxed region represents a Module, since all the parameters also have to be directly passed in as well. The larger the region you black box, the bigger the arguments are.
  • You don’t actually get to see the inside of the custom operator from the outside graph, so no optimization over both inside and outside of the custom operator is possible. (Of course, you can always special case this operator in a pass on the outer graph.)
  • There are some bugs related to doing another torch.compile region inside of a custom operator, although these are workaroundable: https://github.com/pytorch/pytorch/issues/151328

Conditional operators / Unroll to max iterations

Do you really, really need a conditional? If you’re doing an if-branch, can you instead rewrite it so that you run both branches and torch.where dispatch to the results? If you’re doing a while-loop, can you unroll it to the max number of iterations and rely on dynamic shapes to cause it to no-op when you’re done and running extra iterations. Basically, this option is to rewrite your model so it doesn’t have Python-level control flow anymore (the conditional can either be done host or GPU side).

Pros:

  • You get a single, full graph that works for all possible branches
  • You are able to optimize inside and outside of the control flow

Cons:

  • You have to rewrite your model
  • For unrolling, if you are close to being CPU-dispatch bound, unrolling and running with zero size could push you over the brink (as zero size dispatches are still not free)
  • For conditional operators, unconditionally both branches increases the compute you need to do, which can be bad if you are compute-bound.

Control flow HOP

torch has special structured control flow operators that avoid unrolling large loops or needing to execute both branches of a control flow statement. If you’re familiar with JAX, these are very similar to the JAX equivalents. They have specific constraints that allow them to be directly compilable by torch.compile. For example, torch.cond accepts two functions (a true_fn and a false_fn) for the two branches and requires that outputs of each function must have the same properties (e.g. shape, dtype).

So far, we have the following “higher-order” operators (HOPs):

These are relatively new, have been used in torch.export for inference, but have not been battle tested for training or performance.

The semantics of these control flow operators are as follows:

def cond(pred, true_branch, false_branch, operands):
    if pred:
        return true_branch(*operands)
    else:
        return false_branch(*operands)

def while_loop(cond_fn, body_fn, carried_inputs):
    val = carried_inputs
    while cond_fn(*val):
        val = body_fn(*val)
    return val

def scan(combine_fn, init, xs, length=None):
    carry = init
    ys = []
    for x in xs:
        carry, y = f(carry, x)
        ys.append(y)
    return carry, np.stack(ys)

Pros:

  • You get a single, full graph that works for all possible branches
  • You are able to optimize inside and outside of the control flow

Cons:

  • You have to rewrite your model.
  • The control flow HOPs are structured: they have specific constraints on the functions (true_fn, false_fn (cond) or body_fn (while_loop)) that can be passed to them. One such constraint is that these functions may not mutate any of their inputs. This may make rewrites difficult because you have to think about code in a “functional”, JAX-like way.
  • Still WIP and they have some quirks especially for training. For example, the backward pass of torch.scan currently requires re-computing the forward pass (instead of just saving intermediates from each iteration of scan).

CFG over FX graphs

If FX graphs give you basic blocks, you can use them as building blocks for a language that does support conditionals, stringing them together with basic blocks. In fact, Helion, a kernel DSL language, does exactly this, as it is common to need to directly write data-dependent conditionals and loops when writing kernels (it otherwise uses all PyTorch API functions, similar to conventional FX graphs). To do this, you would need to write your own Python frontend that parses Python directly to generate the CFG. TorchScript also does this, but TorchScript frontend is unmaintained and we don’t recommend using it (and it also doesn’t generate FX graphs by default.)

Pros:

  • You get a single graph that works for all possible branches
  • You are able to optimize inside and outside of control flow
  • In principle, you can write exactly the control flow you want

Cons:

  • You have to write the frontend, we don’t have one ready for you (TorchScript is not it, you’re princess is in another castle)
  • If your language looks too much like Python and too general purpose, prepare to get on the endless treadmill of feature requests for adding “just one more Python feature” (can we have lists? dataclasses? etc etc) in the frontend (it is more tractable for Helion, as it’s not a general purpose language.)

The Parallelism Mesh Zoo

When training large scale LLMs, there is a large assortment of parallelization strategies which you can employ to scale your training runs to work on more GPUs. There are already a number of good resources for understanding how to parallelize your models: I particularly recommend How To Scale Your Model and The Ultra-Scale Playbook. The purpose of this blog post is to discuss parallelization strategies in a more schematic way by focusing only on how they affect your device mesh. The device mesh is an abstraction used by both PyTorch and JAX that takes your GPUs (however many of them you’ve got in your cluster!) and organizes them into a N-D tensor that expresses how the devices communicate with each other. When we parallelize computation, we shard a tensor along one dimension of the mesh, and then do collectives along that dimension when there are nontrivial dependencies between shards. Being able to explain why a device mesh is set up the way it is for a collection of parallelization strategies is a good check for seeing if you understand how the parallelization strategies work in the first place! (Credit: This post was influenced by Visualizing 6D Mesh Parallelism.)

tl;dr

  • DP, FSDP: ["dp"]
  • HSDP: ["dp_replicate", "dp_shard"]
  • DP+TP, DP+TP+SP: ["dp", "tp"]
  • DP+UlyssesSP: ["dp", "sp"] (verl)
  • DP+CP: ["dp", "cp"]
  • DP+CP+TP: ["dp", "cp", "tp"]
  • PP+DP+…: ["pp", "dp", ...] (torchtitan), ["dp", "pp", ...] (Megatron)
  • PP+DP+CP+TP+EP: ["pp", "dp_replicate", "dp_shard_mod_ep", "dp_shard_in_ep", "cp", "tp"] (torchtitan)

Prologue: Why device mesh? Before we jump into the zoo, why do we have multi-dimensional meshes in the first place? One intuition is that the dimensions of the device mesh are a reflection of the physical constraints of networking between GPUs (there’s a reason why all of the scaling books talk extensively about how the networking for GPUs works; you can’t reason about what parallelization strategy you should use without knowing about this!) Let’s imagine you have 1024 NVIDIA GPUs. You don’t want to treat this 1024 GPUs as an undifferentiated blob of GPUs. Physically, these GPUs are grouped into nodes of eight which have much faster NVLink connections compared to cross-node communication which is done on a slower Infiniband connection. Intuitively, you will want to do something different depending on if you’re doing intra-node communication or inter-node communication.

The device mesh imposes structure on this collection of GPUs. A mesh is typically specified as a tensor size (e.g., (128, 8)) as well as string axis names ala named tensor (e.g., ["dp", "tp"]), and is simply an N-D tensor over a range of GPU indices (typically [0, 1, 2, 3, ...] for GPUs, and a mostly ascending but occasionally permuted sequence for TPUs). We typically think of 2D and 3D tensors as grids and cubes, but I find it is more helpful (especially in higher dimensions) to think of the device mesh as imposing some self-similar (fractal) structure on the GPUs. In the simplest 2D mesh that accounts for intra versus inter node communication, GPUs are first organized into nodes on the inner-most dimension, and then the nodes are collected together in the outer-most dimension to form the cluster. (The self-similar nature of the nodes is important because it tells us how communication occurs across the cluster: to communicate over the outer-most mesh dimension, all the GPU 0s on each node talk to each other, all the GPU 1s, etc.) This is only the very simplest mesh we can create, however; with more complicated parallelization strategies we may impose extra levels of structure, e.g., we may organize nodes into pods of two and four, or we might further divide the eight GPUs of a single node. In other words, the mesh tells us about which GPUs communicate to which other GPUs. This is important to know, because when I want to parallelize our model, I am making choices about how to shard tensors across my GPUs. The mesh tells me which GPUs have the other shards of my tensor; in other words, they are who I have to communicate with when I am doing a computation that requires information about the full tensor and cannot be done with the local shards only.

In the zoo, when we talk about a parallelism strategy, we will talk to how it typically relates to other parallelization strategies in the model, and the device mesh will tell us if it is orthogonal to other parallelisms (a new dimension), multiplexed with another strategy (a reused dimension) or perhaps a completely different hierarchy of communication (multiple meshes in the same model that don’t factor into the other).

Without further ado, here is the zoo!

Data parallelism (DP). Data parallelism predates the concept of device meshes, since you don’t actually need any nontrivial mesh structure to do data parallelism: if you are only doing data parallel, you just shard your input on the batch axis for however many devices you have. This sharding propagates through forwards and backwards until you allreduce to compute the final global gradient for a parameter. If you did make a 1D device mesh (this is useful to think about, because most higher dimensional parallelisms will include some form of data parallelism), you’d probably name your mesh ["dp"], ["ddp"] or perhaps ["batch"].

Let’s talk briefly about how people tend to name device mesh axes. In the PyTorch world, it’s most common to name the axis after the parallelism that it is responsible, so either “dp” or “ddp” (you really shouldn’t call it ddp, but the DataParallel taboo in PyTorch is very real!) The batch name is common in JAX, and is very natural there because when you annotate the sharding of your input, you need to say for each dimension tensor what mesh dim it is sharded over. So when you shard the batch dimension over the batch mesh dim, it looks just like you’re labeling the batch dimension of your tensor as batch, e.g., P("batch", None). (This situation doesn’t happen in PyTorch because shardings of a tensor are specified per device mesh dim, but that’s a story for another day!)

Fully-sharded data parallel (FSDP). This is best understood as an augmentation over DP where weights are also sharded over all GPUs and you just all-gather weights before performing operations (and reduce-scatter in backwards). Because this all-gather is also among all devices, you don’t need another axes in your mesh, and your mesh might also be called ["dp"] in this case, even though you’re actually doing FSDP. Occasionally, you’ll see people name their mesh ["fsdp"] in this case.

Hybrid sharded data parallel (HSDP). HSDP is an extension of FSDP where you shard weights (FSDP) up to the point where you can’t actually do a giant all-gather/reduce-scatter over every GPU, and then replicate these shards to cover the rest of your cluster (DP). It’s also amenable to fault tolerance techniques that make the modeling assumption that it’s OK to lose samples of your batch if a replica fails (you won’t model this with device mesh though!). This is probably the first time you will encounter a 2D device mesh (indeed, the DeviceMesh tutorial in PyTorch specifically uses hybrid sharding as its motivating example), since HSDP doesn’t require any extra model changes on top of FSDP. There are a few common ways to name the mesh axes for HSDP. One way to think about it is that it is FSDP on the inner dimension and DP on the outer dimension, in which case you would say ["dp", "fsdp"]. Another way is to think about what happens to parameters at the various layers of the mesh: the inner dimension shards, while the outer dimension replicates, so you would say ["replicate", "shard"] or perhaps ["dp_replicate", "dp_shard"] to make it clear that you are still doing data parallelism across both of these device mesh dims (in particular, when you split your batches, you split on both the dp_replicate and dp_shard dims–although, to get the final gradients, you can do the reduction hierarchically by first doing a reduce-scatter on “dp_shard” and then doing an allreduce on “dp_replicate”).

Tensor parallelism (TP). Depending on who you ask, tensor parallelism is either about letting you reduce your effective batch size for training or moving you towards reducing the memory usage of activations in your model. In the “reduce effective batch size” framing, the idea behind TP is that you can only scale up DP until your cluster is as large as your batch size. From a modeling perspective, it can be undesirable to have a batch size that is too large, so you can’t just keep increasing your batch size to get more parallelism. Instead, TP allows us to get some extra scaling by sharding over the feature dimension of our matrix multiplies1 (you can shard over either the columns or the rows of your weight matrix, so we will frequently specify if a TP Linear is column-wise or row-wise; in attention, column-wise linear effectively parallelizes the attention computation over attention heads). The communication needed to do TP is fairly exposed (unless you’re doing async tensor parallel), so you typically want to keep the communications for it within a single node. This leads to this classic 2D device mesh for DP+TP: ["dp", "tp"] (or, if you’re a JAXer, you might write ["batch", "model"], where model is used to indicate the inner feature dimension of the model weights being parallelized over.) When someone says 2D parallelism, they’re usually referring to this combo of parallelisms (although I do not recommend using this term–as you can see, it is obviously ambiguous!) Note that tp is the inner mesh dimension, since it benefits the most from the high bandwidth network between GPUs on a single node.

You don’t have to stop with DP+TP, however. If you’re using FSDP with tensor parallelism (remember, “dp” can mean FSDP!), intra-node TP doesn’t improve the amount of inter-node FSDP communication you have to do: however much TP you do, within one TP node you only have one slice of the model and have to talk to everyone else to get their slices. You could solve this by expanding TP to also cross nodes, but in practice mixed intra/inter-node collectives are a lot slower than pure inter-node collectives. This limits the scaling you can get from TP, and so if you’re still hitting limits on FSDP, it can still be useful to apply HSDP to avoid running collectives that are too large. In that case, you’d end up with a mesh like ["dp_replicate", "dp_shard", "tp"].

Sequence parallelism (SP). For this section, we specifically take the definition of sequence parallelism from the Ultrascale Playbook (as distinguished from context parallelism). Although we said that TP is the first step towards reducing the memory usage of activations2, if you literally implement DP+TP based on my descriptions above, you will still end up with more memory spent on activations than you want because there are still parts of the model around the FFN like the LayerNorm need the full hidden dimension to compute mean and variance3. To reduce the memory usage in these segments, you need to shard on something else. So typically what you will see is that the model will alternate between TP (hidden dimension is sharded) and SP (sequence dimension is sharded). Consequently, if you look at the device mesh for a model using DP+TP+SP, it will typically still look like ["dp", "tp"], and instead the tp dimension is multiplexed to be used both for TP and SP. Because TP and SP never occur at the same time, you don’t need a separate dimension for them.

Ulysses sequence parallelism. Ulysses sequence parallelism from DeepSpeed Ulysses is another sequence parallelism strategy that is implemented by verl (because verl is forked so often, it shows up quite prominently if you are looking for examples of init_device_mesh on GitHub code search). It aims to alleviate memory pressure from extremely long sequences, so sequences are sharded on input, and only when attention needs to be computed is an alltoall issued to re-shard on the attention heads rather than the sequence (doing another alltoall to restore the sequence sharding after the attention is done). Importantly, this means it competes with TP for sharding on the attention heads, which is why you also see people use it to replace TP in MoE models, since it has much less communication than TP (at the cost of having to replicate the attention weights). In verl, you will just see a device mesh ["dp", "sp"] when you are using their FSDP backend (which is what supports Ulysses).

Context parallelism (CP). Context parallelism is another form of “sequence” parallelism. Like Ulysses sequence parallelism, sequences are sharded on input; the difference, however, is instead of using an alltoall to re-shard on attention heads, you just do a (distributed) attention on the entire context. You can do this the easy way by just using allgather to get the full context (as was done in llama4) or you can use a fancy kernel like ring attention, which carefully overlaps communication and computation when performing attention. A popular implementation of context parallelism lives in Megatron, which doesn’t directly use PyTorch’s native DeviceMesh abstraction but has an analogous HyperCommGrid. The mesh we see here will be something like ["dp", "cp"] or more commonly ["dp", "cp", "tp"]. Notice that we can have a dedicated mesh dim for CP: CP operates very similarly to SP outside of the attention calls (as it is just plain data parallelism when there is no cross-token dependency), but because it never shards on attention heads, it doesn’t compete with TP and can be used completely orthogonally to TP (TP shards hidden, CP shards sequence).

CP has a pretty interesting interaction with FSDP. Both DP and CP shard the input data (on batch and sequence respectively). It’s pretty common when you do FSDP to just shard over both “dp” (“dp_shard” in HSDP) and “cp”. In torchtitan, we create a flattened mesh dim “dp_shard_cp” specifically for FSDP sharding (a flattened mesh dim is what happens if you take your mess and “forget” about some of the structure; e.g., if you were to do an all-gather, you just all-gather over all the flattened axes). In the HSDP world, “dp_cp” is still a useful concept because this is the combination of axes you want to all-reduce over to, e.g., compute the global average loss.

Pipeline parallelism (PP). Pipeline parallelism is kind of an ugly duckling and people tend to hate on it because you have to rewrite your models to introduce pipeline stages, and you can’t really use things like DTensor with it (unless you do really strange things like how the GSPMD paper “supports” pipeline parallelism–the general consensus is automatic parallelism does not like PP). PP still goes in the device mesh, because it affects how you are organizing your GPUs, but, for example, torchtitan solely uses it to setup PGs for doing the point-to-point communications. I’ve seen both ["dp", "pp", ...] or ["pp", "dp", ...] for meshes with PP, but the order probably doesn’t make too much of a difference as you are likely solidly inter-node at this point. Pipeline parallelism bandwidth use is very low, and latency can be covered up as you can immediately start processing the next batch after triggering an asynchronous send of the previous batch.

Expert parallelism (EP). EP is its own kettle of fish. Expert parallelism only applies over the expert computation of the model, but within this region, we are not sharding parameters as FSDP conventionally sees it: we will commonly have the entire expert’s weights on our node. torchtitan’s WIP expert parallelism implementation, when it has ALL parallelisms on, would look like ["pp", "dp_replicate", "dp_shard_mod_ep", "dp_shard_in_ep", "cp", "tp"], where dp_shard has been split into two mesh dimensions (DP shard modulo EP, and DP shard in EP). dp_shard_mod_ep is conventionally one, but when it is not it represents further FSDP-style sharding of expert weights inside of the expert region (there’s some complication here if you have shared experts along-side your EP-sharded experts). But then dp_shard_in_ep, cp and optionally tp are combined together to give you the expert parallel dimension. It’s actually more intuitive to imagine that you have two distinct meshes: ["pp", "dp_replicate", "dp_shard", "cp", "tp"] and ["pp", "dp_shard_mod_ep", "ep", "tp"]. The keen-eyed may also notice that there is no intrinsic reason the tp mesh size inside and outside of the expert parallel region, but this is not easily done if you have to have a single global device mesh for everything. In fact, there is a WIP PR to have two meshes, one for inside the expert region and one for outside: https://github.com/pytorch/torchtitan/pull/1660

Conclusion. The general concept behind mesh parallelism is that you can compose parallelization strategies without too much fuss. Indeed, the use of, e.g., TP to improve scaling is precisely because it lets you cover your device space without having to expand DP beyond the batch size you want to do. However, as you can see from these concrete examples, it’s not always quite as simple as just stacking all of the parallelisms together one on top of each other. In the end, all the device mesh is doing is creating PGs behind groups of devices as defined by the mesh, so if you want some weird setup where you’re swapping between two device meshes, PyTorch’s general philosophy has been to say, have fun!

Thanks to Horace He, Tianyu Liu and Natalia Gimelshein for helping fact check this post. Any remaining errors are mine!


  1. One more subtlety I want to point out: while we tend to think of TP as sharding the feature dimension of parameters, when we “propagate” this sharding through the network, other intermediate tensors end up getting sharded on the TP dimension as well. In particular, in a transformer block, you will typically have a column-wise linear followed by a row-wise linear, and the intermediate activation will be temporarily sharded on the TP dimension before the row-wise linear runs. ↩︎

  2. I am very carefully using “activation memory” here and not total memory, because total memory usage (what you actually care about) is also a function of peak memory usage, which is subject to transient peaks such as when FSDP does an all-gather to collect parameters. In fact, even without SP, TP will improve your peak memory usage, because unlike FSDP, it’s not necessary to all-gather the full weight matrix to actually perform the matrix multiply. TPs peak memory usage occurs when it all-gathers activations. ↩︎

  3. You will get a little improvement between the column-wise and row-wise linear, since the activations there are sharded. You can turn this into a big improvement by using selective activation checkpointing and forcing recomputation of activations that aren’t sharded! (Plain activation checkpointing tends not to work so well because of the all-gather of the activations.) ↩︎