When training large scale LLMs, there is a large assortment of parallelization strategies which you can employ to scale your training runs to work on more GPUs. There are already a number of good resources for understanding how to parallelize your models: I particularly recommend How To Scale Your Model and The Ultra-Scale Playbook. The purpose of this blog post is to discuss parallelization strategies in a more schematic way by focusing only on how they affect your device mesh. The device mesh is an abstraction used by both PyTorch and JAX that takes your GPUs (however many of them you've got in your cluster!) and organizes them into a N-D tensor that expresses how the devices communicate with each other. When we parallelize computation, we shard a tensor along one dimension of the mesh, and then do collectives along that dimension when there are nontrivial dependencies between shards. Being able to explain why a device mesh is set up the way it is for a collection of parallelization strategies is a good check for seeing if you understand how the parallelization strategies work in the first place! (Credit: This post was influenced by Visualizing 6D Mesh Parallelism.)
tl;dr
- DP, FSDP: ["dp"]
- HSDP: ["dp_replicate", "dp_shard"]
- DP+TP, DP+TP+SP: ["dp", "tp"]
- DP+UlyssesSP: ["dp", "sp"] (verl)
- DP+CP: ["dp", "cp"]
- DP+CP+TP: ["dp", "cp", "tp"]
- PP+DP+...: ["pp", "dp", ...] (torchtitan), ["dp", "pp", ...] (Megatron)
- PP+DP+CP+TP+EP: ["pp", "dp_replicate", "dp_shard_mod_ep", "dp_shard_in_ep", "cp", "tp"] (torchtitan)
Prologue: Why device mesh? Before we jump into the zoo, why do we have multi-dimensional meshes in the first place? One intuition is that the dimensions of the device mesh are a reflection of the physical constraints of networking between GPUs (there's a reason why all of the scaling books talk extensively about how the networking for GPUs works; you can't reason about what parallelization strategy you should use without knowing about this!) Let's imagine you have 1024 NVIDIA GPUs. You don't want to treat this 1024 GPUs as an undifferentiated blob of GPUs. Physically, these GPUs are grouped into nodes of eight which have much faster NVLink connections compared to cross-node communication which is done on a slower Infiniband connection. Intuitively, you will want to do something different depending on if you're doing intra-node communication or inter-node communication.
The device mesh imposes structure on this collection of GPUs. A mesh is typically specified as a tensor size (e.g., (128, 8)) as well as string axis names ala named tensor (e.g., ["dp", "tp"]), and is simply an N-D tensor over a range of GPU indices (typically [0, 1, 2, 3, ...] for GPUs, and a mostly ascending but occasionally permuted sequence for TPUs). We typically think of 2D and 3D tensors as grids and cubes, but I find it is more helpful (especially in higher dimensions) to think of the device mesh as imposing some self-similar (fractal) structure on the GPUs. In the simplest 2D mesh that accounts for intra versus inter node communication, GPUs are first organized into nodes on the inner-most dimension, and then the nodes are collected together in the outer-most dimension to form the cluster. (The self-similar nature of the nodes is important because it tells us how communication occurs across the cluster: to communicate over the outer-most mesh dimension, all the GPU 0s on each node talk to each other, all the GPU 1s, etc.) This is only the very simplest mesh we can create, however; with more complicated parallelization strategies we may impose extra levels of structure, e.g., we may organize nodes into pods of two and four, or we might further divide the eight GPUs of a single node. In other words, the mesh tells us about which GPUs communicate to which other GPUs. This is important to know, because when I want to parallelize our model, I am making choices about how to shard tensors across my GPUs. The mesh tells me which GPUs have the other shards of my tensor; in other words, they are who I have to communicate with when I am doing a computation that requires information about the full tensor and cannot be done with the local shards only.
In the zoo, when we talk about a parallelism strategy, we will talk to how it typically relates to other parallelization strategies in the model, and the device mesh will tell us if it is orthogonal to other parallelisms (a new dimension), multiplexed with another strategy (a reused dimension) or perhaps a completely different hierarchy of communication (multiple meshes in the same model that don't factor into the other).
Without further ado, here is the zoo!
Data parallelism (DP). Data parallelism predates the concept of device meshes, since you don't actually need any nontrivial mesh structure to do data parallelism: if you are only doing data parallel, you just shard your input on the batch axis for however many devices you have. This sharding propagates through forwards and backwards until you allreduce to compute the final global gradient for a parameter. If you did make a 1D device mesh (this is useful to think about, because most higher dimensional parallelisms will include some form of data parallelism), you'd probably name your mesh ["dp"], ["ddp"] or perhaps ["batch"].
Let's talk briefly about how people tend to name device mesh axes. In the PyTorch world, it's most common to name the axis after the parallelism that it is responsible, so either "dp" or "ddp" (you really shouldn't call it ddp, but the DataParallel taboo in PyTorch is very real!) The batch name is common in JAX, and is very natural there because when you annotate the sharding of your input, you need to say for each dimension tensor what mesh dim it is sharded over. So when you shard the batch dimension over the batch mesh dim, it looks just like you're labeling the batch dimension of your tensor as batch, e.g., P("batch", None). (This situation doesn't happen in PyTorch because shardings of a tensor are specified per device mesh dim, but that's a story for another day!)
Fully-sharded data parallel (FSDP). This is best understood as an augmentation over DP where weights are also sharded over all GPUs and you just all-gather weights before performing operations (and reduce-scatter in backwards). Because this all-gather is also among all devices, you don't need another axes in your mesh, and your mesh might also be called ["dp"] in this case, even though you're actually doing FSDP. Occasionally, you'll see people name their mesh ["fsdp"] in this case.
Hybrid sharded data parallel (HSDP). HSDP is an extension of FSDP where you shard weights (FSDP) up to the point where you can't actually do a giant all-gather/reduce-scatter over every GPU, and then replicate these shards to cover the rest of your cluster (DP). It's also amenable to fault tolerance techniques that make the modeling assumption that it's OK to lose samples of your batch if a replica fails (you won't model this with device mesh though!). This is probably the first time you will encounter a 2D device mesh (indeed, the DeviceMesh tutorial in PyTorch specifically uses hybrid sharding as its motivating example), since HSDP doesn't require any extra model changes on top of FSDP. There are a few common ways to name the mesh axes for HSDP. One way to think about it is that it is FSDP on the inner dimension and DP on the outer dimension, in which case you would say ["dp", "fsdp"]. Another way is to think about what happens to parameters at the various layers of the mesh: the inner dimension shards, while the outer dimension replicates, so you would say ["replicate", "shard"] or perhaps ["dp_replicate", "dp_shard"] to make it clear that you are still doing data parallelism across both of these device mesh dims (in particular, when you split your batches, you split on both the dp_replicate and dp_shard dims--although, to get the final gradients, you can do the reduction hierarchically by first doing a reduce-scatter on "dp_shard" and then doing an allreduce on "dp_replicate").
Tensor parallelism (TP). Depending on who you ask, tensor parallelism is either about letting you reduce your effective batch size for training or moving you towards reducing the memory usage of activations in your model. In the "reduce effective batch size" framing, the idea behind TP is that you can only scale up DP until your cluster is as large as your batch size. From a modeling perspective, it can be undesirable to have a batch size that is too large, so you can't just keep increasing your batch size to get more parallelism. Instead, TP allows us to get some extra scaling by sharding over the feature dimension of our matrix multiplies (you can shard over either the columns or the rows of your weight matrix, so we will frequently specify if a TP Linear is column-wise or row-wise; in attention, column-wise linear effectively parallelizes the attention computation over attention heads). The communication needed to do TP is fairly exposed (unless you're doing async tensor parallel), so you typically want to keep the communications for it within a single node. This leads to this classic 2D device mesh for DP+TP: ["dp", "tp"] (or, if you're a JAXer, you might write ["batch", "model"], where model is used to indicate the inner feature dimension of the model weights being parallelized over.) When someone says 2D parallelism, they're usually referring to this combo of parallelisms (although I do not recommend using this term--as you can see, it is obviously ambiguous!) Note that tp is the inner mesh dimension, since it benefits the most from the high bandwidth network between GPUs on a single node.
You don't have to stop with DP+TP, however. If you're using FSDP with tensor parallelism (remember, "dp" can mean FSDP!), intra-node TP doesn't improve the amount of inter-node FSDP communication you have to do: however much TP you do, within one TP node you only have one slice of the model and have to talk to everyone else to get their slices. You could solve this by expanding TP to also cross nodes, but in practice mixed intra/inter-node collectives are a lot slower than pure inter-node collectives. This limits the scaling you can get from TP, and so if you're still hitting limits on FSDP, it can still be useful to apply HSDP to avoid running collectives that are too large. In that case, you'd end up with a mesh like ["dp_replicate", "dp_shard", "tp"].
Sequence parallelism (SP). For this section, we specifically take the definition of sequence parallelism from the Ultrascale Playbook (as distinguished from context parallelism). Although we said that TP is the first step towards reducing the memory usage of activations , if you literally implement DP+TP based on my descriptions above, you will still end up with more memory spent on activations than you want because there are still parts of the model around the FFN like the LayerNorm need the full hidden dimension to compute mean and variance . To reduce the memory usage in these segments, you need to shard on something else. So typically what you will see is that the model will alternate between TP (hidden dimension is sharded) and SP (sequence dimension is sharded). Consequently, if you look at the device mesh for a model using DP+TP+SP, it will typically still look like ["dp", "tp"], and instead the tp dimension is multiplexed to be used both for TP and SP. Because TP and SP never occur at the same time, you don't need a separate dimension for them.
Ulysses sequence parallelism. Ulysses sequence parallelism from DeepSpeed Ulysses is another sequence parallelism strategy that is implemented by verl (because verl is forked so often, it shows up quite prominently if you are looking for examples of init_device_mesh on GitHub code search). It aims to alleviate memory pressure from extremely long sequences, so sequences are sharded on input, and only when attention needs to be computed is an alltoall issued to re-shard on the attention heads rather than the sequence (doing another alltoall to restore the sequence sharding after the attention is done). Importantly, this means it competes with TP for sharding on the attention heads, which is why you also see people use it to replace TP in MoE models, since it has much less communication than TP (at the cost of having to replicate the attention weights). In verl, you will just see a device mesh ["dp", "sp"] when you are using their FSDP backend (which is what supports Ulysses).
Context parallelism (CP). Context parallelism is another form of "sequence" parallelism. Like Ulysses sequence parallelism, sequences are sharded on input; the difference, however, is instead of using an alltoall to re-shard on attention heads, you just do a (distributed) attention on the entire context. You can do this the easy way by just using allgather to get the full context (as was done in llama4) or you can use a fancy kernel like ring attention, which carefully overlaps communication and computation when performing attention. A popular implementation of context parallelism lives in Megatron, which doesn't directly use PyTorch's native DeviceMesh abstraction but has an analogous HyperCommGrid. The mesh we see here will be something like ["dp", "cp"] or more commonly ["dp", "cp", "tp"]. Notice that we can have a dedicated mesh dim for CP: CP operates very similarly to SP outside of the attention calls (as it is just plain data parallelism when there is no cross-token dependency), but because it never shards on attention heads, it doesn't compete with TP and can be used completely orthogonally to TP (TP shards hidden, CP shards sequence).
CP has a pretty interesting interaction with FSDP. Both DP and CP shard the input data (on batch and sequence respectively). It's pretty common when you do FSDP to just shard over both "dp" ("dp_shard" in HSDP) and "cp". In torchtitan, we create a flattened mesh dim "dp_shard_cp" specifically for FSDP sharding (a flattened mesh dim is what happens if you take your mess and "forget" about some of the structure; e.g., if you were to do an all-gather, you just all-gather over all the flattened axes). In the HSDP world, "dp_cp" is still a useful concept because this is the combination of axes you want to all-reduce over to, e.g., compute the global average loss.
Pipeline parallelism (PP). Pipeline parallelism is kind of an ugly duckling and people tend to hate on it because you have to rewrite your models to introduce pipeline stages, and you can't really use things like DTensor with it (unless you do really strange things like how the GSPMD paper "supports" pipeline parallelism--the general consensus is automatic parallelism does not like PP). PP still goes in the device mesh, because it affects how you are organizing your GPUs, but, for example, torchtitan solely uses it to setup PGs for doing the point-to-point communications. I've seen both ["dp", "pp", ...] or ["pp", "dp", ...] for meshes with PP, but the order probably doesn't make too much of a difference as you are likely solidly inter-node at this point. Pipeline parallelism bandwidth use is very low, and latency can be covered up as you can immediately start processing the next batch after triggering an asynchronous send of the previous batch.
Expert parallelism (EP). EP is its own kettle of fish. Expert parallelism only applies over the expert computation of the model, but within this region, we are not sharding parameters as FSDP conventionally sees it: we will commonly have the entire expert's weights on our node. torchtitan's WIP expert parallelism implementation, when it has ALL parallelisms on, would look like ["pp", "dp_replicate", "dp_shard_mod_ep", "dp_shard_in_ep", "cp", "tp"], where dp_shard has been split into two mesh dimensions (DP shard modulo EP, and DP shard in EP). dp_shard_mod_ep is conventionally one, but when it is not it represents further FSDP-style sharding of expert weights inside of the expert region (there's some complication here if you have shared experts along-side your EP-sharded experts). But then dp_shard_in_ep, cp and optionally tp are combined together to give you the expert parallel dimension. It's actually more intuitive to imagine that you have two distinct meshes: ["pp", "dp_replicate", "dp_shard", "cp", "tp"] and ["pp", "dp_shard_mod_ep", "ep", "tp"]. The keen-eyed may also notice that there is no intrinsic reason the tp mesh size inside and outside of the expert parallel region, but this is not easily done if you have to have a single global device mesh for everything. In fact, there is a WIP PR to have two meshes, one for inside the expert region and one for outside: https://github.com/pytorch/torchtitan/pull/1660
Conclusion. The general concept behind mesh parallelism is that you can compose parallelization strategies without too much fuss. Indeed, the use of, e.g., TP to improve scaling is precisely because it lets you cover your device space without having to expand DP beyond the batch size you want to do. However, as you can see from these concrete examples, it's not always quite as simple as just stacking all of the parallelisms together one on top of each other. In the end, all the device mesh is doing is creating PGs behind groups of devices as defined by the mesh, so if you want some weird setup where you're swapping between two device meshes, PyTorch's general philosophy has been to say, have fun!
Thanks to Horace He, Tianyu Liu and Natalia Gimelshein for helping fact check this post. Any remaining errors are mine!