ezyang’s blog

the arc of software bends towards understanding

Draw high dimensional tensors as a matrix of matrices

I have recently needed to draw the contents of high-dimensional (e.g., 4D and up) tensors where it is important to ensure that is clear how to identify each of the dimensions in the representation. Common strategies I've seen people do in this situation include printing a giant list 2D slices (what the default PyTorch printer will do) or flattening the Tensor in some way back down to a 2D tensor. However, if you have a lot of horizontal space, there is a strategy that I like that makes it easy to identify all the axes of the higher dimensional tensor: draw it as a matrix of matrices.

Here are some examples, including the easy up-to-2D cases for completeness.

0D: torch.arange(1).view()

0

1D: torch.arange(2)

0  1

2D: torch.arange(4).view(2, 2 )

0  1
2  3

3D: torch.arange(8).view(2, 2, 2)

0  1    4  5
2  3    6  7

4D: torch.arange(16).view(2, 2, 2, 2)

 0  1    4  5
 2  3    6  7

 8  9   12 13
10 11   14 15

5D: torch.arange(32).view(2, 2, 2, 2, 2):

 0  1    4  5  :  16 17   20 21
 2  3    6  7  :  18 19   22 23
               :
 8  9   12 13  :  24 25   28 29
10 11   14 15  :  26 27   30 31

The idea is that every time you add a new dimension, you alternate between stacking the lower dimension matrices horizontally and vertically. You always stack horizontally before stacking vertically, to follow the standard row-major convention for printing in the 2D case. Dimensions always proceed along the x and y axis, but the higher dimensions (smaller dim numbers) involve skipping over blocks. For example, a "row" on dim 3 in the 4D tensor is [0, 1] but the "row" on dim 1 is [0, 4] (we skip over to the next block.) The fractal nature of the construction means we can keep repeating the process for as many dimensions as we like.

In fact, for the special case when every size in the tensor is 2, the generated sequence of indices form a Morton curve. But I don't call it that, since I couldn't find a popular name for the variation of the Morton curve where the radix of each digit in the coordinate representation can vary.

Knowledge check. For the 4D tensor of size (2, 2, 2, 2) arranged in this way, draw the line(s) that would split the tensor into the pieces that torch.split(x, 1, dim), for each possible dimension 0, 1, 2 and 3. Answer under the fold.

.

.

.

.

.

.

.

.

.

.

.

.

.

.

.

dim=0

>>> [x.reshape(-1) for x in torch.arange(16).view(2,2,2,2).split(1,dim=0)]
[tensor([0, 1, 2, 3, 4, 5, 6, 7]), tensor([ 8, 9, 10, 11, 12, 13, 14, 15])]

     0  1    4  5
     2  3    6  7
   ----------------
     8  9   12 13
    10 11   14 15


dim=1

>>> [x.reshape(-1) for x in torch.arange(16).view(2,2,2,2).split(1,dim=1)]
[tensor([ 0, 1, 2, 3, 8, 9, 10, 11]), tensor([ 4, 5, 6, 7, 12, 13, 14, 15])]

     0  1 |  4  5
     2  3 |  6  7
          |
     8  9 | 12 13
    10 11 | 14 15

dim=2

>>> [x.reshape(-1) for x in torch.arange(16).view(2,2,2,2).split(1,dim=2)]
[tensor([ 0, 1, 4, 5, 8, 9, 12, 13]), tensor([ 2, 3, 6, 7, 10, 11, 14, 15])]

     0  1    4  5
   ------- -------
     2  3    6  7

     8  9   12 13
   ------- -------
    10 11   14 15

dim=3

>>> [x.reshape(-1) for x in torch.arange(16).view(2,2,2,2).split(1,dim=3)]
[tensor([ 0, 2, 4, 6, 8, 10, 12, 14]), tensor([ 1, 3, 5, 7, 9, 11, 13, 15])]

     0 |  1    4 |  5
     2 |  3    6 |  7

     8 |  9   12 | 13
    10 | 11   14 | 15
  • October 25, 2025

So you want to control flow in PT2

With contributions from Richard Zou.

PT2’s dominant internal representation, FX graphs, do not directly support control flow (if statements, while loops): they only represent straight-line basic blocks. Most of our graph capture mechanisms are tracing based (fx.symbolic_trace, make_fx, Dynamo), which means that we expect to be able to linearize all conditionals we encounter into a straight line program. Sometimes, you want to work with code that has control flow while working the compiler stack. There is no silver bullet, instead there are a lot of different options with different tradeoffs.

Regional compilation

We have a perfectly good general purpose language that supports control flow: Python. To handle control flow, compile only regions/submodules of your program that have no internal control flow, and then string them together with a standard Python control flow constructs. PT2 compiled regions are compositional with non-compiled regions, “it works.”

Pro:

  • Simple: requires no major model changes
  • Universal: it always works (including data dependent flow, calling into third-party libraries, making an HTTP request, anything!)

Cons:

  • You will not get a full graph this way; you will only get graphs for each region. In particular, you will not be able to do truly global optimizations, nor will you be able to serialize a self-contained Python-less representation of the entire model
  • It can sometimes be inconvenient to structure your program so all the regions you want are compilable. Suppose you have this call graph between modules: A -> B -> C. C is compileable; A is compileable except for its call to B, which is what does the control flow. It’s easy to compile C, but you can’t directly compile A, as it has a B-shaped bit that can’t be compiled. What to do? If you split A so it is pipelined as A1, B, A2, you can then compile A1 and A2, but not B. Dynamo also supports “graph breaks” to automatically perform this split for you, in which case you just disable compilation on B, but graph break generated graphs can be difficult to reason about as the inputs to A2 are implicitly inferred.

Link: Reducing torch.compile cold start compilation time with regional compilation

Multiple graphs dispatched with guards

When the control flow is controlled by arguments that are known ahead of time (no data-dependent), you can also compile at the top level and get the flattened straight-line program for the particular branching you had in this case. Because Dynamo is a symbolic bytecode interpreter, it can automatically determine what inputs were used as part of control flow, and generate guards to validate that we would take the same paths again. If those values change, we will recompile the program at the new values. We dispatch between all the different unrollings of the program we have generated.

Pros:

  • Simple: requires no major model changes
  • You get a full graph for a particular unrolling of loops / conditionals, so global optimizations are possible

Cons:

  • Doesn’t work with data-dependent shapes.
  • You will end up with a graph for every unrolling; for example, if you have a loop that ranges from 1 to 32, you will end up with 32 different graphs. This will increase compile time.

Black box via custom operator

An FX graph just calls operators. The operator internally can have whatever control flow in them they want. So you can always black box a problematic region of your model into an operator and preserve compilation for everything else.

Pros:

  • You get a single, full graph that works for all possible branches

Cons:

  • A custom operator only supports inputs/outputs that fall inside our type system, which means you can only pass simple types like Tensor, int, bool (or pytree-able containers containing these things). There is some in progress work to relax this to allow more opaque types.
  • You have to explicitly declare all the inputs/outputs for the custom operator. This can be tiresome if the black boxed region represents a Module, since all the parameters also have to be directly passed in as well. The larger the region you black box, the bigger the arguments are.
  • You don’t actually get to see the inside of the custom operator from the outside graph, so no optimization over both inside and outside of the custom operator is possible. (Of course, you can always special case this operator in a pass on the outer graph.)
  • There are some bugs related to doing another torch.compile region inside of a custom operator, although these are workaroundable: https://github.com/pytorch/pytorch/issues/151328

Conditional operators / Unroll to max iterations

Do you really, really need a conditional? If you’re doing an if-branch, can you instead rewrite it so that you run both branches and torch.where dispatch to the results? If you’re doing a while-loop, can you unroll it to the max number of iterations and rely on dynamic shapes to cause it to no-op when you’re done and running extra iterations. Basically, this option is to rewrite your model so it doesn’t have Python-level control flow anymore (the conditional can either be done host or GPU side).

Pros:

  • You get a single, full graph that works for all possible branches
  • You are able to optimize inside and outside of the control flow

Cons:

  • You have to rewrite your model
  • For unrolling, if you are close to being CPU-dispatch bound, unrolling and running with zero size could push you over the brink (as zero size dispatches are still not free)
  • For conditional operators, unconditionally both branches increases the compute you need to do, which can be bad if you are compute-bound.

Control flow HOP

torch has special structured control flow operators that avoid unrolling large loops or needing to execute both branches of a control flow statement. If you’re familiar with JAX, these are very similar to the JAX equivalents. They have specific constraints that allow them to be directly compilable by torch.compile. For example, torch.cond accepts two functions (a true_fn and a false_fn) for the two branches and requires that outputs of each function must have the same properties (e.g. shape, dtype).

So far, we have the following “higher-order” operators (HOPs):

These are relatively new, have been used in torch.export for inference, but have not been battle tested for training or performance.

The semantics of these control flow operators are as follows:

def cond(pred, true_branch, false_branch, operands):
    if pred:
        return true_branch(*operands)
    else:
        return false_branch(*operands)

def while_loop(cond_fn, body_fn, carried_inputs):
    val = carried_inputs
    while cond_fn(*val):
        val = body_fn(*val)
    return val

def scan(combine_fn, init, xs, length=None):
    carry = init
    ys = []
    for x in xs:
        carry, y = f(carry, x)
        ys.append(y)
    return carry, np.stack(ys)

Pros:

  • You get a single, full graph that works for all possible branches
  • You are able to optimize inside and outside of the control flow

Cons:

  • You have to rewrite your model.
  • The control flow HOPs are structured: they have specific constraints on the functions (true_fn, false_fn (cond) or body_fn (while_loop)) that can be passed to them. One such constraint is that these functions may not mutate any of their inputs. This may make rewrites difficult because you have to think about code in a “functional”, JAX-like way.
  • Still WIP and they have some quirks especially for training. For example, the backward pass of torch.scan currently requires re-computing the forward pass (instead of just saving intermediates from each iteration of scan).

CFG over FX graphs

If FX graphs give you basic blocks, you can use them as building blocks for a language that does support conditionals, stringing them together with basic blocks. In fact, Helion, a kernel DSL language, does exactly this, as it is common to need to directly write data-dependent conditionals and loops when writing kernels (it otherwise uses all PyTorch API functions, similar to conventional FX graphs). To do this, you would need to write your own Python frontend that parses Python directly to generate the CFG. TorchScript also does this, but TorchScript frontend is unmaintained and we don’t recommend using it (and it also doesn’t generate FX graphs by default.)

Pros:

  • You get a single graph that works for all possible branches
  • You are able to optimize inside and outside of control flow
  • In principle, you can write exactly the control flow you want

Cons:

  • You have to write the frontend, we don’t have one ready for you (TorchScript is not it, you’re princess is in another castle)
  • If your language looks too much like Python and too general purpose, prepare to get on the endless treadmill of feature requests for adding “just one more Python feature” (can we have lists? dataclasses? etc etc) in the frontend (it is more tractable for Helion, as it’s not a general purpose language.)
  • September 5, 2025

The Parallelism Mesh Zoo

When training large scale LLMs, there is a large assortment of parallelization strategies which you can employ to scale your training runs to work on more GPUs. There are already a number of good resources for understanding how to parallelize your models: I particularly recommend How To Scale Your Model and The Ultra-Scale Playbook. The purpose of this blog post is to discuss parallelization strategies in a more schematic way by focusing only on how they affect your device mesh. The device mesh is an abstraction used by both PyTorch and JAX that takes your GPUs (however many of them you've got in your cluster!) and organizes them into a N-D tensor that expresses how the devices communicate with each other. When we parallelize computation, we shard a tensor along one dimension of the mesh, and then do collectives along that dimension when there are nontrivial dependencies between shards. Being able to explain why a device mesh is set up the way it is for a collection of parallelization strategies is a good check for seeing if you understand how the parallelization strategies work in the first place! (Credit: This post was influenced by Visualizing 6D Mesh Parallelism.)

tl;dr

  • DP, FSDP: ["dp"]
  • HSDP: ["dp_replicate", "dp_shard"]
  • DP+TP, DP+TP+SP: ["dp", "tp"]
  • DP+UlyssesSP: ["dp", "sp"] (verl)
  • DP+CP: ["dp", "cp"]
  • DP+CP+TP: ["dp", "cp", "tp"]
  • PP+DP+...: ["pp", "dp", ...] (torchtitan), ["dp", "pp", ...] (Megatron)
  • PP+DP+CP+TP+EP: ["pp", "dp_replicate", "dp_shard_mod_ep", "dp_shard_in_ep", "cp", "tp"] (torchtitan)

Prologue: Why device mesh? Before we jump into the zoo, why do we have multi-dimensional meshes in the first place? One intuition is that the dimensions of the device mesh are a reflection of the physical constraints of networking between GPUs (there's a reason why all of the scaling books talk extensively about how the networking for GPUs works; you can't reason about what parallelization strategy you should use without knowing about this!) Let's imagine you have 1024 NVIDIA GPUs. You don't want to treat this 1024 GPUs as an undifferentiated blob of GPUs. Physically, these GPUs are grouped into nodes of eight which have much faster NVLink connections compared to cross-node communication which is done on a slower Infiniband connection. Intuitively, you will want to do something different depending on if you're doing intra-node communication or inter-node communication.

The device mesh imposes structure on this collection of GPUs. A mesh is typically specified as a tensor size (e.g., (128, 8)) as well as string axis names ala named tensor (e.g., ["dp", "tp"]), and is simply an N-D tensor over a range of GPU indices (typically [0, 1, 2, 3, ...] for GPUs, and a mostly ascending but occasionally permuted sequence for TPUs). We typically think of 2D and 3D tensors as grids and cubes, but I find it is more helpful (especially in higher dimensions) to think of the device mesh as imposing some self-similar (fractal) structure on the GPUs. In the simplest 2D mesh that accounts for intra versus inter node communication, GPUs are first organized into nodes on the inner-most dimension, and then the nodes are collected together in the outer-most dimension to form the cluster. (The self-similar nature of the nodes is important because it tells us how communication occurs across the cluster: to communicate over the outer-most mesh dimension, all the GPU 0s on each node talk to each other, all the GPU 1s, etc.) This is only the very simplest mesh we can create, however; with more complicated parallelization strategies we may impose extra levels of structure, e.g., we may organize nodes into pods of two and four, or we might further divide the eight GPUs of a single node. In other words, the mesh tells us about which GPUs communicate to which other GPUs. This is important to know, because when I want to parallelize our model, I am making choices about how to shard tensors across my GPUs. The mesh tells me which GPUs have the other shards of my tensor; in other words, they are who I have to communicate with when I am doing a computation that requires information about the full tensor and cannot be done with the local shards only.

In the zoo, when we talk about a parallelism strategy, we will talk to how it typically relates to other parallelization strategies in the model, and the device mesh will tell us if it is orthogonal to other parallelisms (a new dimension), multiplexed with another strategy (a reused dimension) or perhaps a completely different hierarchy of communication (multiple meshes in the same model that don't factor into the other).

Without further ado, here is the zoo!

Data parallelism (DP). Data parallelism predates the concept of device meshes, since you don't actually need any nontrivial mesh structure to do data parallelism: if you are only doing data parallel, you just shard your input on the batch axis for however many devices you have. This sharding propagates through forwards and backwards until you allreduce to compute the final global gradient for a parameter. If you did make a 1D device mesh (this is useful to think about, because most higher dimensional parallelisms will include some form of data parallelism), you'd probably name your mesh ["dp"], ["ddp"] or perhaps ["batch"].

Let's talk briefly about how people tend to name device mesh axes. In the PyTorch world, it's most common to name the axis after the parallelism that it is responsible, so either "dp" or "ddp" (you really shouldn't call it ddp, but the DataParallel taboo in PyTorch is very real!) The batch name is common in JAX, and is very natural there because when you annotate the sharding of your input, you need to say for each dimension tensor what mesh dim it is sharded over. So when you shard the batch dimension over the batch mesh dim, it looks just like you're labeling the batch dimension of your tensor as batch, e.g., P("batch", None). (This situation doesn't happen in PyTorch because shardings of a tensor are specified per device mesh dim, but that's a story for another day!)

Fully-sharded data parallel (FSDP). This is best understood as an augmentation over DP where weights are also sharded over all GPUs and you just all-gather weights before performing operations (and reduce-scatter in backwards). Because this all-gather is also among all devices, you don't need another axes in your mesh, and your mesh might also be called ["dp"] in this case, even though you're actually doing FSDP. Occasionally, you'll see people name their mesh ["fsdp"] in this case.

Hybrid sharded data parallel (HSDP). HSDP is an extension of FSDP where you shard weights (FSDP) up to the point where you can't actually do a giant all-gather/reduce-scatter over every GPU, and then replicate these shards to cover the rest of your cluster (DP). It's also amenable to fault tolerance techniques that make the modeling assumption that it's OK to lose samples of your batch if a replica fails (you won't model this with device mesh though!). This is probably the first time you will encounter a 2D device mesh (indeed, the DeviceMesh tutorial in PyTorch specifically uses hybrid sharding as its motivating example), since HSDP doesn't require any extra model changes on top of FSDP. There are a few common ways to name the mesh axes for HSDP. One way to think about it is that it is FSDP on the inner dimension and DP on the outer dimension, in which case you would say ["dp", "fsdp"]. Another way is to think about what happens to parameters at the various layers of the mesh: the inner dimension shards, while the outer dimension replicates, so you would say ["replicate", "shard"] or perhaps ["dp_replicate", "dp_shard"] to make it clear that you are still doing data parallelism across both of these device mesh dims (in particular, when you split your batches, you split on both the dp_replicate and dp_shard dims--although, to get the final gradients, you can do the reduction hierarchically by first doing a reduce-scatter on "dp_shard" and then doing an allreduce on "dp_replicate").

Tensor parallelism (TP). Depending on who you ask, tensor parallelism is either about letting you reduce your effective batch size for training or moving you towards reducing the memory usage of activations in your model. In the "reduce effective batch size" framing, the idea behind TP is that you can only scale up DP until your cluster is as large as your batch size. From a modeling perspective, it can be undesirable to have a batch size that is too large, so you can't just keep increasing your batch size to get more parallelism. Instead, TP allows us to get some extra scaling by sharding over the feature dimension of our matrix multiplies [1] (you can shard over either the columns or the rows of your weight matrix, so we will frequently specify if a TP Linear is column-wise or row-wise; in attention, column-wise linear effectively parallelizes the attention computation over attention heads). The communication needed to do TP is fairly exposed (unless you're doing async tensor parallel), so you typically want to keep the communications for it within a single node. This leads to this classic 2D device mesh for DP+TP: ["dp", "tp"] (or, if you're a JAXer, you might write ["batch", "model"], where model is used to indicate the inner feature dimension of the model weights being parallelized over.) When someone says 2D parallelism, they're usually referring to this combo of parallelisms (although I do not recommend using this term--as you can see, it is obviously ambiguous!) Note that tp is the inner mesh dimension, since it benefits the most from the high bandwidth network between GPUs on a single node.

You don't have to stop with DP+TP, however. If you're using FSDP with tensor parallelism (remember, "dp" can mean FSDP!), intra-node TP doesn't improve the amount of inter-node FSDP communication you have to do: however much TP you do, within one TP node you only have one slice of the model and have to talk to everyone else to get their slices. You could solve this by expanding TP to also cross nodes, but in practice mixed intra/inter-node collectives are a lot slower than pure inter-node collectives. This limits the scaling you can get from TP, and so if you're still hitting limits on FSDP, it can still be useful to apply HSDP to avoid running collectives that are too large. In that case, you'd end up with a mesh like ["dp_replicate", "dp_shard", "tp"].

Sequence parallelism (SP). For this section, we specifically take the definition of sequence parallelism from the Ultrascale Playbook (as distinguished from context parallelism). Although we said that TP is the first step towards reducing the memory usage of activations [2], if you literally implement DP+TP based on my descriptions above, you will still end up with more memory spent on activations than you want because there are still parts of the model around the FFN like the LayerNorm need the full hidden dimension to compute mean and variance [3]. To reduce the memory usage in these segments, you need to shard on something else. So typically what you will see is that the model will alternate between TP (hidden dimension is sharded) and SP (sequence dimension is sharded). Consequently, if you look at the device mesh for a model using DP+TP+SP, it will typically still look like ["dp", "tp"], and instead the tp dimension is multiplexed to be used both for TP and SP. Because TP and SP never occur at the same time, you don't need a separate dimension for them.

Ulysses sequence parallelism. Ulysses sequence parallelism from DeepSpeed Ulysses is another sequence parallelism strategy that is implemented by verl (because verl is forked so often, it shows up quite prominently if you are looking for examples of init_device_mesh on GitHub code search). It aims to alleviate memory pressure from extremely long sequences, so sequences are sharded on input, and only when attention needs to be computed is an alltoall issued to re-shard on the attention heads rather than the sequence (doing another alltoall to restore the sequence sharding after the attention is done). Importantly, this means it competes with TP for sharding on the attention heads, which is why you also see people use it to replace TP in MoE models, since it has much less communication than TP (at the cost of having to replicate the attention weights). In verl, you will just see a device mesh ["dp", "sp"] when you are using their FSDP backend (which is what supports Ulysses).

Context parallelism (CP). Context parallelism is another form of "sequence" parallelism. Like Ulysses sequence parallelism, sequences are sharded on input; the difference, however, is instead of using an alltoall to re-shard on attention heads, you just do a (distributed) attention on the entire context. You can do this the easy way by just using allgather to get the full context (as was done in llama4) or you can use a fancy kernel like ring attention, which carefully overlaps communication and computation when performing attention. A popular implementation of context parallelism lives in Megatron, which doesn't directly use PyTorch's native DeviceMesh abstraction but has an analogous HyperCommGrid. The mesh we see here will be something like ["dp", "cp"] or more commonly ["dp", "cp", "tp"]. Notice that we can have a dedicated mesh dim for CP: CP operates very similarly to SP outside of the attention calls (as it is just plain data parallelism when there is no cross-token dependency), but because it never shards on attention heads, it doesn't compete with TP and can be used completely orthogonally to TP (TP shards hidden, CP shards sequence).

CP has a pretty interesting interaction with FSDP. Both DP and CP shard the input data (on batch and sequence respectively). It's pretty common when you do FSDP to just shard over both "dp" ("dp_shard" in HSDP) and "cp". In torchtitan, we create a flattened mesh dim "dp_shard_cp" specifically for FSDP sharding (a flattened mesh dim is what happens if you take your mess and "forget" about some of the structure; e.g., if you were to do an all-gather, you just all-gather over all the flattened axes). In the HSDP world, "dp_cp" is still a useful concept because this is the combination of axes you want to all-reduce over to, e.g., compute the global average loss.

Pipeline parallelism (PP). Pipeline parallelism is kind of an ugly duckling and people tend to hate on it because you have to rewrite your models to introduce pipeline stages, and you can't really use things like DTensor with it (unless you do really strange things like how the GSPMD paper "supports" pipeline parallelism--the general consensus is automatic parallelism does not like PP). PP still goes in the device mesh, because it affects how you are organizing your GPUs, but, for example, torchtitan solely uses it to setup PGs for doing the point-to-point communications. I've seen both ["dp", "pp", ...] or ["pp", "dp", ...] for meshes with PP, but the order probably doesn't make too much of a difference as you are likely solidly inter-node at this point. Pipeline parallelism bandwidth use is very low, and latency can be covered up as you can immediately start processing the next batch after triggering an asynchronous send of the previous batch.

Expert parallelism (EP). EP is its own kettle of fish. Expert parallelism only applies over the expert computation of the model, but within this region, we are not sharding parameters as FSDP conventionally sees it: we will commonly have the entire expert's weights on our node. torchtitan's WIP expert parallelism implementation, when it has ALL parallelisms on, would look like ["pp", "dp_replicate", "dp_shard_mod_ep", "dp_shard_in_ep", "cp", "tp"], where dp_shard has been split into two mesh dimensions (DP shard modulo EP, and DP shard in EP). dp_shard_mod_ep is conventionally one, but when it is not it represents further FSDP-style sharding of expert weights inside of the expert region (there's some complication here if you have shared experts along-side your EP-sharded experts). But then dp_shard_in_ep, cp and optionally tp are combined together to give you the expert parallel dimension. It's actually more intuitive to imagine that you have two distinct meshes: ["pp", "dp_replicate", "dp_shard", "cp", "tp"] and ["pp", "dp_shard_mod_ep", "ep", "tp"]. The keen-eyed may also notice that there is no intrinsic reason the tp mesh size inside and outside of the expert parallel region, but this is not easily done if you have to have a single global device mesh for everything. In fact, there is a WIP PR to have two meshes, one for inside the expert region and one for outside: https://github.com/pytorch/torchtitan/pull/1660

Conclusion. The general concept behind mesh parallelism is that you can compose parallelization strategies without too much fuss. Indeed, the use of, e.g., TP to improve scaling is precisely because it lets you cover your device space without having to expand DP beyond the batch size you want to do. However, as you can see from these concrete examples, it's not always quite as simple as just stacking all of the parallelisms together one on top of each other. In the end, all the device mesh is doing is creating PGs behind groups of devices as defined by the mesh, so if you want some weird setup where you're swapping between two device meshes, PyTorch's general philosophy has been to say, have fun!

Thanks to Horace He, Tianyu Liu and Natalia Gimelshein for helping fact check this post. Any remaining errors are mine!

[1]One more subtlety I want to point out: while we tend to think of TP as sharding the feature dimension of parameters, when we "propagate" this sharding through the network, other intermediate tensors end up getting sharded on the TP dimension as well. In particular, in a transformer block, you will typically have a column-wise linear followed by a row-wise linear, and the intermediate activation will be temporarily sharded on the TP dimension before the row-wise linear runs.
[2]I am very carefully using "activation memory" here and not total memory, because total memory usage (what you actually care about) is also a function of peak memory usage, which is subject to transient peaks such as when FSDP does an all-gather to collect parameters. In fact, even without SP, TP will improve your peak memory usage, because unlike FSDP, it's not necessary to all-gather the full weight matrix to actually perform the matrix multiply. TPs peak memory usage occurs when it all-gathers activations.
[3]You will get a little improvement between the column-wise and row-wise linear, since the activations there are sharded. You can turn this into a big improvement by using selective activation checkpointing and forcing recomputation of activations that aren't sharded! (Plain activation checkpointing tends not to work so well because of the all-gather of the activations.)
  • August 30, 2025

You could have invented CuTe hierarchical layout (but maybe not the rest of it?)

CuTe is a C++ library that aims to make dealing with complicated indexing easier. A key part of how it does this is by defining a Layout type, which specifies how to map from logical coordinates to physical locations (CuTe likes to say layouts are "functions from integers to integers.") In fact, CuTe layouts are a generalization of PyTorch strides, which say you always do this mapping by multiplying each coordinate with its respective stride and summing them together, e.g., i0 * s0 + i1 * s1 + .... Although NVIDIA's docs don't spell it out, the CuTe's generalization here is actually very natural, and in this blog post I'd like to explain how you could have invented it (on a good day).

First, a brief recap about strides. PyTorch views allow us to reinterpret the physical layout of a tensor in different ways, changing how we map logical coordinates into physical locations. For example, consider this 2-D tensor:

>>> torch.arange(4).view(2, 2)
tensor([[0, 1],
        [2, 3]])
>>> torch.arange(4).view(2, 2).stride()
(2, 1)

The physical memory reads 0, 1, 2, 3, and if I want to know what the value at coordinate (0, 1) is (row 0, col 1), I compute 0 * 2 + 1 * 1, which tells me I should read out the value at index 1 in physical memory. If I change the strides, I can change the order I read out the physical locations. For example, if I transpose I have:

>>> torch.arange(4).view(2, 2).T
tensor([[0, 2],
        [1, 3]])
>>> torch.arange(4).view(2, 2).T.stride()
(1, 2)

The physical memory hasn't changed, but now when we read out coordinate (0, 1), we compute 0 * 1 + 1 * 2, which tells me I should read the value at index 2 (which is indeed what I see at this coordinate!)

PyTorch also allows us to "flatten" dimensions of a tensor, treating them as a 1D tensor. Intuitively, a 2-D tensor flattened into a 1-D one involves just concatenating all the rows together into one line:

>>> torch.arange(4).view(2, 2).view(-1)
tensor([0, 1, 2, 3])

We should be able to do this for the transpose too, getting tensor([0, 2, 1, 3]), but instead, this is what you get:

>>> torch.arange(4).view(2, 2).T.view(-1)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

The dreaded "use reshape instead" error! The error is unavoidable under PyTorch striding: there is no stride we can select that will cause us to read the elements in this order (0, 2, 1, 3); after all, i0 * s0 is a pretty simple equation, we can't simultaneously have 1 * s0 == 2 and 2 * s0 == 1.

Upon learning this, an understandable reaction is to just shrug, assume that this is impossible to fix, and move on with your life. But today, you are especially annoyed by this problem, because you were only trying to flatten N batch dimensions into a single batch dimension so that you could pass it through a function that only works with one batch dimension, with the plan of unflattening it when you're done. It doesn't matter that this particular layout is inexpressible with strides; you aren't going to rely on the layout in any nontrivial way, you just care that you can flatten and then unflatten back to the original layout.

Imagine we're dealing with a tensor of size (2, 2, 2) where the strides for dim 0 and dim 1 were transposed as (2, 4, 1). It should be OK to flatten this into a tensor (4, 2) and then unflatten it back to (2, 2, 2). Intuitively, I'd like to "remember" what the original sizes and strides are, so that I can go back to them. Here's an idea: let's just store the original size/stride as a nested entry in our size tuple. So instead of the size (4, 2), we have ((2, 2), 2); and now analogously the stride can simply be ((2, 4), 1). When I write (2, 2) as the "size" of a dimension, I really just mean the product 4, but there is some internal structure that affects how I should index its inside, namely, the strides (2, 4). If I ask for the row at index 2, I first have to translate this 1D coordinate into a 2D coordinate (1, 0), and then apply the strides to it like before.

Well, it turns out, this is exactly how CuTe layouts work! In CuTe, sizes/strides are hierarchical: a size is actually a tree of ints, where the hierarchy denotes internal structure of a dimension that you can address linearly (in fact, everything by default can be addressed in a 1-D linear way, even if its an N-D object.) The documentation of Layout does say this... but I actually suffered a lot extracting out the high level intuition of this blog post, because CuTe uses co-lexicographic ordering when linearizing (it iterates over coordinates (0,0), (1,0), (2,0), etc. rather than in the more normal lexicographic order (0,0), (0,1), (0,2)). This leads to some truly deranged example code where they print a 2D matrix in conventional lexicographic ordering, and then turn around and say, "But wait, if I have the layout take care of translating the 1D coordinate into an ND coordinate, it is colexicographic!!":

> print2D(s2xh4)
  0    2    1    3
  4    6    5    7
# sure, why not?

> print1D(s2xh4)
  0    4    2    6    1    5    3    7
# wtf???

In any case, if you want to engage with the documentation, s2xh4 is the important example to pay attention to for understanding the nested semantics. However, note the example is smeared across like five sections and also you need to know about the co-lexicographic thing to understand why the examples print the way they do.

  • August 22, 2025

State of torch.compile for training (August 2025)

The purpose of this post is to sum up, in one place, the state of torch.compile for training as of August 2025. Nothing in here isn't something you might not already know about from elsewhere on the Internet, but we rarely put everything together in one place. The target audience for this document are teams who are evaluating the use of torch.compile for large scale training runs.

First, the basics. torch.compile (also known as PT2) is a compiler for PyTorch eager programs for both inference and training workloads. Speedups from 1.5-2x compared to eager code are typical, and torch.compile also makes it possible to do global optimizations for memory (e.g., automatic activation checkpointing) and distributed communications (e.g., async tensor parallelism).

What is torch.compile's functionality?

The headline functionality of torch.compile is a decorator you can attach to a function to compile it:

@torch.compile()
def f(x, y):
    ...

Here are some non-functional properties of compile which are important to know:

  • Just-in-time compilation. We don't actually compile the function until it is called for the first time, and execution blocks until compilation completes. There is both local and remote caching to skip compilation cost when you rerun the model. (Ahead-of-time compilation is possible for inference with AOTInductor, and is being worked on for training.)
  • Compositional with Eager. PyTorch's original success comes from the extreme hackability of eager mode, and torch.compile seeks to preserve this. The function can be as big or as small part of your training loop as you like; compiled functions compose with autograd, DDP, FSDP and other PyTorch subsystems. (This composition is sometimes imperfect, e.g., in the case of double backwards (not supported), tensor subclasses (requires specific support from the subclass), autograd (differentiating with respect to intermediates returned from a compiled region does not work).) If compilation doesn't work on a region, you can disable it entirely with torch.compiler.disable() and fall back to eager.
  • Gradient updates are delayed to the end of compiled regions. This arises because PyTorch eager autograd does not support streaming gradients incrementally from a large backward node. (This can be solved by using compiled autograd, but this requires that the entirety of your backwards be compileable.)
  • Graphs may be recompiled. We aggressively specialize on all non-Tensor arguments/globals used in the function to ensure we always generate straight-line computation graphs with no control flow. If those arguments/globals change we will recompile the graph. (Recompilations can be banned with torch._dynamo.config.error_on_recompile = True.)
  • Static by default, recompile to dynamic shapes. We aggressively specialize all sizes to static. However, if we discover that a size varies over time, on the first recompile we will attempt to generate a single compiled region that handles dynamic shapes. We are not guaranteed to be able to compile a model with dynamic shapes. (You can use mark_dynamic to force an input shape to be dynamic, and you can use mark_unbacked to error if we specialize.)
  • Graph breaks transparently bypass non-capturable code. By default, if the compiler encounters a line of code that it is not able to handle, it will trigger a graph break, disabling compilation for that line of code, but still attempting to compile regions before and after it. (This behavior can be banned with fullgraph=True.)
  • Function calls are inlined and loops are unrolled by default. If you have many copies of a Transformer block in your model, your compile time will scale with the number of Transformer blocks. (You can reduce compile time by doing "regional compilation", where you only compile the Transformer block instead of compiling the entire model.)
  • NOT bitwise equivalent with eager PyTorch. The biggest divergence with eager PyTorch is that when float16/bfloat16 operations are fused together, we do not insert redundant down/up-conversions. (This can be disabled torch._inductor.config.emulate_precision_casts = True; you can also rewrite eager code to perform operations in higher precision with the understanding torch.compile will optimize it. XLA has a similar config xla_allow_excess_precision which JAX enables by default.) However, we may also make decisions to swap out, e.g., matmul implementations, and there may also be slight divergence that arise from differences in reduction ordering that are unavoidable when compilation occurs. We support ablating the graph capture frontend separately from the compiler backend to help diagnose these kinds of problems.
  • Distributed collectives and DTensor can be compiled, but are unoptimized by default. We are able to capture c10d collectives and also programs that handle DTensors, but we don't apply optimizations to collectives by default. (There are experimental optimizations that can be enabled, but this is active work in progress.) We generally do not expect to be able to trace through highly optimized distributed framework code.

State of advanced parallelism

For large scale training runs, torch.compile faces stiff competition from (1) PyTorch native distributed frameworks which embrace eager mode and implement all optimizations by hand (e.g., megatron), (2) custom "compiler" stacks which reuse our tracing mechanisms (e.g., symbolic_trace and make_fx) but implement their desired passes by hand, (3) JAX, which has always been XLA first and is years ahead in compile-driven parallelism techniques.

Here is where we currently are for advanced parallelism (with an emphasis on comparing with JAX):

  • DTensor, a "global tensor" abstraction for representing sharded tensors. DTensor is a tensor subclass which allows us to represent tensors which are sharded over an SPMD device mesh. The shape of a DTensor reflects the global shape of the original full tensor, but it only stores locally a shard of the data according to the placement. Here are some important details:
    • Shard placements. Unlike JAX placements, DTensor placements are "device mesh" oriented; that is to say, you conventionally specify a device mesh dim size list of placements, and Shard(i) indicates that the ith dimension of a tensor is sharded. This is opposite of JAX, which is "tensor" oriented. For example, given a 2-D mesh ["dp", "tp"], a tensor with [Replicate, Shard(0)] in DTensor placement (or {"dp": Replicate, "tp": Shard(0)} with named device mesh axes), would correspond to a JAX placement of P("tp", None). The reason for this is that DTensor supports a Partial placement, which indicates that an axis on the device mesh has a pending reduction. Partial shows up ubiquitously from matrix multiplies, and it isn't associated with any particular tensor axis, making it more convenient to represent in a device-mesh oriented formulation. The tradeoff is that device-mesh oriented placements don't naively support specifying sharding ordering, e.g., suppose I want to shard a 1-D tensor on tp and then dp, in JAX I'd represent this as P(("tp", "dp"),) but this order cannot be disambiguated from [Shard(0), Shard(0)] and in fact DTensor always forces left-to-right sharding. There is currently a proposal to extend our sharding specification to support ordering to bring us to parity with JAX expressiveness, but it is not yet implemented.
    • Autograd. DTensor is directly differentiable; we run autograd on programs that have DTensors (as opposed to desugaring a DTensor program to one with regular Tensors and differentiating it). This ensures that the sharding strategy of a primal and its corresponding tangent can diverge. This is parity with JAX.
    • Python subclass of Tensor. Unlike JAX, DTensor is a separate subclass from Tensor. However, Tensor and DTensor interoperate fine; a Tensor can simply be thought of as a DTensor that is replicated on all dimensions. DTensor is implemented in Python, which makes it easy to modify and debug but imposes quite a bit of overhead (for example, FSDP2 does not directly accumulate gradients into DTensor, because with thousands of parameters, performing detach and add operations on DTensor is a bottleneck). Still, despite this overhead, DTensor was designed for good eager performance, and extensively caches the results of sharding propagation so that in the fastpath, it only needs to lookup what redistribute it should perform and then directly dispatches to the local eager operation. However, this caching strategy means that overhead can be quite high for workloads with dynamic shapes, as the cache requires exact matches of all input shapes.
    • Compilation. DTensor is compilable by torch.compile, and doing so will desugar it into its underlying collectives and eliminate any eager mode DTensor overhead (even if you do not perform any other optimizations.) However, DTensor with dynamic shapes in compile is not well supported, see http://github.com/pytorch/pytorch/issues/159635 (we don't think this is currently critical path for any critical use cases, so a relatively junior engineer has been chipping away at it.)
    • Greedy propagation. Because DTensor must work in eager mode, it only implements greedy shard propagation, where for every eager operation we greedily pick whatever output shard minimizes the collective costs of an operation. It is work in progress to support backward propagation of sharding with the assistance of a compiler-like framework.
    • Operator coverage. DTensor requires sharding propagation rules to work for operations. If a sharding propagation rule is not implemented, DTensor will fail rather than trigger an inefficient allgather to run the operator under replication. We don't currently have full coverage of all operators, but important operators for transformer models like llama3 are all covered (sharding rules are defined here). You can write custom shardings for user defined operators.
    • Jagged sharding. We do not support a "jagged sharding" concept which would be necessary for expert parallelism with imbalanced routing. However, we believe that our existing sharding rules could largely be reused to support such an idea. As dynamism would only be exposed in the local tensor for the jagged shard, jagged shards don't suffer from the dynamic shapes problems mentioned in the compilation section.
    • Ecosystem. We are committed to DTensor as the standard representation for sharded tensors, and DTensor is integrated with checkpointing, FSDP2, SimpleFSDP, AutoParallel, torchtitan, among others.
  • Functional collectives. If you don't like DTensor, we also support "functional collectives", which are non-mutating versions of collective operations that can be used to manually implement SPMD operations in a compiler-friendly way without needing DTensor. (In fact, if you use traditional collective APIs and compile them, we will silently translate them into functional collectives for compiler passes.) When compiled, functional collectives don't necessarily force allocation of the output buffer as they can be re-inplaced. Importantly, functional collectives currently do NOT support autograd, see https://discuss.pytorch.org/t/supporting-autograd-for-collectives/219430

  • Graph capture. There are two particularly popular graph capture mechanisms which people have used to perform distributed optimizations separate from model code. All graph capture mechanisms produce FX graphs, which are a simple Python basic block IR representation with no control flow, which is entirely unopinionated about what actual operator set can occur in the graph.
    • Symbolic_trace. This was the original graph capture mechanism and is quite popular, despite its limitations. It is implemented entirely with Python operator overloading and will give you exactly whatever operations are overloadable in the graph. We consider this largely a legacy pipeline as you are unable to trace code involving conditionals on shapes and you end up with a graph that has no useful metadata about the shapes/dtypes of intermediate values. For example, PiPPY, a legacy stack for performing pipeline parallelism, was built on top of symbolic_trace graph capture.
    • make_fx/torch.export. This graph capture mechanism works by actually sending (fake) tensors through your program and recording ATen operators. There are a number of different variants: e.g., whether or not it is a Python tracing approach ala JAX jit, or whether it uses sophisticated bytecode analysis ala Dynamo; similarly, there are various levels of IR you can extract (pre-dispatch, post-dispatch; also, operators can be decomposed or kept as single units). Our compiler parallelism efforts are built on top of this capture mechanism, but there is nothing stopping you per se from writing your own graph pass on top of this IR. In practice, this can be difficult without PyTorch expertise, because (1) integrating a traced graph into PyTorch's autograd system so it can interoperate with other code is quite complicated to do in full generality, (2) the exact operator sets you get at various phases of compilation are undocumented and in practice very tied to the Inductor lowering stack, and it is poorly documented on how to prevent operators from getting decomposed before your pass gets to them.
  • Not SPMD compiler by default. torch.compile does not assume the program being compiled is SPMD by default, which means it will not do things like drop unused collectives (you can change this behavior with a config flag). Additionally, the default mode of use for torch.compile is to compile in parallel on all nodes, which means care has to be taken to ensure that every instance of the compiler compiles identically (only one rank recompiling, or compilers making different decisions, can lead to NCCL timeout). We ultimately think that we should compile a program once and send it to all nodes, but as this is not currently implemented, the general approach people have taken to solve this problem is to either (1) eliminate all sources of divergent behavior from ranks, e.g., don't allow the compiler to look at the actual size for dynamic inputs when making compiler decisions, or (2) introducing extra collectives to the compiler to communicate decisions that must be made consistently across all ranks.

Our vision for the future of advanced parallelism, spearheaded by the in-progress SimpleFSDP and AutoParallel, is that users should write single-node programs that express mathematically what they want to do. These are then transformed into efficient distributed programs in two steps: (1) first, collectives are inserted into the graph in a naive way (i.e., simply to express what the sharding of all intermediates should be), and (2) the collectives are optimized to handle scheduling concerns such as pre-fetching and bucketing. AutoParallel sets a GSPMD style goal of automatically determining a good enough sharding for a program--it should be able to rediscover data parallel, tensor parallel, even expert parallel(!)--but SimpleFSDP sets a smaller goal of just inserting collectives in the pattern that FSDP would mandate, and then writing FSDP-specific optimization passes for recovering FSDP2's performance. It is very common to write domain specific optimizations; for example, async tensor parallelism is also implemented as a pass that detects TP patterns and rewriting them to async TP operations. Unlike JAX, which started with a very generic solver and has needed to add more manual escape hatches over time, PyTorch has started with writing all of the distributed patterns exactly by hand, and we are only recently adding more automatic mechanisms as an alternative to doing everything by hand.

State of optimization

torch.compile performs many optimizations, but here are some particularly important ones to know about:

  • Inductor. Inductor is our backend for torch.compile that generates Triton kernels for PyTorch programs. It has very good coverage of PyTorch's operator set and can do fusions of pointwise and reductions, including in the patterns that typically occur for backwards. It also is able to fuse pointwise operations into matmuls and autotune different matmul backends (including cuBlas, cutlass and Triton) to select the best one for any given size. When people talk about torch.compile speeding up their programs, they are conventionally talking about Inductor; however, you don't have to use torch.compile with Inductor; for example, you could run with AOTAutograd only and skip Inductor compilation.
  • CUDA graphs. Inductor builds in support for CUDA graphing models. Unlike manual CUDA graphs application, we can give better soundness guarantees than manual CUDA graphs application (e.g., forgetting to copy in all input buffers, CPU compute inside the CUDA graph region). torch.compile CUDA graphs is typically used with Inductor but we also offer an eager-only cudagraphs integration (that is less well exercised).
  • Automatic activation checkpointing. With torch.compile, we can globally optimize the memory-compute tradeoff, much better than the activation checkpointing APIs that eager PyTorch supports (and require the user to manually feed in what they want checkpointed or not). However, some folks have reported that it can be quite miserable tuning the hyperparameter for AC; we have also found bugs in it.
  • FP8 optimizations. One big success story for traditional compilation was adding support for a custom FP8 flavor. With torch.compile, they didn't have to write manual kernels for their variant. This has since been upstreamed to torchao.
  • Flex attention. Flex attention usage continues to grow, with 632 downstream repo users in OSS (vs 125 in Jan '25). It has been used to enable chunked attention, document masking and context parallelism in llama family models. It is a really good research tool, although sometimes people complain about slight numerical differences.
  • Helion. Helion is an actively developed project aiming to go beta in October this year which offers a higher level interface for programming Triton kernels that looks just like writing PyTorch eager code. It relies heavily on autotuning to explore the space of possible structural choices of kernels to find the best one. It is not production ready but it is worth knowing that it is coming soon.

State of compile time

torch.compile is a just-in-time compiler and as such, in its default configuration, compilation will occur on your GPU cluster (preventing you from using the GPUs to do other useful work!) In general, most pathological compile times arise from repeated recompilation (often due to dynamic shapes, but sometimes not). In Transformer models, compile time can also be improved by only compiling the Transformer block (which can then be compiled only once, instead of having to be compiled N times for each Transformer block in the model).

We don't think caching is an ideal long-term solution for large scale training runs, and we have been working on precompile to solve the gap here. Precompile simply means having compilation be an ahead-of-time process which produces a binary which you can directly run from your training script to get the compiled model. The compilation products are built on top of our ABI stable interface (developed for AOTInductor) which allows the same binaries to target multiple PyTorch versions, even though PyTorch the library does not offer ABI compatibility from version to version.

How do I get started?

The most typical pattern we see for people who want to make use of torch.compile for large-scale training is to fork torchtitan and use this codebase as the basis for your training stack. torchtitan showcases PyTorch native functionality, including torch.compile--in effect, it shows you how to use features in PyTorch together in a way that lets you do large-scale training. From there, swap out the components you are opinionated about and keep the things you don't care about.

  • August 13, 2025

Vibe coding case study: ScubaDuck

A lot of strong engineers that I know haven't really taken a serious look at AI coding; they've used LLMs to ask questions or write simple scripts and appreciate that it is a useful tool, but haven't actually tried building a nontrivial application entirely from scratch in vibe coding style (here, I use the term in its original meaning: when you do AI coding without carefully reviewing the output). This is understandable: if you're not working on a green field project, there aren't that many opportunities to write code in this style--standard practice for established projects is that someone else needs to review all of the code you write: this is a bad match for vibe coding! So in this post, I want to give a concrete case study of a nontrivial system that was entirely vibe coded (ScubaDuck), to argue the following claims:

  1. AI coding can be done on a manager's schedule: you don't need continuous blocks of coding time and context-switching is considerably less harmful. ScubaDuck was implemented in three days of part time work, where all of the work happened when the baby was napping.
  2. AI coding substantially lowers the cost of doing projects in tech stacks you are less familiar with. ScubaDuck is mostly JavaScript UI code, which is not something I write on a day-to-day basis.
  3. AI coding is an unlock for "sidequests": support software that's ancillary to your main task that is nice to have, but not essential. If previously you would have decided the cost outweighed the benefit, AI coding reducing the cost means you should redo these calculations.
  4. Vibe coding works and can produce working software. ScubaDuck is an existence proof that vibe coding is a viable strategy for generating JavaScript UI code (NB: I don't claim vibe coding will work for all domains, nor do I claim this is the only domain for it works. Hopefully you can also build some intuition for where it is more or less likely to work). You will not one shot it (ScubaDuck was 150 prompts in the end) but if you are prompting the LLM to also generate tests, you can reliably fix issues without causing regressions to existing code.
  5. Vibe coding is good for situations where buggy software is low impact; be on the lookout for ways to engineer this sort of situation. ScubaDuck is a read-only interface, where the only downside to being buggy is you can't issue the queries you want to issue.

Update: You can see all of my prompts and the resulting agent trajectories at scubaduck-prompts.

What is ScubaDuck?

ScubaDuck is a discount implementation of Meta's internal Scuba realtime database system. You can read more about what exactly this is on GitHub, but it's not so important for the purposes of this post: the key details you need to know about ScubaDuck is that it consists of a Python server that exposes an API to perform queries against a DuckDB database, and an HTML and JavaScript frontend application which implements the forms for building these queries and rendering of the output data. Both the forms and output data rendering have nontrivial JavaScript enhancements: some form inputs are chip inputs and support autocomplete, and the time series view is an SVG chart. All of these components were coded from scratch, so the project has no third-party JavaScript dependencies.

So on the one hand, this project is pretty simple. There are no stringent performance or uptime requirements, it's a pretty standard server-client program that the LLM has seen millions of times before (this is good!) On the other hand, the exact behavior of the frontend UI is quite intricate and would be very difficult to one-shot in a single prompt. Indeed, as I was coding and testing the application, I frequently ran into situations that I didn't anticipate in my original specification, and that I had to ask Codex to refine. Another way to put it is that ScubaDuck is a relatively simple functional specification (although this too was not one shot), but I did a lot of polishing of small behaviors so that the interface behaved in the way that I expected Scuba to behave. Here, it was helpful that I had a very clear idea of what I wanted (since I've used Scuba quite a lot at work).

Going into ScubaDuck, I had a pretty good sense that this project should be a good fit for LLMs. HTML, JavaScript and Python are all extremely high resource languages, and I'd heard lots of people raving about how good LLMs were at transforming wireframes and mockups into fully functional websites. It is also fully self contained and straightforward-ish to test (only "ish" because you do have to use something like Playwright to actually test the frontend UI, which honestly is a slog. But fortunately, the LLM can write the tests for you!) One design decision I made, which I didn't originally anticipate but worked out in the end, was the decision to not use any third-party JavaScript libraries. This was by accident: Python has no native of bundling third party JavaScript, but I wanted the tool to work offline. I wasn't sure if you could vibe code an SVG charting library from scratch, but apparently you can and it's quite easy!

Agent setup

ScubaDuck was implemented with OpenAI Codex in the cloud (not the CLI tool). Codex's cloud offering requires you to initialize a hermetic environment which the coding agent can execute commands in. It's pretty well known now that AI coding agents work much better if they are able to run the code they write and see if it worked or not, so this is quite an important part of the process. Unfortunately, this was somewhat time consuming trial and error to setup. I had a fairly detailed initial prompt, and what I would do was submit it to Codex, watch it fail, read over the trajectory (the agent logs) to see what happened (Codex wanted to use npm! Codex couldn't download something from the internet! Codex tried to use a package that wasn't available!) and then fixed whatever environment misconfiguration had caused it to fail, or edited AGENTS.md to instruct it to not do some behavior. According to my history, the first day of the project was spent unsuccessfully trying to get the project setup, and my first successful Codex PR only happened on May 19.

At the end of setup, I had the following:

  1. A pyproject.toml with exactly the dependencies I wanted to be used (duckdb, flask and python-dateutil), a lockfile for it (since I was using uv) and my preferred configuration for various tools (pytest, ruff). I'm a big fan of pytest-xdist for vibe coded projects, since you can prompt the LLM to write tests that will work when run in parallel and it does a pretty good job at this. Later I'd also add a pyright configuration, though initially I left it out because I saw Codex doing some strange things on account of duckdb being untyped, and I didn't want to debug it at the time (the fix, by the way, is instructing the LLM to define stubs as necessary in this case.)
  2. An AGENTS.md file with some basic instructions to try to get Codex to stop doing things I saw it doing in the initial trajectories that I didn't want it to do. Nothing fancy, just if you see Codex do something bad, tell it not to do it in AGENTS.md. A good example of this is the "There are no nested AGENTS.md files, this is the only agents file": Codex is post-trained to look for nested AGENTS.md files, but you can save a few tool calls if you tell it there aren't any. (Note: folklore for Claude 3.7 is that instruction following for this sort of rules following was not great. Word on the street is that both Codex and Claude 4 are substantially better at this. Extra note: For uv users, another notable instruction in AGENTS.md is how to activate the venv, since at time of writing I couldn't get Codex to make this happen automatically.)
  3. A setup script for the environment. This took the most debugging, because Codex runs all Internet access through a proxy and sometimes it works imperfectly.

After I got my initial prompt to generate a first draft of the application, I was able to begin vibe coding in earnest.

The Human-Agent loop

The basic vibe coding loop works like this:

  1. Interact with the application and find things that are broken
  2. Prompt the LLM to fix them
  3. Repeat

For example, after the very first PR, some very mild poking around immediately revealed the bugs fixed in #2:

There's a race condition in the current test logic for matching against table contents in run_query. Specifically, if there were previously valid results in lastResults, and for some reason Dive doesn't do anything, then we will still see the old results. The testing framework should explicitly clear lastResults before attempting an interaction.

...and #3:

Filter functionality does not work. We will first add a failing test, and then fix it. The failing test should click "Add Filter", then select "user" as the field, and then add an "alice" chip (by typing alice in the text box and pressing ENTER). Then when we dive, we should see two alice rows. Right now, NO request is issued at all when we click Dive. Diagnose and then fix the problem.

Prompt the agent to write tests. It's very helpful to prompt the agent to generate tests for whatever bugs its fixing. For frontend code, I decided to use playwright to write these tests. An example in #11:

def test_header_and_tabs(page: Any, server_url: str) -> None:
    page.goto(server_url)
    page.wait_for_selector("#order_by option", state="attached")

    header = page.text_content("#header")
    assert "sample.csv" in header
    assert "events" in header

    assert page.is_visible("#settings")
    assert page.is_hidden("#columns")
    page.click("text=Columns")
    assert page.is_visible("#columns")
    cols = page.locator("#column_list li").all_inner_texts()
    assert "timestamp" in cols
    assert "event" in cols
    page.click("text=View Settings")
    assert page.is_visible("#settings")

    btn_color = page.evaluate("getComputedStyle(document.querySelector('#dive')).backgroundColor")
    assert "rgb(0, 128, 0)" == btn_color

    sidebar_overflow = page.evaluate("getComputedStyle(document.querySelector('#sidebar')).overflowY")
    view_overflow = page.evaluate("getComputedStyle(document.querySelector('#view')).overflowY")
    assert sidebar_overflow == 'auto'
    assert view_overflow == 'auto'

This is kind of a very manual and somewhat brittle test. I doubt I would have had the patience to manually write and debug this. But the LLM will happily do it. There is also a fear that the test might not actually be testing anything for real. But because these tests are all generated from bug reports, you can prompt the model to write the test first, check that it's failing, and then fix the problem. I didn't instruct this in the above prompt, but I did instruct it for #155, for example. The trajectory shows that the LLM demonstrates it can repro the bug before fixing the issue:

I’ll add a case-insensitive test for "samples" to check server behavior with lowercase input.

(scubaduck) root@5b69743466e4:/workspace/scubaduck# python - <<'EOF'
> import json
> from scubaduck import server
>
> app = server.app
> client = app.test_client()
> payload = {
>     "table": "events",
>     "start": "2024-01-01 00:00:00",
>     "end": "2024-01-03 00:00:00",
>     "graph_type": "table",
>     "order_by": "samples",  # lowercase
>     "order_dir": "DESC",
>     "limit": 10,
>     "columns": [],
>     "group_by": ["user"],
> }
> rv = client.post("/api/query", data=json.dumps(payload), content_type="application/json")
> print('status', rv.status_code)
> print('data', rv.get_json())
> EOF
status 400
data {'error': 'Unknown column: samples'}
(scubaduck) root@5b69743466e4:/workspace/scubaduck#

I’ve found the bug happens when "Samples" is lowercase. I’ll adjust the query to handle case insensitivity and add a test for it. Time to modify server.py to fix the issue!

Beware BC-breaking changes. The LLM generated tests worked well when I didn't need to change behavior in a BC-breaking way. Example prompt:

Change time column to default to none unconditionally. (I want to default it to a time column if we find a good candidate but this threshold needs tuning and I don't have the tuning right now.)

This hit a Codex timeout:

This attempt to update the default time column led to extensive frontend test failures that require deeper changes. The environment made it difficult to fully diagnose the Playwright test behaviors within the time available, so the work could not be completed.

In standard software engineering practice, when this happens, decouple the BC compatible and BC breaking changes!

Make it so that Time Column can be set to (none). When it is set this way, the Start/End fields are hidden and we don't apply a filter on time range. (#115)

and then later instead of defaulting the time column to none, I added a heuristic to pick a column that looked like time, which picked the same column that all of the existing tests had also expected to be called with.

Refactors have to be split up. Codex's timeout means that you can't ask it to do too much in one go. Here's a prompt that timed out:

scubaduck/index.html has gotten a bit long. Let's split out some of the JS code into dedicated JS files for their functionality. Also setup the necessary Flask scaffolding to serve these JS files. I think splitting out these specific components would be good:

  • Dropdown implementation
  • Sidebar resizing
  • JS controlling the View Settings (e.g., updateDisplayTypeUI, as well as one off interactions on form elements, columns handling, filter handling, the actual Dive implementation (including query updating), reading in defaults from query string)
  • Table rendering (e.g., formatNumber, sorting)
  • Chip input implementation
  • Chart rendering (showTimeSeries)

Make changes to AGENTS.md or README.md describing the structure so you can quickly find where the components you need are

I eventually did manage the refactor by prompting Codex to individually move out the pieces I wanted to extract one-by-one. This is a place where I think Claude Code probably would have performed better.

Parallelizing tasks. As you can see from the lengths of my prompts, it does take a while to write a good prompt; you're basically writing a bug report with enough detail that the LLM can repro it and then fix it. So sometimes I would be bottlenecked on prompt writing. However, sometimes the prompts were quite short. In those cases, Codex encourages you to submit more tasks that can run in parallel. I found this worked well, and I'd sometimes have as many as five instances going (once again, rate limited by discovering problems, making designs and typing prompts!) One irritation is when the tasks end up conflicting with each other. Sometimes the conflicts are easy to fix, but if it feels nontrivial, it's often better to just ask Codex to redo one of the PRs on latest main after the other has landed. To avoid merge conflicts, it helps to have only one "main feature" agent going at any time, and then ask the agent to do random bugfixes in parallel with it. Once you have no more tasks to get running, you can go do something else while you wait for the agents to finish (manager schedule!)

Prompting

As a reminder, I've posted all of my prompts (including the ones that failed) at scubaduck-prompts, and I think it's helpful to skim through them to get a flavor of what I was asking the LLM. But to summarize, what did I spend most of my time on prompting Codex to do? My general vibe (ahem) is that I spent most of my time doing minor enhancements, where I instructed Codex to make some part of the program work slightly differently, in a way that was previously unspecified from the previous prompt. The metaphor I had in my head while I was working on the project was like that of a sculptor chiseling away marble: in the beginning, anything is possible, but as I kept prompting, I continuously narrowed down the space of possible programs I had until I had exactly the one I wanted. One big thing I want to note is that Codex rarely needed to make updates to my tests; for the most part, tests that were added never got taken away, because I never "changed my mind". I suspect that the vibe coding process would have been rockier if I was having to change behavior frequently.

One of the things that surprised me the most about the process was how easy it was to implement a line chart in SVG with Codex. My first prompt resulted in a chart that looked broken on the test data:

We're going to add a new View type, to go along with Samples and Table: Time Series. Time Series supports all the fields that Table supports, and a few more:

  • X-axis: Main group by dimension, e.g., the x-axis on time series view. This is our custom dropdown selector, but only time columns are populated here. It should prefer a default setting from the following list, most preferred first: "time", "timestamp"
  • Granularity: Choose the time interval between data points on the chart. For example, a granularity of 1 hour means there will be a data point every 60 minutes that is aggregated with the chosen Aggregate function over the data for the granularity period before point. This is a plain drop down. The valid values are: Auto, Fine, 1 second, 5 seconds, 10 seconds, 30 seconds, 1 minute, 4 minutes, 5 minutes, 10 minutes, 15 minutes, 30 minutes, 1 hour, 3 hours, 6 hours, 1 day, 1 week, 30 days. The semantics of the Auto setting is that it sets the interval to whatever would result in maximum 100 buckets (if there are not enough data points for that many buckets, it just picks the finest time interval that makes sense), and Fine which sets the interval to 500 buckets.
  • Fill Missing Buckets: This is a dropdown. For now, it has the settings "Fill with 0 (Per Series)" (default), "Connect (Per Series)" and "Leave blank".

Additionally, the default setting of Limit is 7, as it controls how many elements from group by will be plotted (the actual number of lines plotted could be a multiple of this, as we will plot every selected Column).

Unlike Samples and Table, we will instead display a line chart in the right panel. To plot the line chart, we will implement it by hand with JS and SVG, similar to how highcharts implements it. We will not use any third party dependencies. Lines will be plotted as paths, no smoothing, no dots for individual data points. Each series (as generated by group by) should be plotted with a different color, assigned using a best practices color palette for graph design. There should be a rendering of x-axis and y-axis; the x-axis should have slanted labels to aid readability. When we mouse over the chart, a vertical line should snap to the center of the time bucket that we are closest to. We should also display a crosshair on all of the series showing us their values at that data point, and highlight the closest point we are on, and increase the thickness of the series that point is on. To the left of the graph (still in the right panel), there should be a legend. The legend looks like this:

[GROUP BY VALUE] [AGGREGATE]
[First Column name, with series color]
[Number of samples for the first column]
[Second Column name, with series color]
[Number of samples for the second column]
... for all columns
----
... for all group by values (up to the limit)

So for example, if I group by user, I might see:

Alice AVG
value
4 (samples)

The highlighted series (which has a thicker line) should also be highlighted in the legend).

This was kind of terrifying, because I initially thought I didn't have a good way to test the SVG outputs. But after doing some regular old-fashioned debugging and reading the code (yes, this part not vibe coded), I figured out the problem, and also realized that Playwright can test that an SVG path is not just entirely straight. After the initial bugs were fixed, I mostly had to add missing features like x-axis/y-axis and interactivity features (amusingly, Codex ignored most of the instructions in the latter half of the prompt, giving only the barest bones legend. I suspect this was because I had some files which were too long). My general take after this was that JS chart libraries are going to become obsolete: it's much easier to vibe code a bespoke implementation and then customize the heck out of it.

Conclusion

ScubaDuck was implemented in about 150 Codex prompts. As you can see from the sample prompts above, the prompts are recognizably programming, they just happen to be in plain English language. This is a big help, because I never had to keep track of the nest of callbacks and state machines for implementing complex UI elements in JavaScript. I had to be fluent in what I wanted my program to do, and a good QA tester for the application to discover new problems that needed to be fixed, but I did not have to worry at all about the vagaries of SVG DOM elements or pixel position computation minutiae. It's hard to say how long it would have taken to code this by hand, but I think reproducing a UI that's been in production for years at Meta in three (part-time) days is pretty good!

Despite having done a bit of AI coding before, I also learned a bit from working on Codex. Codex made it blindingly clear that the parallel modality (and subsequent conflict resolution) is important. It made me adjust up my estimation of the capability of LLMs to write raw HTML/JS and evoked a future where people vibe code components in place of taking on a third party dependency. I was very appreciative of no rate limit Codex (though I doubt it's going to last.) It also reminded me how difficult it will be to setup agent environments for "real" projects (like PyTorch).

Hopefully, this case study has given you some ideas for things to try. Go forth and vibe code, responsibly!

  • June 2, 2025

Why you should maintain a personal LLM coding benchmark

Do you use an LLM for coding? Do you maintain a personal benchmark based on problems you have posed the LLM? The purpose of this blog post is to convince you should do this: that you can do so with marginal effort on top of your day-to-day vibe coding and that you will get both short and long term benefits from making your own personal benchmark exist.


I started thinking about benchmarks for coding in part with my frustration with the discourse around LLMs in the public squares I frequent (Reddit and Twitter). People often want to know "what's the best model" or "what's the best coding IDE"? One might imagine that the way to answer this question would be to test the models on a variety of problems from real world uses of the LLM for coding, and then compare how well various systems do on this. Indeed, whenever a new SOTA model releases, the lab will usually tell you about the model's performance against a few well known coding benchmarks. Problem solved?

https://blog.ezyang.com/wp-content/uploads/2025/03/Screenshot-2025-03-31-at-10.10.14%E2%80%AFAM.png

Of course not! In fact, for the most part, no one really talks about benchmarks when comparing models. Why? I argue the most popular benchmarks measure tasks that are largely different from what a user wants out of an LLM. For example, take the recent Gemini 2.5 Pro release. In their headline table, they test against LiveCodeBench, Aider Polyglot and SWE-bench Verified. Both LiveCodeBench and Aider Polyglot derive their problems from contest programming and pedagogical exercises (respectively), while SWE-bench assesses bug fixes to preexisting codebases. While useful, this is only a small slice things people want to do with LLMs.

Wouldn't it be great if you had your own, personal benchmark, based on problems you actually care about? If you are tweaking your .cursorrules, you could run your benchmark to see if a change you made helped or not. When a new model comes out, you could spend a few bucks to run your eval and make a decision if you should switch your daily driver. And then on social media, if you wanted to stan the new model, instead of asking the model to drop a ball inside a rotating hexagon or vagueposting about how the new model is incredible, you could just post your benchmark results.


Nicholas Carlini's Yet Another Applied LLM Benchmark is an existence proof that this playbook can work. As Nicholas describes it:

It's a collection of nearly 100 tests I've extracted from my actual conversation history with various LLMs.

There are two defining features of this benchmark that make it interesting. Most importantly, I've implemented a simple dataflow domain specific language to make it easy for me (or anyone else!) to add new tests that realistically evaluate model capabilities. This DSL allows for specifying both how the question should be asked and also how the answer should be evaluated. Most questions are evaluated by actually running the code the model writes but the framework supports a bunch of other evaluation methods as well. And then, directly as a result of this, I've written nearly 100 tests for different situations I've actually encountered when working with LLMs as assistants.

I have been working on my own benchmark based off of Carlini's benchmark, and I can confirm that this works well for the traditional style of coding eval, where you have a one-shot task that generates and executes the code against some test cases. My basic strategy is to vibe code as usual, but whenever I give an LLM a task that it isn't able to one shot, I consider adding it to the benchmark. In more detail:

  • I only add a task if a SOTA LLM failed it. This ensures the benchmark consists of all appropriate difficulty problems: easy enough that I thought an LLM should be able to do it, but hard enough that a SOTA model failed on it. I don't need problems that are too hard (this is already well covered by well known benchmarks like SWE-Bench or SWE-Lancer), and I don't mind if my problems saturate because, hey, that means the models are that much better for my use cases!
  • After I have added the task to the benchmark, I can use the benchmark runner to tell if changing the model, tweaking the prompt, or even just running the prompt again at nonzero temperature can make it pass. Indeed, it's helpful to find some configuration that makes the eval pass, as this is good for debugging issues in the evaluation function itself... also it means you have working code for whatever task you were working on. Conversely, you can make the task harder by leaving things out from the prompt.
  • Writing the test is the labor intensive part, but you can always vibe code a test. Importantly, you have a failing implementation (your initial generation) and some way you (manually?) determined that the implementation was wrong, so just turn this into your evaluation function! (And for all you yak shaving aficionados, if the model fails to vibe code your test, well, you have another task for your benchmark!)

For example, the other day I needed to take an asciinema recording and convert it into a sequence of frames rendered as plain text. However, the only project for doing these conversations was agg, which converts recordings into animated gifs. In agg_to_text, I ask an LLM to take agg's source code and create a new program which dumps the frames as plain text rather than gif images. The reason why this task is difficult, is because there is some discretion in deciding when to emit a frame, and with my original prompt the LLM didn't precisely replicate the original behavior in agg. While working on the benchmark, I realized that instructing the model specifically about how frame batching worked was enough to get it to preserve the original behavior. But I don't think I should need to do this: thus this task. (P.S. If this test saturates, well, I can always make it harder by removing the agg source code from the prompt.)


The ability to benchmark one shot tasks is here today, but I would like to speculate a bit about what lies beyond them. In particular, most of my LLM coding activity involves asking the LLM to make changes to a pre-existing project, which makes it less amenable to "single prompt creates self contained program". (Also, I usually only ask one-shot questions that the LLM can answer, so most of them would never go in my benchmark.)

In short, how can I extract tasks from my day-to-day work? There seems to be two big extra levers we have:

  • Codebase tasks. This is the heavy-weight approach: you record the Git commit of your codebase at the time you prompted for some new feature to be added, and then when you want to run an eval on a new model you just check out the codebase at that commit and let the end-to-end system go. You'll typically want to execute the modified code, which means you'll also need a way to reliably setup the runtime environment for the code; things like lockfiles can help a lot here.
  • Transcript tasks. You don't actually need the entire codebase to be available to ask an LLM for a completion; you only need the conversation transcript up to the point of the critical generation. If the transcript is mostly your agent system reading in files for context, you can end up with a relatively system generic prompt that can tell you something about other systems. Of course, if you want to actually run the change, you still need the full codebase, which is why this approach is much more amenable if you're going to do some static analysis on the output. For example, if a model keeps adding try: ... except: ... blocks that are suppressing errors, you can take some transcripts where you've caught the model red-handed doing this and make an eval that checks if the model is still doing this. I suspect testing on transcripts works best for testing if changing prompts or rules improves performance, since the transcript itself will put the model into some particular latent space and if it were a different model they might have made different choices leading to a different latent space. Transcripts from thinking models are especially susceptible to this!

I have started adapting Carlini's framework to work better for these cases, although I would love to be told someone has already solved this problem for me. In particular, I am very excited about using transcript tasks to evaluate whether or not things I add to my prompts / triggered rules are helping or not. Current SOTA model instruction following isn't great and I regularly catch models doing behaviors that I explicitly told them not to in the system prompt. I have started some initial analysis over all of my chat logs to find cases where the model misbehaved, although I haven't quite worked out how I want to build an eval out of it.

One word of warning: to make transcript tasks, you need an AI coding system that doesn't obscure how it assembles its underlying prompts (which rules out most of the popular closed source AI code editors.)


I started building evals for a selfish reason: I wanted to be able to tell if modifications to my prompts were doing anything. But I also think there is a broader opportunity that arises if we also publish these benchmarks to the world.

For one, building a real world benchmark on use cases we care about is a way to communicate to the people training AI models whether or not they are doing well or not. Historical evals have focused on LeetCoding, and consequently we have models that would ace any big tech interview and yet on real world tasks will drive you off a cliff at the first opportunity. And this is not just free labor for the top labs: if you believe in open source models, one of the biggest barriers to good small models is having really high quality data. We, the OSS vibe coding community, can directly help here.

I think there is a tremendous opportunity for the open source community to really push the state of the art in coding evaluations. There's only so many benchmarks that I, personally, can create, but if everyone is making benchmarks I could eventually imagine a universe of benchmarks where you could curate the problems that are relevant to your work and quickly and cheaply judge models in this way: a Wikipedia of Coding Benchmarks.

To summarize: every time an LLM fails to solve a problem you ask it for, this is a potential new benchmark. As long as there is a way to automate testing if the LLM has solved the problem, you can turn this into a benchmark. Do this for yourself, and you can quickly have a personal benchmark with which to evaluate new models. Do this at scale, and you can help push the frontier in coding models.

  • April 4, 2025

New Years resolutions for PyTorch in 2025

In my previous two posts "Ways to use torch.compile" and "Ways to use torch.export", I often said that PyTorch would be good for a use case, but there might be some downsides. Some of the downsides are foundational and difficult to remove. But some... just seem like a little something is missing from PyTorch. In this post, here are some things I hope we will end up shipping in 2025!

Improving torch.compile

A programming model for PT2. A programming model is a an abstract description of the system that is both simple (so anyone can understand it and keep it in their head all at once) and can be used to predict the system's behavior. The torch.export programming model is an example of such a description. Beyond export, we would like to help users understand why all aspects of PT2 behave the way it does (e.g., via improved error messages), and give simple, predictable tools for working around problems when they arise. The programming model helps us clearly define the intrinsic complexity of our compiler, which we must educate users about. This is a big effort involving many folks on the PyTorch team and I hope we can share more about this effort soon.

Pre-compilation: beyond single graph export. Whenever someone realizes that torch.compile compilation is taking a substantial amount of time on expensive cluster machines, the first thing they ask is, "Why don't we just compile it in advance?" To support precompiling the torch.compile API exactly as is not so easy; unlike a traditional compiler which gets the source program directly as input, users of torch.compile must actually run their Python program to hit the regions of code that are intended to be compiled. Nor can these regions be trivially enumerated and then compiled: not only must know all the metadata input tensors flowing into a region, a user might not even know what the compiled graphs are if a model has graph breaks.

OK, but why not just run the model, dump all the compiled products, and then reuse them later? This works! Here is a POC from Nikita Shulga where a special decorator aot_compile_sticky_cache swaps between exporting a graph and running the exported product. Zhengxu Chen used a similar idea to export Whisper as a few distinct graphs, which he then manually stitched together in C++ to get a Python-free version of Whisper. If you want training to work, you can more directly integrate AOTInductor as an Inductor backend, e.g., as seen in this POC.. We are a stones throw away from working precompilation, which can guarantee no compilation at runtime, we just need to put the pieces together!

Improving caching further. There are some gaps with caching which we hope to address in the near future: (1) loading Triton cache artifacts takes a long time because we still re-parse the Triton code before doing a cache lookup (James Wu is on this), (2) if you have a lot of small graphs, remote cache ends up having to do lots of small network requests, instead of one batched network request at the beginning (Oguz Ulgen recently landed this), (3) AOTAutograd cache is not fully rolled out yet (James Wu again). These collectively should be worth a 2x speedup or even more on warm cache time.

Fix multithreading. We should just make sure multithreading works, doing the testing and fiddly thread safety auditing needed to make it work. Here's a list of multithreading related issues.

Improving torch.export

Draft mode export. Export requires a lot of upfront work to even get an exported artifact in the first place. Draft mode export capitalizes on the idea that it's OK to generate an unsound "draft" graph early in the export, because even an incorrect graph is useful for kicking the tires on the downstream processing that happens after export. A draft export gives you a graph, and it also gives you a report describing what potential problems need to be fixed to get some guarantees about the correctness of the export. You can then chip away on the problems in the report until everything is green. One of the biggest innovations of draft-mode export is pervasive use of real tensor propagation when doing export: you run the export with actual tensors, so you can always trace through code, even if it is doing spicy things like data-dependent control flow.

Libtorch-free AOTInductor. AOTInductor generated binaries have a relatively small ABI surface that needs to be implemented. This hack from the most recent CUDA Mode meetup shows that you can just create an alternate implementation of the ABI that has no dependence on libtorch. This makes your deployed binary size much smaller!

Support for bundling CUDA kernels into AOTInductor. AOTInductor already supports directly bundling Triton kernels into the generated binary, but traditional CUDA kernels cannot be bundled in this way. There's no reason this has to be the case though: all we're doing is bundling cubins in both case. If we have the ability to bundle traditional CUDA kernels into AOTInductor, this means you could potentially directly embed custom operators into AOTInductor binaries, which is nice because then those operators no longer have to be offered on the runtime (especially if you're commonly iterating on these kernels!)

Export multigraphs. Export's standard model is to give you a single graph that you call unconditionally. But it's easy to imagine a level of indirection on top of these graphs, where we can dispatch between multiple graphs depending on some arguments to the model. For example, if you have a model that optionally takes an extra Tensor argument, you can simply have two graphs, one for when the Tensor is absent, and one for when it is present.

ABI stable PyTorch extensions. It's hard work being a third-party PyTorch extension with native code, because whenever there's a new release of Python or PyTorch you have to rebuild all of your wheels. If there was a limited ABI that you could build your extension against that didn't expose CPython and only relied on a small, stable ABI of PyTorch functions, your binary packaging situation would be much simpler! And if an extension relied on a small ABI, it could even be bundled with AOTInductor binary, letting these export products be truly package agnostic (one of our lessons we learned with torch.package is picking the split between "what is packaged" and "what is not" is very difficult, and people would much rather just have everything be packaged.) Jane Xu is investigating how to do this, and separately, Scott Wolchok has been refactoring headers in libtorch so that a small set of headers can be used independently of the rest of libtorch.

  • January 9, 2025

Ways to use torch.export

Previously, I discussed the value proposition of torch.compile. While doing so, I observed a number of downsides (long compile time, complicated operational model, lack of packaging) that were intrinsic to torch.compile's API contract, which emphasized being able to work on Python code as is, with minimal intervention from users. torch.export occupies a different spot in the tradeoff space: in exchange for more upfront work making a model exportable, it allows for use of PyTorch models in environments where using torch.compile as is would be impossible.

Enable end-to-end C++ CPU/GPU Inference

Scenario: Like before, suppose you want to deploy your model for inference. However, now you have more stringent runtime requirements: perhaps you need to do inference from a CPython-less environment (because your QPS requirements require GIL-less multithreading; alternately, CPython execution overhead is unacceptable but you cannot use CUDA graphs, e.g., due to CPU inference or dynamic shapes requirements). Or perhaps your production environment requires hermetic deploy artifacts (for example, in a monorepo setup, where infrastructure code must be continually pushed but model code should be frozen). But like before, you would prefer not to have to rewrite your model; you would like the existing model to serve as the basis for your Python-less inference binary.

What to do: Use torch.export targeting AOTInductor. This will compile the model into a self-contained shared library which then can be directly invoked from a C++ runtime. This shared library contains all of the compiler generated Triton kernels as precompiled cubins and is guaranteed not to need any runtime compilation; furthermore, it relies only on a small runtime ABI (with no CPython dependency), so the binaries can be used across versions of libtorch. AOTInductor's multithreading capability and low runtime overhead also makes it a good match for CPU inference too!

You don't have to go straight to C++ CPU/GPU inference: you can start with using torch.compile on your code before investing in torch.export. There are four primary extra requirements export imposes: (1) your model must compile with fullgraph=True (though you can sometimes bypass missing Dynamo functionality by using non-strict export; sometimes, it is easier to do non-strict torch.export than it is to torch.compile!), (2) your model's inputs/outputs must only be in torch.export's supported set of argument types (think Tensors in pytrees), (3) your model must never recompile--specifically, you must specify what inputs have dynamic shapes, and (4) the top-level of your model must be an nn.Module (so that export can keep track of all of the parameters your model has).

Some tips:

  • Check out the torch.export programming model. The torch.export programming model is an upcoming doc which aims to help set expectations on what can and cannot be exported. It talks about things like "Tensors are the only inputs that can actually vary at runtime" and common mistakes such as module code which modifies NN modules (not supported!) or optional input types (you will end up with an export that takes in that input or not, there is no runtime optionality).
  • Budget time for getting a model to export. With torch.compile for Python inference, you could just slap it on your model and see what happens. For torch.export, you have to actually finish exporting your entire model before you can even consider running the rest of the pipeline. For some of the more complicated models we have exported, there were often dozens of issues that had to be worked around in one way or another. And that doesn't even account for all of the post-export work you have to do, like validating the numerics of the exported model.
  • Intermediate value debugging. AOTInductor has an option to add dumps of intermediate tensor values in the compiled C++ code. This is good for determining, e.g., the first time where a NaN shows up, in case you are suspecting a miscompilation.

Open source examples: Among other things, torchchat has an example end-to-end AOTInductor setup for server-side LLM inference, which you can view in run.cpp.

torch.export specific downsides:

  • No built-in support for guard-based dispatch (multiple compilations). Earlier, I mentioned that an exported model must not have any recompiles. This leads to some fairly common patterns of code not being directly supported by torch.export: you can't export a single model that takes an enum as input, or has an optional Tensor argument, or accepts two distinct tensor shapes that need to be compiled individually. Now, technically, we could support this: you could imagine a package that contains multiple exported artifacts and dispatches between them depending on some conditions (e.g., the value of the enum, whether or the optional Tensor argument was provided, the shape of the input tensor). But you're on your own: torch.compile will do this for you, but torch.export will not.
  • No built-in support for models that are split into multiple graphs. Similarly, we've mentioned that an exported model must be a single graph. This is in contrast to torch.compile, which will happily insert graph breaks and compile distinct islands of code that can be glued together with Python eager code. Now, technically, you can do this with export too: you can carve out several distinct subnets of your model, export them individually, and then glue them together with some custom written code on the other end (in fact, Meta's internal recommendation systems do this), but there's no built-in support for this workflow.
  • The extra requirements often don't cover important components of real world models. I've mentioned this previously as the extra restrictions export places on you, but it's worth reiterating some of the consequences of this. Take an LLM inference application: obviously, there is a core model that takes in tokens and produces logit predictions--this part of the model is exportable. But there are also important other pieces such as the tokenizer and sampling strategy which are not exportable (tokenizer because it operates on strings, not tensors; sampling because it involves complicated control flow). Arguably, it would be much better if all of these things could be directly bundled with the model itself; in practice, end-to-end applications should just expect to directly implement these in native code (e.g., as is done in torchchat). Our experience with TorchScript taught us that we don't really want to be in the business of designing a general purpose programming language that is portable across all of export's targets; better to just bet that the tokenizer doesn't change that often and eat the cost of natively integrating it by hand.

AOTInductor specific downsides:

  • You still need libtorch to actually run the model. Although AOTInductor binaries bundle most of their compiled kernel implementation, they still require a minimal runtime that can offer basic necessities such as tensor allocation and access to custom operators. There is not yet an official offering of an alternative, lightweight implementation of the stable ABI AOTInductor binaries depends on, so if you do want to deploy AOTInductor binaries you will typically have to also bring libtorch along. This is usually not a big deal server side, but it can be problematic if you want to do client side deployments!
  • No CUDA graphs support. This one is not such a big deal since you are much less likely to be CPU bound when the host side logic is all compiled C++, but there's no support for CUDA graphs in AOTInductor. (Funnily enough, this is also something you technically can orchestrate from outside of AOTInductor.)

Edge deployment

Scenario: You need to deploy your PyTorch model to edge devices (e.g., a mobile phone or a wearable device) where computational resources are limited. You have requirements that are a bit different from server size: you care a lot more about minimizing binary size and startup time. Traditional PyTorch deployment with full libtorch won't work. The device you're deploying too might also have some strange extra processors, like a DSP or NPU, that you want your model to target.

What to do: Use torch.export targeting Executorch. Among other things, Executorch offers a completely separate runtime for exported PyTorch programs (i.e., it has no dependency on libtorch, except perhaps there are a few headers which we share between the projects) which was specifically designed for edge deployment. (Historical note: we spent a long time trying to directly ship a stripped down version of libtorch to mobile devices, but it turns out it's really hard to write code that is portable on server and client, so it's better to only share when absolutely necessary.) Quantization is also a pretty important part of deployment to Edge, and Executorch incorporates this into the end-to-end workflow.

Open source examples: torchchat also has an Executorch integration letting you run an LLM on your Android phone.

Downsides. All of the export related downsides described previously apply here. But here's something to know specifically about Executorch:

  • The edge ecosystem is fragmented. At time of writing, there are seven distinct backends Executorch can target. This is not really Executorch's fault, it comes with the territory--but I want to call it out because it stands in stark contrast to the NVIDIA's server-side hegemony. Yes, AMD GPUs are a thing, and various flavors of CPU are real, but it really is a lot easier to be focused on server side because NVIDIA GPUs come first.

Pre-compiled kernels for eager mode

Scenario: You need a new function or self-contained module with an efficient kernel implementation. However, you would prefer not to have to write the CUDA (or even Triton) by hand; the kernel is something that torch.compile can generate from higher level PyTorch implementation. At the same time, however, you cannot tolerate just-in-time compilation at all (perhaps you are doing a massive training job, and any startup latency makes it more likely that one of your nodes will fail during startup and then you make no progress at all; or maybe you just find it annoying when PyTorch goes out to lunch when you cache miss).

What to do: Use torch.export targeting AOTInductor, and then load and run the AOTInductor generated binary from Python.

Downsides. So, we know this use case works, because we have internally used this to unblock people who wanted to use Triton kernels but could not tolerate Triton's just-in-time compilation. But there's not much affordance in our APIs for this use case; for example, guard-based dispatch is often quite useful for compiled functions, but you'll have to roll that by hand. More generally, when compiling a kernel, you have to make tradeoffs about how static versus dynamic the kernel should be (for example, will you force the inputs to be evenly divisible by eight? Or would you have a separate kernel for the divisible and not divisible cases?) Once again, you're on your own for making the call there.

An exchange format across systems

Scenario: In an ideal world, you would have a model, you could export it to an AOTInductor binary, and then be all done. In reality, maybe this export process needs to be a multi-stage process, where it has to be processed to some degree on one machine, and then finish processing on another machine. Or perhaps you need to shift the processing over time: you want to export a model to freeze it (so it is no longer tied to its original source code), and then repeatedly run the rest of the model processing pipeline on this exported program (e.g., because you are continuously updating its weights and then reprocessing the model). Maybe you want to export the model and then train it from Python later, committing to a distributed training strategy only when you know how many nodes you are running. The ability to hermetically package a model and then process it later is one of the big value propositions of TorchScript and torch.package.

What to do: Use torch.export by itself, potentially using pre-dispatch if you need to support training use-cases. torch.export produces an ExportedProgram which has a clean intermediate representation that you can do processing on, or just serialize and then do processing on later.

Downsides:

  • Custom operators are not packaged. A custom operator typically refers to some native code which was linked with PyTorch proper. There's no way to extract out this kernel and embed it into the exported program so that there is no dependence; instead, you're expected to ensure the eventual runtime relinks with the same custom operator. Note that this problem doesn't apply to user defined Triton kernels, as export can simply compile it and package the binary directly into the exported product. (Technically, this applies to AOTInductor too, but this tends to be much more of a problem for use cases which are primarily about freezing rapidly evolving model code, as opposed to plain inference where you would simply just expect people to not be changing custom operators willy nilly.)
  • Choose your own decompositions. Export produces IR that only contains operators from a canonical operator set. However, the default choice is sometimes inappropriate for use cases (e.g., some users want aten.upsample_nearest2d.vec to be decomposed while others do not), so in practice for any given target you may have a bespoke operator set that is appropriate for that use case. Unfortunately, it can be fiddly getting your operator set quite right, and while we've talked about ideas like a "build your own operator set interactive tool" these have not been implemented yet.
  • Annoyingly large FC/BC surface. Something I really like about AOTInductor is that it has a very small FC/BC surface: I only need to make sure I don't make breaking changes to the C ABI, and I'm golden. With export IR, the FC/BC surface is all of the operators produced by export. Even a decomposition is potentially BC breaking: a downstream pass could be expecting to see an operator that no longer exists because I've decomposed it into smaller pieces. Matters get worse in pre-dispatch export, since the scope of APIs used inside export IR expands to include autograd control operators (e.g., torch.no_grad) as well as tensor subclasses (since Tensor subclasses cannot be desugared if we have not yet eliminated autograd). We will not break your AOTInductor blobs. We can't as easily give the same guarantee for the IR here.

Next time: What's missing, and what we're doing about it

  • December 23, 2024

Ways to use torch.compile

On the surface, the value proposition of torch.compile is simple: compile your PyTorch model and it runs X% faster. But after having spent a lot of time helping users from all walks of life use torch.compile, I have found that actually understanding how this value proposition applies to your situation can be quite subtle! In this post, I want to walk through the ways to use torch.compile, and within these use cases, what works and what doesn't. By the way, some of these gaps are either served by export, or by missing features we are actively working on, those will be some other posts!

Improve training efficiency on a small-medium scale

Scenario: You have a model in PyTorch that you want to train at a small-medium scale (e.g., below 1K GPUs--at the 1K point there is a phase change in behavior that deserves its own section). You would like it to train faster. Locally, it's nice to get a trained model faster than you would have otherwise. But globally, the faster everyone's models train, the less GPU hours they use, which means you can run more jobs in a given time window with a fixed cluster. If your supply of GPUs is inelastic (lol), efficiency improvement means you can support more teams and use cases for the same amount of available GPUs. At a capacity planning level, this can be a pretty big deal even if you are GPU rich.

What to do: In some sense, this is the reason we built torch.compile. (When we were initially planning torch.compile, we were trying to assess if we were going after inference; but inference compilers are a much more crowded space than training compilers, and we reasoned that if we did a good job building a training compiler, inference would work too--which it did!) The dream which we sold with torch.compile is that you could slap it on the top of your model and get a speed up. This turns out to... not quite be true? But the fact remains that if you're willing to put in some work, there is almost always performance waiting at the end of the road for you. Some tips:

  • Compile only the modules you need. You don't have to compile the entire model; there might be specific modules which are easy to compile which will give you the most of the benefit. For example, in recommendation systems, there is not much compute improvement to be had from optimizing the embedding lookups, and their model parallelism is often quite hard to handle in the compiler, so torch.compiler.disable them. NB: This doesn't apply if you want to do some global graph optimization which needs the whole model: in that case, pass fullgraph=True to torch.compile and ganbatte!
  • Read the missing manual. The missing manual is full of guidance on working with the compiler, with a particular emphasis on working on training.

Open source examples: torchtune and torchtitan are two first party libraries which are intended to showcase modern PyTorch using torch.compile in a training context. There's also some training in torchao.

Downsides:

  • The compiler is complicated. One of the things we've slowly been coming to terms with is that, uh, maybe promising you could just slap torch.compile on a model and have it run faster was overselling the feature a teensy bit? There seems to be some irreducible complexity with compilers that any user bringing their own model to torch.compile has to grapple with. So yes, you are going to spend some of your complexity budget on torch.compile, in hopes that the payoff is worth it (we think it is!) One ameliorating factor is that the design of torch.compile (graph breaks) means it is very easy to incrementally introduce torch.compile into a codebase, without having to do a ton of upfront investment.
  • Compile time can be long. The compiler is not a straightforward unconditional win. Even if the compiler doesn't slow down your code (which it can, in pathological cases), you have to spend some amount of time compiling your model (investment), which you then have to make back by training the model more quickly (return). For very small experimentation jobs, or jobs that are simply crashing, the time spent compiling is just dead weight, increasing the overall time your job takes to run. (teaser: async compilation aims to solve this.) To make matters worse, if you are scheduling your job on systems that have preemption, you might end up repeatedly compiling over and over again every time your job gets rescheduled (teaser: caching aims to solve this.) But even when you do spend some time training, it is not obvious without an A/B test whether or not you are actually getting a good ROI. In an ideal world, everyone using torch.compile would actually verify this ROI calculation, but it doesn't happen automatically (teaser: automatic ROI calculation) and in large organizations we see people running training runs without even realizing torch.compile is enabled.
  • Numerics divergence from eager. Unfortunately, the compiler does not guarantee exact bitwise equivalence with eager code; we reserve the right to do things like select different matrix multiply algorithms with different numerics or eliminate unnecessary downcast/upcasts when fusing half precision compute together. The compiler is also complicated and can have bugs that can cause loss not to converge. Expect to also have to evaluate whether or not application of torch.compile affects accuracy. Fortunately, for most uses of compiler for training efficiency, the baseline is the eager model, so you can just run an ablation to figure out who is actually causing the accuracy problem. (This won't be true in a later use case when the compiler is load bearing, see below!)

Improve Python inference efficiency

Scenario: You've finished training your model and you want to deploy it for inference. Here, you want to improve the efficiency of inference to improve response latency or reduce the overall resource requirements of the system, so you can use less GPUs to serve the traffic you are receiving. Admittedly, it is fairly common to just use some other, more inference friendly systems (which I will decline to name by name lol) to serve the model. But let's say you can't rewrite the model in a more serving friendly language (e.g., because the model authors are researchers and they keep changing the model, or there's a firehose of models and you don't have the money to keep continuously porting each of them, or you depend on an ecosystem of libraries that are only available in CPython).

What to do: If Python can keep up with the CPU-side QPS requirements, a way of getting good performance without very much work is taking the Python model, applying torch.compile on it in the same way as you did in training and directly using this as your inference solution. Some tips that go beyond training:

  • Autotuning makes the most sense for inference. In training runs, you have a limited window (the lifetime of the training job) to get return on the investment you spent optimizing the model. In the serving regime, you can amortize over the entire lifetime of your model in inference, which is typically much longer. Therefore, expensive optimization modes like mode="max-autotune" are more likely to pay off!
  • Warmup inference processes before serving traffic to them. Because torch.compile is a just-in-time compiler, you will spend quite a bit of time compiling (even if you cache hit) at startup. If you have latency requirements, you will want to warmup a fresh process with a representative set of inputs so that you can make sure you trigger all of the compilation paths you need to hit. Caching will reduce compile time but not eliminate it.
  • Try skip_guard_eval_unsafe to reduce guard overhead. Dynamo guard overhead can be material in the inference case. If this is a problem, get a nightly and try skip_guard_eval_unsafe.

Open source examples: LLM serving on torch.compile is quite popular: vllm, sglang, tensorrt-llm, gpt-fast (this is technically not an E2E serving solution, but one of its primary reasons for existing is to serve as a starting point so you can build your own torch.compile based LLM inference stack on top of it). Stable diffusion models are also notable beneficiaries of torch.compile, e.g., diffusers.

Downsides:

  • Just in time compilation is a more complicated operational model. It would be better if you didn't have to warmup inference processes before serving traffic to them. Here, torch.compile has traded operational simplicity for ease of getting started. If you wanted to guarantee that compilation had already happened ahead of time, you have to instead commit to some sort of export-based flow (e.g., C++ GPU/CPU inference) below.
  • Model and dependency packaging in Python is unaddressed. You need to somehow package and deploy the actual Python code (and all its dependencies) which constitute the model; torch.compile doesn't address this problem at all (while torch.export does). If you are running a monorepo and do continuous pushes of your infra code, it can be organizationally complicated to ensure people don't accidentally break model code that is being shipped to production--it's very common to be asked if there's a way to "freeze" your model code so that the monorepo can move on. But with Python inference you have to solve this problem yourself, whether the solution is torch.package, Docker images, or something else.
  • Caches are not guaranteed to hit. Do you have to recompile the model every time you restart the inference process? Well, no, we have an Inductor and Triton (and an in-progress AOTAutograd) cache which in principle can cache all of the cubin's that are generated by torch.compile. Most of the time, you can rely on this to reduce startup cost to Dynamo tracing the model only. However, the caches are not guaranteed to hit: there are rarer cases where we don't know how to compute the cache key for some feature a model is using, or the compiler is nondeterministic in a way that means the cache doesn't hit. You should file bugs for all of these issues as we are interested in fixing them, but we don't give a categorical guarantee that after you've compiled your inference program once, you won't have to compile it again. (And indeed, under torch.compile's user model, we can't, because the user code might be the root cause of the nondeterminism--imagine a model that is randomly sampling to decide what version of a model to run.)
  • Multithreading is currently buggy. It should, in principle, be possible to run torch.compile'd code from multiple threads in Python and get a speedup, especially when CUDA graphs or CPP wrapper is used. (Aside: Inductor's default compile target is "Python wrapper", where Inductor's individually generated Triton kernels are called from Python. In this regime, you may get in trouble due to the GIL; CUDA graphs and CPP wrapper, however, can release the GIL when the expensive work is being done.) However, it doesn't work. Track the issue at https://github.com/pytorch/pytorch/issues/136833

Like above, but the compiler is load bearing

Scenario: In both the cases above, we assumed that we had a preexisting eager model that worked, and we just wanted to make it faster. But you can also use the compiler in a load bearing way, where the model does not work without the compiler. Here are two common cases where this can occur:

  1. Performance: A compiler optimization results in an asymptotic or large constant factor improvement in performance can make a naive eager implementation that would have otherwise been hopelessly slow have good performance. For example, SimpleFSDP chooses to apply no optimizations to the distributed collectives it issues, instead relying on the compiler to bucket and prefetch them for acceptable performance.
  2. Memory: A compiler optimization reduces the memory usage of a model, can allow you to fit a model or batch size that would otherwise OOM. Although we don't publicly expose APIs for doing so, you can potentially use the compiler to do things like force a certain memory budget when doing activation checkpointing, without requiring the user to manually specify what needs to be checkpointed.

What to do: Unlike in the previous cases where you took a preexisting model and slap torch.compile, this sort of use of the compiler is more likely to arise from a codevelopment approach, where you use torch.compile while you build your model, and are constantly checking what the compiler does to the code you write. Some tips:

  • Don't be afraid to write your own optimization pass. Inductor supports custom FX optimization passes. torch.compile has done the work of getting your model into an optimizable form; you can take advantage of this to apply domain specific optimizations that Inductor may not support natively.

Open source examples. SimpleFSDP as mentioned above. VLLM uses torch.compile to apply custom optimization passes. Although its implementation is considerably more involved than what you might reasonable expect a third party to implement, FlexAttention is a good example of a non-compiler feature that relies on the compiler in a load-bearing way for performance.

Downsides: Beyond the ones mentioned above:

  • You can no longer (easily) use eager as a baseline. This is not always true; for example, FlexAttention has an eager mode that runs everything unfused which can still be fast enough for small experiments. But if you have an accuracy problem, it may be hard to compare against an eager baseline if you OOM in that case! It turns out that it's really, really useful to have access to an eager implementation, so it's worth working harder to make sure that the eager implementation works, even if it is slow. (It's less clear how to do that with, e.g., a fancy global optimization based activation checkpointing strategy.)

Next time: Ways to use torch.export

  • November 5, 2024