ezyang's blog

the arc of software bends towards understanding

Hugo Migration

This blog has lived on WordPress since it was initially created during a social challenge at MIT to write a blog post a week or pay up with beer. I remember a very important piece of advice I had been given at that time: don’t fuck around with your blog authoring software, just do the minimum viable thing (use WordPress) and focus on writing posts.

It’s 2026 now, the world is different, and in particular the existence of coding agents means that this particular advice falls flat now: it has never been easier to vibe code your own blog software and be done in an afternoon of token generation. Similarly, over the years, I had been increasingly unhappy about my WordPress setup (too hard to add images, ancient version of WordPress, Markdown has taken over the world why am I still writing in ReST, I love scripts.mit.edu but I definitely don’t want to use it to host serious things). So I typed this into ChatGPT and Claude and asked it what I should migrate too.

I currently have a Wordpress blog whose 633 posts are written in ReST using rest-wordpress with some manual code edits, and a theme based on Ashley that I also customized. I’d like to migrate to another blogging solution. I care a lot about ensuring the URLs are preserved. To a lesser extent, I also care about the comments, although I’m willing to compromise here (e.g., an offline flow where I have to explicitly publish comments might be OK; I know static site is difficult to support comments; I also know that email newsletter is popular and I’d like to support this modality if possible. I don’t use a WYSIWYG editor. It’s on Wordpress 5.1.19. It would be nice to have a way for people. Some more niche things plugins I’ve used is WP LaTeX and Share a Draft but I’m willing to do a lossy conversion if necessary (I don’t use LaTeX that much now; it’s just important to make sure the old posts still format correctly). Many of my posts have images and I’d like an easier flow than my current flow (where I have to manually upload my images to my server and then hyperlink them into the post). What do you recommend?

It suggested Hugo, which I had played around with before in AI Blindspots, and I figured, “Why not, I’ll just ask Claude to do the migration. A few hours later, two pro sessions worth of tokens and some PHP export scripts, the entire blog was moved over, no muss, no fuss. I live streamed a portion of this migration process although there’s nothing that special about it.

I actually wasn’t going to write a blog post about this, but I saw Jeff Geerling’s blog also had made frontpage Hacker News. I too haven’t figured out how I am going to solve the comments problem on the new format; I also think I will figure out how to get an email newsletter going from the blog. Here’s to seeing if this can encourage you to use LLMs to make the jump for your own personal site!

The gap between a Helpful Assistant and a Senior Engineer

Let’s suppose you asked an AI coding agent to “implement a CLI calculator”. Imagine if, instead of only writing short Python script, it also started building an automated test suite, a crash reporting mechanism and a telemetry subsystem. You’d be like, “What the fuck is this?”

But now let’s say that you were planning to release this project to users. It would be clearly negligent to not have an automated test suite. A crash reporting mechanism might be overkill for a simple calculator, but for more complicated CLIs interacting with the real world, it may not always be feasible to have reproducer, in which case crash logs are essential. Similarly, a telemetry subsystem would be wildly inappropriate for an open source local-only calculator, but it could make sense for a networked application or a corporate tool of all consenting users. One of the important functions of a senior engineer is to be able to evaluate the context a software project lives in and figure out if we need to do something, even if it isn’t explicitly asked for. This is contrast to a helpful assistant, who is first and foremost obligated to follow the user’s instructions. This leads to a gap between a Helpful Assistant and a Senior Engineer.

In principle, you could prompt the LLM agent to act like a Senior Engineer. In fact, why stop at Senior, let’s tell the LLM to be a Staff Engineer! Imagine that scaling continues: what would you expect the LLM to do when instructed to act in this way? Well, imagine a human L7 engineer who has just been hired by a big tech company to head up some big, new, multi-year initiative. Will they say, “Sure, I can help with that!” and start busily coding away? Of course not: they will go out and start reviewing code, reading docs, talking to people, asking questions, shadowing oncalls, doing small starter tasks–they will start by going out and building context. Here, the “helpful assistant” frame for LLMs is limiting: sure, Claude might ask you a few questions to clarify the task upfront, but if your coding agent starts asking you about “relevant partner teams” and “org-wide priorities for this half” you are definitely going to raise an eyebrow.

What would take for an LLM to be able to act like a Senior Engineer?

  • Perhaps prompting is all you need, and you just need to write enough information about the surrounding context for a project, and once you feed in enough tokens, a smart model can infer the rest of the details you didn’t explicitly right down. This context would be bespoke for every project; you would have to redo this exercise every time you had a new project!

  • Perhaps you can instead prompt a model on how to operate agentically to get the context it needs. This prompt here might be more reusable. But the model may need to actually do wetwork (e.g., talk to humans) to get all of the information it needs. And remember the old saying: the more generic the advice is, the less useful it is. Specificity is king, which leads to…

  • Let’s say we solve continual learning. Instead of crafting the perfect prompt upfront; you could just drop the model as an “embodied” software developer. It reads code, talks to people, does projects, and in doing so slowly develops its latent context, in the same way a human engineer does. Building context will often be bottlenecked in the same way humans are: you can’t get experience related to pushing a feature to production, until you’ve actually pushed the feature to production (however long that takes).

But just like how you shouldn’t micromanage a Senior Engineer, all of these approaches involve fundamentally different expectations about what an AI coding agent should do, and so even if a model and scaffold are capable of doing these things, it is altogether another question if it will be asked to behave in this way. So let’s not take it as a foregone conclusion that METR task times will keep following the empirical trendline: I expect a phase transition when the context an LLM needs to do a good job exceeds the capability of scaffolding to provide on the fly.

Code review as human alignment, in the era of LLMs

I’ve recently been doing a lot of both submitting and reviewing pull requests to PyTorch that were authored with substantial LLM assistance. This is a big difference from earlier this year, where it was clear LLMs worked well for greenfield projects but the code was too hopelessly sloppy for a production codebase. Here are my merged PRs that mention claude code in their description; Jason Ansel has also had a similar experience (Meta only link, here is the list of issues he referenced in his writeup). There already has been increasing discourse (Simon Willison, LLVM) on how code review should adapt to this new era of LLMs. My contribution to this discourse is this: within teams, code review should change to being primarily be a human alignment mechanism.

Here is a simple example: it is well known that LLMs are prone to generating overly defensive code: e.g., they will be constantly sprinkling try...catch everywhere or testing if a variable is some type when system invariants imply that it should always be that type. If someone sends me a PR with these problems, I am not commenting on these problems solely because I want them to be fixed. If that’s all I cared about, I could have just fed my comments directly to claude code. The real problem is that the human who was operating the LLM didn’t agree with me that this defensive code was bad, and the point of the review is to align them with me on what is overly defensive versus not. In the most trivial cases, maybe the engineer didn’t read the LLM output, in which case the remedy is to make them actually read the code. But sometimes real human work has to happen; for example, maybe there is a global system invariant that one has to understand to know if the defensiveness is necessary or not. If we agree about the global system invariants, there’s no reason the code review has to go through me: the original code author can just instruct the LLM to fix problems and keep me out of the loop until they have aligned the LLM output to themselves–at which point we should do the more expensive human to human alignment. The ideal is that I don’t need to ever write review comments about mechanical problems, because they have already been fixed by the original author ahead of time.

Conversely, when I am putting up an LLM generated PR for human review, I am trying to transmit higher level information. How does the new code work? What do I need to know about the existing system to understand this code? This doesn’t even have to be in the PR description: if the LLM proposes a fix that I myself don’t understand, or seems difficult to understand, I will simply instruct it to try it a different way, until the resulting diff is obviously correct. Tokens are cheap: we should expect more out of the author of code, because the cost of generating these PRs has gone way down. Similarly, I am willing to throw out the code and start again; you don’t have to feel bad about wasting my time (I didn’t type it! I spent my time understanding the problem, and none of that is regretted.)

There is a lot of scaremongering about how engineers who don’t pick up AI tools will be left behind. My take on this is that there a number of different skills that make up what it means to be a good software engineer, and it is clear that LLM coding, even today, is clearly reweighting the relative importance of these skills. I care a lot more about your ability to read code, reason about the big picture, communicate clearly and to have good taste, than I care about your ability to mechanically write code. There is an archetype of junior engineer who is not that good at coding but very good at the softer, higher level skills, and I think they will be very valuable in this new world order. Conversely, I think going forward I will have substantially less patience if I have to keep telling you the same things over and over, because I just don’t value raw “ability to code” as much anymore. My ideal state is like that with long time senior teammates: I can trust that they have made good low level decisions, and I can focus on understanding the bigger picture and updating my mental model of how the system works.

Today’s LLMs have no memory: they have to rediscover everything in the system from first principles every time they are run. The purpose of the humans, of the team, is to collectively maintain a shared vision of what, platonically, the system should do. I want code review to reconfigure itself around this purpose.

Learning to love mesh-oriented sharding

Famously, PyTorch and JAX don’t agree on how shardings should be represented: PyTorch takes a mesh-dim oriented view, where for each dimension in your device mesh, you specify what sharding should be applied; JAX takes a tensor-dim oriented view, where for each dimension on your tensor, you say which mesh dimensions (potentially multiple!) shard it. Among my Twitter followers, it is generally agreed that the JAX formulation is more intuitive from a user perspective. OK, fine; if you prefer one representation over another, it’s easy enough to translate between the two representations (in easy situations, at least!) In this post, I want to talk more about the framework implementation side: what is the better internal representation of sharding? I don’t claim to have all the answers, but my motivation for writing this post is to help explain where I currently stand and how I evaluate proposals for evolving DTensor and sharding in PyTorch.

Closed versus open. I am going to make a precise technical claim: JAX sharding is closed, where as PyTorch sharding is (in principle) open. Here, what I mean by closed/open refers to the capability for users to extend a system: traditional ADTs are closed (you can’t add another constructor to an ADT), whereas object-oriented classes are open (you can define a new subclass of a class). Now, technically JAX sharding is open: the jax.sharding.Sharding is a base class that is intended to be subclassed, but to do this you have to define things like _to_xla_hlo_sharding, which is as good as not being supported. The regular class everyone uses, NamedSharding, consists of a mesh and a tuple of mesh axes, with no obvious extension points. I also offer for the defense this unanswered forum post: https://github.com/jax-ml/jax/discussions/23703

In contrast, PyTorch sharding is in principle extensible: the sharding is expressed as a list of Placement, a class which is subclassed to define custom shardings. The extensibility of Placement isn’t really well supported (for example, there’s no way of conveniently adding extra rules for placements to sharding rules), but it works enough that both internally and externally there are implementations of weird placements (internally, StridedShard and NormPartial… and technically all of the non-sum reductions supported by Partial as well as uneven sharding; externally, see RaggedShard and InterleavedShard).

Why does mesh-dim oriented sharding support extensibility in this way? The key is that mesh-oriented sharding is very imperative in nature: you can think of the list of placements as a sequence of transformations you apply to the tensor from left-to-right. Concretely, given the current local tensor (as produced by all of the placements you handled for the mesh dims before the one you’re currently processing), run an invertible function to split this tensor along the current mesh dimension. This gives you a bunch of new local tensors which you recursively continue sharding with the rest of the mesh dims. The invertibility of the function is the only real constraint on what function you can provide (since you need to be able to reassemble the shards back into the original full tensor), but otherwise your choice of function is unconstrained. It is in this sense that Placement is morally extensible.

When designing systems, it is not an unambiguous good to make the system more flexible. Closed systems like JAX’s mean you don’t have to worry about hilariously complicated situations like what if you unevenly shard on the same dimension multiple times (do you have any guarantees on the local sizes of tensors being somewhat balanced?) But sometimes, the use case demands a greater degree of expressivity (in the same way that manual memory management allows you to do more than you can conveniently do in a GC’ed language.)

How expressive does Sharding have to be? One of the primary value propositions of DTensor is that it specifies a standardized representation for saying how a tensor is sharded across your cluster. It’s very good to have this information, because it prevents accidents, like forgetting that a tensor dimension is sharded so you do a reduction on that dimension without first doing a collective and you get subtly wrong results that take weeks to debug. It’s better to have a system that is correct but slow, than it is to have a system that is fast but incorrect.

Being able to express all distributed states is not a terminal goal. There are lots of situations in distributed optimizations where you temporarily need to put the system in a state where it is very difficult to describe exactly how to interpret data across nodes. For example, when you implement ring attention, to avoid communications when performing softmax, you instead perform online softmax. It’s quite difficult to say what the “placements” of the running quantities in online softmax are. In this case, we shouldn’t overly stress ourselves with defining a placement: we should just use local_map or shard_map and absolve ourselves of needing to actually say exactly how data is laid out at any given point in time. But the key is that we should only do this in local regions of code; if we give up and local_map our entire model, we might as well have just not written our code with DTensor at all. So we should seek additional expressivity when it is needed to express how data is being communicated across system boundaries.

Here are some classic examples from LLM training where you need a little bit of extra expressivity, starting with simple cases and becoming more complicated:

  1. Suppose you are applying FSDP to a model, where the parameter sizes haven’t been chosen with parallelism in mind; and in particular, the size of your cluster doesn’t evenly divide with the parameter count. It can be convenient to allow for an uneven sharding to happen in this case, so that the user doesn’t have to manually take care of padding out their tensor so that it can be allgathered.
  2. Say you do a matrix multiply between two tensors which are sharded on the contraction dimension. A reduction is required to communicate the local results into the final tensor. However, sometimes, it can be profitable to delay this reduction, since it can be coalesced into a later reduction. This requires the ability to express that a tensor has a pending reduction on some mesh axis.
  3. If you have both FSDP and row-wise TP, if your FSDP implementation naively shards on the first dim of your weight tensor, you need to ensure that the TP sharding occurs before the FSDP sharding (so that when you undo your FSDP sharding, you have the expected TP sharding ready to go for TP.) This requires the ability to express the order of sharding in a non-standard order (right-to-left, as is supported by list of mesh axes aka PartitionSpec), OR the ability to express that the FSDP is a weird “strided” shard where you don’t have contiguous data, instead you have stripes of data that will then be further sharded by the TP sharding.
  4. Suppose you have a tensor going into a matrix multiply which is not sharded on batch (because you’re naughty and you’re following the standard PyTorch convention of not actually expressing batch-sharding in DTensor) but is sharded on sequence. If you want to treat both batch and sequence as “batch” for the matmul, in PyTorch, this typically requires flattening these two dimensions into a single flat batch dimension. However, this cannot be done, as there is no native Placement that can represent a flattened (Replicate, Shard); however, this works with StridedShard (or InterleavedShard, which is the same thing.) More generally, it is extremely irritating that DTensors cannot reliably have view operations applied to them (that would be supported on plain tensors), and you need weird shard types to be able to handle many view operations.
  5. Traditional FSDP2 assumes that there’s not any requirement for how parameters are placed on nodes; but things like block-wise quantization and structure-aware optimizers need the ability to place a block/parameter on a single device, so that you have access to all the data you need. This won’t be a standard sharding; the shards will be ragged.

I think it’s a worthy use of complexity budget to search for a system design that can handle all of these things, especially since PyTorch’s existing mesh-oriented sharding is already tantalizingly close to supporting this.

Why is adding a new Placement to PyTorch hard? I tend to think, fundamentally, that a mesh-oriented sharding strategy can support arbitrary Placement subclasses. So why does this not work so well in PyTorch? I think there really only are two issues:

  • There is no official way to retroactively add extra sharding propagation rules to existing operators. What I always tell people is that a sharding propagation rule is simply a mathematical equality, saying that map(f, in_placement(x)) == out_placement(f(x)). Mathematical equalities are infinitely compositional: you can always add more true equalities to your system without compromising correctness. But there isn’t actually a way to do this.
  • Many sharding propagation rules are written assuming only shard exists. Placement provides an is_shard method to test if something is a shard (as opposed to replicate/partial), and sharding propagation rules often assume that if this is True, you specifically have a standard, even Shard, as if it was the only sharding in the universe. This means that rules are often secretly buggy when custom Placements are added. StridedShard, in particular, naughtily advertises that it is_shard(), which means that we will happily allow for it to contract with a plain Shard, leading to this bug: https://github.com/pytorch/pytorch/issues/166598 Now, to be clear; often rules WILL work for arbitrary sharding subclasses; for example, if an input dimension is a batch dimension, it doesn’t matter how the data is sliced up, your operation is functorial over that dimension. Will Constable has been working on refactoring our sharding rules to distinguish between the “it’s a batch dimension” situation versus the “I specifically need an even sharding” or “I need these shardings on two inputs to be exactly the same kind of sharding.”

I think with these two issues fixed, and a bit of careful design on what the overrideable API on Placement is for subclasses, I think we can have a very good extensibility story for shardings.

Draw high dimensional tensors as a matrix of matrices

I have recently needed to draw the contents of high-dimensional (e.g., 4D and up) tensors where it is important to ensure that is clear how to identify each of the dimensions in the representation. Common strategies I’ve seen people do in this situation include printing a giant list 2D slices (what the default PyTorch printer will do) or flattening the Tensor in some way back down to a 2D tensor. However, if you have a lot of horizontal space, there is a strategy that I like that makes it easy to identify all the axes of the higher dimensional tensor: draw it as a matrix of matrices.

Here are some examples, including the easy up-to-2D cases for completeness.

0D: torch.arange(1).view() :

0

1D: torch.arange(2) :

0  1

2D: torch.arange(4).view(2, 2 ) :

0  1
2  3

3D: torch.arange(8).view(2, 2, 2) :

0  1    4  5
2  3    6  7

4D: torch.arange(16).view(2, 2, 2, 2) :

0  1    4  5
2  3    6  7

8  9   12 13
10 11   14 15

5D: torch.arange(32).view(2, 2, 2, 2, 2):

0  1    4  5  :  16 17   20 21
2  3    6  7  :  18 19   22 23
              :
8  9   12 13  :  24 25   28 29
10 11   14 15  :  26 27   30 31

The idea is that every time you add a new dimension, you alternate between stacking the lower dimension matrices horizontally and vertically. You always stack horizontally before stacking vertically, to follow the standard row-major convention for printing in the 2D case. Dimensions always proceed along the x and y axis, but the higher dimensions (smaller dim numbers) involve skipping over blocks. For example, a “row” on dim 3 in the 4D tensor is [0, 1] but the “row” on dim 1 is [0, 4] (we skip over to the next block.) The fractal nature of the construction means we can keep repeating the process for as many dimensions as we like.

In fact, for the special case when every size in the tensor is 2, the generated sequence of indices form a Morton curve. But I don’t call it that, since I couldn’t find a popular name for the variation of the Morton curve where the radix of each digit in the coordinate representation can vary.

Knowledge check. For the 4D tensor of size (2, 2, 2, 2) arranged in this way, draw the line(s) that would split the tensor into the pieces that torch.split(x, 1, dim), for each possible dimension 0, 1, 2 and 3. Answer under the fold.

.

.

.

.

.

.

.

.

.

.

.

.

.

.

.

dim=0

>>> [x.reshape(-1) for x in torch.arange(16).view(2,2,2,2).split(1,dim=0)]
[tensor([0, 1, 2, 3, 4, 5, 6, 7]), tensor([ 8, 9, 10, 11, 12, 13, 14, 15])]

     0  1    4  5
     2  3    6  7
   ----------------
     8  9   12 13
    10 11   14 15


dim=1

>>> [x.reshape(-1) for x in torch.arange(16).view(2,2,2,2).split(1,dim=1)]
[tensor([ 0, 1, 2, 3, 8, 9, 10, 11]), tensor([ 4, 5, 6, 7, 12, 13, 14, 15])]

     0  1 |  4  5
     2  3 |  6  7
          |
     8  9 | 12 13
    10 11 | 14 15

dim=2

>>> [x.reshape(-1) for x in torch.arange(16).view(2,2,2,2).split(1,dim=2)]
[tensor([ 0, 1, 4, 5, 8, 9, 12, 13]), tensor([ 2, 3, 6, 7, 10, 11, 14, 15])]

     0  1    4  5
   ------- -------
     2  3    6  7

     8  9   12 13
   ------- -------
    10 11   14 15

dim=3

>>> [x.reshape(-1) for x in torch.arange(16).view(2,2,2,2).split(1,dim=3)]
[tensor([ 0, 2, 4, 6, 8, 10, 12, 14]), tensor([ 1, 3, 5, 7, 9, 11, 13, 15])]

     0 |  1    4 |  5
     2 |  3    6 |  7

     8 |  9   12 | 13
    10 | 11   14 | 15

So you want to control flow in PT2

With contributions from Richard Zou.

PT2’s dominant internal representation, FX graphs, do not directly support control flow (if statements, while loops): they only represent straight-line basic blocks. Most of our graph capture mechanisms are tracing based (fx.symbolic_trace, make_fx, Dynamo), which means that we expect to be able to linearize all conditionals we encounter into a straight line program. Sometimes, you want to work with code that has control flow while working the compiler stack. There is no silver bullet, instead there are a lot of different options with different tradeoffs.

Regional compilation

We have a perfectly good general purpose language that supports control flow: Python. To handle control flow, compile only regions/submodules of your program that have no internal control flow, and then string them together with a standard Python control flow constructs. PT2 compiled regions are compositional with non-compiled regions, “it works.”

Pro:

  • Simple: requires no major model changes
  • Universal: it always works (including data dependent flow, calling into third-party libraries, making an HTTP request, anything!)

Cons:

  • You will not get a full graph this way; you will only get graphs for each region. In particular, you will not be able to do truly global optimizations, nor will you be able to serialize a self-contained Python-less representation of the entire model
  • It can sometimes be inconvenient to structure your program so all the regions you want are compilable. Suppose you have this call graph between modules: A -> B -> C. C is compileable; A is compileable except for its call to B, which is what does the control flow. It’s easy to compile C, but you can’t directly compile A, as it has a B-shaped bit that can’t be compiled. What to do? If you split A so it is pipelined as A1, B, A2, you can then compile A1 and A2, but not B. Dynamo also supports “graph breaks” to automatically perform this split for you, in which case you just disable compilation on B, but graph break generated graphs can be difficult to reason about as the inputs to A2 are implicitly inferred.

Link: Reducing torch.compile cold start compilation time with regional compilation

Multiple graphs dispatched with guards

When the control flow is controlled by arguments that are known ahead of time (no data-dependent), you can also compile at the top level and get the flattened straight-line program for the particular branching you had in this case. Because Dynamo is a symbolic bytecode interpreter, it can automatically determine what inputs were used as part of control flow, and generate guards to validate that we would take the same paths again. If those values change, we will recompile the program at the new values. We dispatch between all the different unrollings of the program we have generated.

Pros:

  • Simple: requires no major model changes
  • You get a full graph for a particular unrolling of loops / conditionals, so global optimizations are possible

Cons:

  • Doesn’t work with data-dependent shapes.
  • You will end up with a graph for every unrolling; for example, if you have a loop that ranges from 1 to 32, you will end up with 32 different graphs. This will increase compile time.

Black box via custom operator

An FX graph just calls operators. The operator internally can have whatever control flow in them they want. So you can always black box a problematic region of your model into an operator and preserve compilation for everything else.

Pros:

  • You get a single, full graph that works for all possible branches

Cons:

  • A custom operator only supports inputs/outputs that fall inside our type system, which means you can only pass simple types like Tensor, int, bool (or pytree-able containers containing these things). There is some in progress work to relax this to allow more opaque types.
  • You have to explicitly declare all the inputs/outputs for the custom operator. This can be tiresome if the black boxed region represents a Module, since all the parameters also have to be directly passed in as well. The larger the region you black box, the bigger the arguments are.
  • You don’t actually get to see the inside of the custom operator from the outside graph, so no optimization over both inside and outside of the custom operator is possible. (Of course, you can always special case this operator in a pass on the outer graph.)
  • There are some bugs related to doing another torch.compile region inside of a custom operator, although these are workaroundable: https://github.com/pytorch/pytorch/issues/151328

Conditional operators / Unroll to max iterations

Do you really, really need a conditional? If you’re doing an if-branch, can you instead rewrite it so that you run both branches and torch.where dispatch to the results? If you’re doing a while-loop, can you unroll it to the max number of iterations and rely on dynamic shapes to cause it to no-op when you’re done and running extra iterations. Basically, this option is to rewrite your model so it doesn’t have Python-level control flow anymore (the conditional can either be done host or GPU side).

Pros:

  • You get a single, full graph that works for all possible branches
  • You are able to optimize inside and outside of the control flow

Cons:

  • You have to rewrite your model
  • For unrolling, if you are close to being CPU-dispatch bound, unrolling and running with zero size could push you over the brink (as zero size dispatches are still not free)
  • For conditional operators, unconditionally both branches increases the compute you need to do, which can be bad if you are compute-bound.

Control flow HOP

torch has special structured control flow operators that avoid unrolling large loops or needing to execute both branches of a control flow statement. If you’re familiar with JAX, these are very similar to the JAX equivalents. They have specific constraints that allow them to be directly compilable by torch.compile. For example, torch.cond accepts two functions (a true_fn and a false_fn) for the two branches and requires that outputs of each function must have the same properties (e.g. shape, dtype).

So far, we have the following “higher-order” operators (HOPs):

These are relatively new, have been used in torch.export for inference, but have not been battle tested for training or performance.

The semantics of these control flow operators are as follows:

def cond(pred, true_branch, false_branch, operands):
    if pred:
        return true_branch(*operands)
    else:
        return false_branch(*operands)

def while_loop(cond_fn, body_fn, carried_inputs):
    val = carried_inputs
    while cond_fn(*val):
        val = body_fn(*val)
    return val

def scan(combine_fn, init, xs, length=None):
    carry = init
    ys = []
    for x in xs:
        carry, y = f(carry, x)
        ys.append(y)
    return carry, np.stack(ys)

Pros:

  • You get a single, full graph that works for all possible branches
  • You are able to optimize inside and outside of the control flow

Cons:

  • You have to rewrite your model.
  • The control flow HOPs are structured: they have specific constraints on the functions (true_fn, false_fn (cond) or body_fn (while_loop)) that can be passed to them. One such constraint is that these functions may not mutate any of their inputs. This may make rewrites difficult because you have to think about code in a “functional”, JAX-like way.
  • Still WIP and they have some quirks especially for training. For example, the backward pass of torch.scan currently requires re-computing the forward pass (instead of just saving intermediates from each iteration of scan).

CFG over FX graphs

If FX graphs give you basic blocks, you can use them as building blocks for a language that does support conditionals, stringing them together with basic blocks. In fact, Helion, a kernel DSL language, does exactly this, as it is common to need to directly write data-dependent conditionals and loops when writing kernels (it otherwise uses all PyTorch API functions, similar to conventional FX graphs). To do this, you would need to write your own Python frontend that parses Python directly to generate the CFG. TorchScript also does this, but TorchScript frontend is unmaintained and we don’t recommend using it (and it also doesn’t generate FX graphs by default.)

Pros:

  • You get a single graph that works for all possible branches
  • You are able to optimize inside and outside of control flow
  • In principle, you can write exactly the control flow you want

Cons:

  • You have to write the frontend, we don’t have one ready for you (TorchScript is not it, you’re princess is in another castle)
  • If your language looks too much like Python and too general purpose, prepare to get on the endless treadmill of feature requests for adding “just one more Python feature” (can we have lists? dataclasses? etc etc) in the frontend (it is more tractable for Helion, as it’s not a general purpose language.)

The Parallelism Mesh Zoo

When training large scale LLMs, there is a large assortment of parallelization strategies which you can employ to scale your training runs to work on more GPUs. There are already a number of good resources for understanding how to parallelize your models: I particularly recommend How To Scale Your Model and The Ultra-Scale Playbook. The purpose of this blog post is to discuss parallelization strategies in a more schematic way by focusing only on how they affect your device mesh. The device mesh is an abstraction used by both PyTorch and JAX that takes your GPUs (however many of them you’ve got in your cluster!) and organizes them into a N-D tensor that expresses how the devices communicate with each other. When we parallelize computation, we shard a tensor along one dimension of the mesh, and then do collectives along that dimension when there are nontrivial dependencies between shards. Being able to explain why a device mesh is set up the way it is for a collection of parallelization strategies is a good check for seeing if you understand how the parallelization strategies work in the first place! (Credit: This post was influenced by Visualizing 6D Mesh Parallelism.)

tl;dr

  • DP, FSDP: ["dp"]
  • HSDP: ["dp_replicate", "dp_shard"]
  • DP+TP, DP+TP+SP: ["dp", "tp"]
  • DP+UlyssesSP: ["dp", "sp"] (verl)
  • DP+CP: ["dp", "cp"]
  • DP+CP+TP: ["dp", "cp", "tp"]
  • PP+DP+…: ["pp", "dp", ...] (torchtitan), ["dp", "pp", ...] (Megatron)
  • PP+DP+CP+TP+EP: ["pp", "dp_replicate", "dp_shard_mod_ep", "dp_shard_in_ep", "cp", "tp"] (torchtitan)

Prologue: Why device mesh? Before we jump into the zoo, why do we have multi-dimensional meshes in the first place? One intuition is that the dimensions of the device mesh are a reflection of the physical constraints of networking between GPUs (there’s a reason why all of the scaling books talk extensively about how the networking for GPUs works; you can’t reason about what parallelization strategy you should use without knowing about this!) Let’s imagine you have 1024 NVIDIA GPUs. You don’t want to treat this 1024 GPUs as an undifferentiated blob of GPUs. Physically, these GPUs are grouped into nodes of eight which have much faster NVLink connections compared to cross-node communication which is done on a slower Infiniband connection. Intuitively, you will want to do something different depending on if you’re doing intra-node communication or inter-node communication.

The device mesh imposes structure on this collection of GPUs. A mesh is typically specified as a tensor size (e.g., (128, 8)) as well as string axis names ala named tensor (e.g., ["dp", "tp"]), and is simply an N-D tensor over a range of GPU indices (typically [0, 1, 2, 3, ...] for GPUs, and a mostly ascending but occasionally permuted sequence for TPUs). We typically think of 2D and 3D tensors as grids and cubes, but I find it is more helpful (especially in higher dimensions) to think of the device mesh as imposing some self-similar (fractal) structure on the GPUs. In the simplest 2D mesh that accounts for intra versus inter node communication, GPUs are first organized into nodes on the inner-most dimension, and then the nodes are collected together in the outer-most dimension to form the cluster. (The self-similar nature of the nodes is important because it tells us how communication occurs across the cluster: to communicate over the outer-most mesh dimension, all the GPU 0s on each node talk to each other, all the GPU 1s, etc.) This is only the very simplest mesh we can create, however; with more complicated parallelization strategies we may impose extra levels of structure, e.g., we may organize nodes into pods of two and four, or we might further divide the eight GPUs of a single node. In other words, the mesh tells us about which GPUs communicate to which other GPUs. This is important to know, because when I want to parallelize our model, I am making choices about how to shard tensors across my GPUs. The mesh tells me which GPUs have the other shards of my tensor; in other words, they are who I have to communicate with when I am doing a computation that requires information about the full tensor and cannot be done with the local shards only.

In the zoo, when we talk about a parallelism strategy, we will talk to how it typically relates to other parallelization strategies in the model, and the device mesh will tell us if it is orthogonal to other parallelisms (a new dimension), multiplexed with another strategy (a reused dimension) or perhaps a completely different hierarchy of communication (multiple meshes in the same model that don’t factor into the other).

Without further ado, here is the zoo!

Data parallelism (DP). Data parallelism predates the concept of device meshes, since you don’t actually need any nontrivial mesh structure to do data parallelism: if you are only doing data parallel, you just shard your input on the batch axis for however many devices you have. This sharding propagates through forwards and backwards until you allreduce to compute the final global gradient for a parameter. If you did make a 1D device mesh (this is useful to think about, because most higher dimensional parallelisms will include some form of data parallelism), you’d probably name your mesh ["dp"], ["ddp"] or perhaps ["batch"].

Let’s talk briefly about how people tend to name device mesh axes. In the PyTorch world, it’s most common to name the axis after the parallelism that it is responsible, so either “dp” or “ddp” (you really shouldn’t call it ddp, but the DataParallel taboo in PyTorch is very real!) The batch name is common in JAX, and is very natural there because when you annotate the sharding of your input, you need to say for each dimension tensor what mesh dim it is sharded over. So when you shard the batch dimension over the batch mesh dim, it looks just like you’re labeling the batch dimension of your tensor as batch, e.g., P("batch", None). (This situation doesn’t happen in PyTorch because shardings of a tensor are specified per device mesh dim, but that’s a story for another day!)

Fully-sharded data parallel (FSDP). This is best understood as an augmentation over DP where weights are also sharded over all GPUs and you just all-gather weights before performing operations (and reduce-scatter in backwards). Because this all-gather is also among all devices, you don’t need another axes in your mesh, and your mesh might also be called ["dp"] in this case, even though you’re actually doing FSDP. Occasionally, you’ll see people name their mesh ["fsdp"] in this case.

Hybrid sharded data parallel (HSDP). HSDP is an extension of FSDP where you shard weights (FSDP) up to the point where you can’t actually do a giant all-gather/reduce-scatter over every GPU, and then replicate these shards to cover the rest of your cluster (DP). It’s also amenable to fault tolerance techniques that make the modeling assumption that it’s OK to lose samples of your batch if a replica fails (you won’t model this with device mesh though!). This is probably the first time you will encounter a 2D device mesh (indeed, the DeviceMesh tutorial in PyTorch specifically uses hybrid sharding as its motivating example), since HSDP doesn’t require any extra model changes on top of FSDP. There are a few common ways to name the mesh axes for HSDP. One way to think about it is that it is FSDP on the inner dimension and DP on the outer dimension, in which case you would say ["dp", "fsdp"]. Another way is to think about what happens to parameters at the various layers of the mesh: the inner dimension shards, while the outer dimension replicates, so you would say ["replicate", "shard"] or perhaps ["dp_replicate", "dp_shard"] to make it clear that you are still doing data parallelism across both of these device mesh dims (in particular, when you split your batches, you split on both the dp_replicate and dp_shard dims–although, to get the final gradients, you can do the reduction hierarchically by first doing a reduce-scatter on “dp_shard” and then doing an allreduce on “dp_replicate”).

Tensor parallelism (TP). Depending on who you ask, tensor parallelism is either about letting you reduce your effective batch size for training or moving you towards reducing the memory usage of activations in your model. In the “reduce effective batch size” framing, the idea behind TP is that you can only scale up DP until your cluster is as large as your batch size. From a modeling perspective, it can be undesirable to have a batch size that is too large, so you can’t just keep increasing your batch size to get more parallelism. Instead, TP allows us to get some extra scaling by sharding over the feature dimension of our matrix multiplies1 (you can shard over either the columns or the rows of your weight matrix, so we will frequently specify if a TP Linear is column-wise or row-wise; in attention, column-wise linear effectively parallelizes the attention computation over attention heads). The communication needed to do TP is fairly exposed (unless you’re doing async tensor parallel), so you typically want to keep the communications for it within a single node. This leads to this classic 2D device mesh for DP+TP: ["dp", "tp"] (or, if you’re a JAXer, you might write ["batch", "model"], where model is used to indicate the inner feature dimension of the model weights being parallelized over.) When someone says 2D parallelism, they’re usually referring to this combo of parallelisms (although I do not recommend using this term–as you can see, it is obviously ambiguous!) Note that tp is the inner mesh dimension, since it benefits the most from the high bandwidth network between GPUs on a single node.

You don’t have to stop with DP+TP, however. If you’re using FSDP with tensor parallelism (remember, “dp” can mean FSDP!), intra-node TP doesn’t improve the amount of inter-node FSDP communication you have to do: however much TP you do, within one TP node you only have one slice of the model and have to talk to everyone else to get their slices. You could solve this by expanding TP to also cross nodes, but in practice mixed intra/inter-node collectives are a lot slower than pure inter-node collectives. This limits the scaling you can get from TP, and so if you’re still hitting limits on FSDP, it can still be useful to apply HSDP to avoid running collectives that are too large. In that case, you’d end up with a mesh like ["dp_replicate", "dp_shard", "tp"].

Sequence parallelism (SP). For this section, we specifically take the definition of sequence parallelism from the Ultrascale Playbook (as distinguished from context parallelism). Although we said that TP is the first step towards reducing the memory usage of activations2, if you literally implement DP+TP based on my descriptions above, you will still end up with more memory spent on activations than you want because there are still parts of the model around the FFN like the LayerNorm need the full hidden dimension to compute mean and variance3. To reduce the memory usage in these segments, you need to shard on something else. So typically what you will see is that the model will alternate between TP (hidden dimension is sharded) and SP (sequence dimension is sharded). Consequently, if you look at the device mesh for a model using DP+TP+SP, it will typically still look like ["dp", "tp"], and instead the tp dimension is multiplexed to be used both for TP and SP. Because TP and SP never occur at the same time, you don’t need a separate dimension for them.

Ulysses sequence parallelism. Ulysses sequence parallelism from DeepSpeed Ulysses is another sequence parallelism strategy that is implemented by verl (because verl is forked so often, it shows up quite prominently if you are looking for examples of init_device_mesh on GitHub code search). It aims to alleviate memory pressure from extremely long sequences, so sequences are sharded on input, and only when attention needs to be computed is an alltoall issued to re-shard on the attention heads rather than the sequence (doing another alltoall to restore the sequence sharding after the attention is done). Importantly, this means it competes with TP for sharding on the attention heads, which is why you also see people use it to replace TP in MoE models, since it has much less communication than TP (at the cost of having to replicate the attention weights). In verl, you will just see a device mesh ["dp", "sp"] when you are using their FSDP backend (which is what supports Ulysses).

Context parallelism (CP). Context parallelism is another form of “sequence” parallelism. Like Ulysses sequence parallelism, sequences are sharded on input; the difference, however, is instead of using an alltoall to re-shard on attention heads, you just do a (distributed) attention on the entire context. You can do this the easy way by just using allgather to get the full context (as was done in llama4) or you can use a fancy kernel like ring attention, which carefully overlaps communication and computation when performing attention. A popular implementation of context parallelism lives in Megatron, which doesn’t directly use PyTorch’s native DeviceMesh abstraction but has an analogous HyperCommGrid. The mesh we see here will be something like ["dp", "cp"] or more commonly ["dp", "cp", "tp"]. Notice that we can have a dedicated mesh dim for CP: CP operates very similarly to SP outside of the attention calls (as it is just plain data parallelism when there is no cross-token dependency), but because it never shards on attention heads, it doesn’t compete with TP and can be used completely orthogonally to TP (TP shards hidden, CP shards sequence).

CP has a pretty interesting interaction with FSDP. Both DP and CP shard the input data (on batch and sequence respectively). It’s pretty common when you do FSDP to just shard over both “dp” (“dp_shard” in HSDP) and “cp”. In torchtitan, we create a flattened mesh dim “dp_shard_cp” specifically for FSDP sharding (a flattened mesh dim is what happens if you take your mess and “forget” about some of the structure; e.g., if you were to do an all-gather, you just all-gather over all the flattened axes). In the HSDP world, “dp_cp” is still a useful concept because this is the combination of axes you want to all-reduce over to, e.g., compute the global average loss.

Pipeline parallelism (PP). Pipeline parallelism is kind of an ugly duckling and people tend to hate on it because you have to rewrite your models to introduce pipeline stages, and you can’t really use things like DTensor with it (unless you do really strange things like how the GSPMD paper “supports” pipeline parallelism–the general consensus is automatic parallelism does not like PP). PP still goes in the device mesh, because it affects how you are organizing your GPUs, but, for example, torchtitan solely uses it to setup PGs for doing the point-to-point communications. I’ve seen both ["dp", "pp", ...] or ["pp", "dp", ...] for meshes with PP, but the order probably doesn’t make too much of a difference as you are likely solidly inter-node at this point. Pipeline parallelism bandwidth use is very low, and latency can be covered up as you can immediately start processing the next batch after triggering an asynchronous send of the previous batch.

Expert parallelism (EP). EP is its own kettle of fish. Expert parallelism only applies over the expert computation of the model, but within this region, we are not sharding parameters as FSDP conventionally sees it: we will commonly have the entire expert’s weights on our node. torchtitan’s WIP expert parallelism implementation, when it has ALL parallelisms on, would look like ["pp", "dp_replicate", "dp_shard_mod_ep", "dp_shard_in_ep", "cp", "tp"], where dp_shard has been split into two mesh dimensions (DP shard modulo EP, and DP shard in EP). dp_shard_mod_ep is conventionally one, but when it is not it represents further FSDP-style sharding of expert weights inside of the expert region (there’s some complication here if you have shared experts along-side your EP-sharded experts). But then dp_shard_in_ep, cp and optionally tp are combined together to give you the expert parallel dimension. It’s actually more intuitive to imagine that you have two distinct meshes: ["pp", "dp_replicate", "dp_shard", "cp", "tp"] and ["pp", "dp_shard_mod_ep", "ep", "tp"]. The keen-eyed may also notice that there is no intrinsic reason the tp mesh size inside and outside of the expert parallel region, but this is not easily done if you have to have a single global device mesh for everything. In fact, there is a WIP PR to have two meshes, one for inside the expert region and one for outside: https://github.com/pytorch/torchtitan/pull/1660

Conclusion. The general concept behind mesh parallelism is that you can compose parallelization strategies without too much fuss. Indeed, the use of, e.g., TP to improve scaling is precisely because it lets you cover your device space without having to expand DP beyond the batch size you want to do. However, as you can see from these concrete examples, it’s not always quite as simple as just stacking all of the parallelisms together one on top of each other. In the end, all the device mesh is doing is creating PGs behind groups of devices as defined by the mesh, so if you want some weird setup where you’re swapping between two device meshes, PyTorch’s general philosophy has been to say, have fun!

Thanks to Horace He, Tianyu Liu and Natalia Gimelshein for helping fact check this post. Any remaining errors are mine!


  1. One more subtlety I want to point out: while we tend to think of TP as sharding the feature dimension of parameters, when we “propagate” this sharding through the network, other intermediate tensors end up getting sharded on the TP dimension as well. In particular, in a transformer block, you will typically have a column-wise linear followed by a row-wise linear, and the intermediate activation will be temporarily sharded on the TP dimension before the row-wise linear runs. ↩︎

  2. I am very carefully using “activation memory” here and not total memory, because total memory usage (what you actually care about) is also a function of peak memory usage, which is subject to transient peaks such as when FSDP does an all-gather to collect parameters. In fact, even without SP, TP will improve your peak memory usage, because unlike FSDP, it’s not necessary to all-gather the full weight matrix to actually perform the matrix multiply. TPs peak memory usage occurs when it all-gathers activations. ↩︎

  3. You will get a little improvement between the column-wise and row-wise linear, since the activations there are sharded. You can turn this into a big improvement by using selective activation checkpointing and forcing recomputation of activations that aren’t sharded! (Plain activation checkpointing tends not to work so well because of the all-gather of the activations.) ↩︎

You could have invented CuTe hierarchical layout (but maybe not the rest of it?)

CuTe is a C++ library that aims to make dealing with complicated indexing easier. A key part of how it does this is by defining a Layout type, which specifies how to map from logical coordinates to physical locations (CuTe likes to say layouts are “functions from integers to integers.”) In fact, CuTe layouts are a generalization of PyTorch strides, which say you always do this mapping by multiplying each coordinate with its respective stride and summing them together, e.g., i0 * s0 + i1 * s1 + .... Although NVIDIA’s docs don’t spell it out, the CuTe’s generalization here is actually very natural, and in this blog post I’d like to explain how you could have invented it (on a good day).

First, a brief recap about strides. PyTorch views allow us to reinterpret the physical layout of a tensor in different ways, changing how we map logical coordinates into physical locations. For example, consider this 2-D tensor:

>>> torch.arange(4).view(2, 2)
tensor([[0, 1],
        [2, 3]])
>>> torch.arange(4).view(2, 2).stride()
(2, 1)

The physical memory reads 0, 1, 2, 3, and if I want to know what the value at coordinate (0, 1) is (row 0, col 1), I compute 0 * 2 + 1 * 1, which tells me I should read out the value at index 1 in physical memory. If I change the strides, I can change the order I read out the physical locations. For example, if I transpose I have:

>>> torch.arange(4).view(2, 2).T
tensor([[0, 2],
        [1, 3]])
>>> torch.arange(4).view(2, 2).T.stride()
(1, 2)

The physical memory hasn’t changed, but now when we read out coordinate (0, 1), we compute 0 * 1 + 1 * 2, which tells me I should read the value at index 2 (which is indeed what I see at this coordinate!)

PyTorch also allows us to “flatten” dimensions of a tensor, treating them as a 1D tensor. Intuitively, a 2-D tensor flattened into a 1-D one involves just concatenating all the rows together into one line:

>>> torch.arange(4).view(2, 2).view(-1)
tensor([0, 1, 2, 3])

We should be able to do this for the transpose too, getting tensor([0, 2, 1, 3]), but instead, this is what you get:

>>> torch.arange(4).view(2, 2).T.view(-1)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

The dreaded “use reshape instead” error! The error is unavoidable under PyTorch striding: there is no stride we can select that will cause us to read the elements in this order (0, 2, 1, 3); after all, i0 * s0 is a pretty simple equation, we can’t simultaneously have 1 * s0 == 2 and 2 * s0 == 1.

Upon learning this, an understandable reaction is to just shrug, assume that this is impossible to fix, and move on with your life. But today, you are especially annoyed by this problem, because you were only trying to flatten N batch dimensions into a single batch dimension so that you could pass it through a function that only works with one batch dimension, with the plan of unflattening it when you’re done. It doesn’t matter that this particular layout is inexpressible with strides; you aren’t going to rely on the layout in any nontrivial way, you just care that you can flatten and then unflatten back to the original layout.

Imagine we’re dealing with a tensor of size (2, 2, 2) where the strides for dim 0 and dim 1 were transposed as (2, 4, 1). It should be OK to flatten this into a tensor (4, 2) and then unflatten it back to (2, 2, 2). Intuitively, I’d like to “remember” what the original sizes and strides are, so that I can go back to them. Here’s an idea: let’s just store the original size/stride as a nested entry in our size tuple. So instead of the size (4, 2), we have ((2, 2), 2); and now analogously the stride can simply be ((2, 4), 1). When I write (2, 2) as the “size” of a dimension, I really just mean the product 4, but there is some internal structure that affects how I should index its inside, namely, the strides (2, 4). If I ask for the row at index 2, I first have to translate this 1D coordinate into a 2D coordinate (1, 0), and then apply the strides to it like before.

Well, it turns out, this is exactly how CuTe layouts work! In CuTe, sizes/strides are hierarchical: a size is actually a tree of ints, where the hierarchy denotes internal structure of a dimension that you can address linearly (in fact, everything by default can be addressed in a 1-D linear way, even if its an N-D object.) The documentation of Layout does say this… but I actually suffered a lot extracting out the high level intuition of this blog post, because CuTe uses co-lexicographic ordering when linearizing (it iterates over coordinates (0,0), (1,0), (2,0), etc. rather than in the more normal lexicographic order (0,0), (0,1), (0,2)). This leads to some truly deranged example code where they print a 2D matrix in conventional lexicographic ordering, and then turn around and say, “But wait, if I have the layout take care of translating the 1D coordinate into an ND coordinate, it is colexicographic!!”:

> print2D(s2xh4)
  0    2    1    3
  4    6    5    7
# sure, why not?

> print1D(s2xh4)
  0    4    2    6    1    5    3    7
# wtf???

In any case, if you want to engage with the documentation, s2xh4 is the important example to pay attention to for understanding the nested semantics. However, note the example is smeared across like five sections and also you need to know about the co-lexicographic thing to understand why the examples print the way they do.

State of torch.compile for training (August 2025)

The purpose of this post is to sum up, in one place, the state of torch.compile for training as of August 2025. Nothing in here isn’t something you might not already know about from elsewhere on the Internet, but we rarely put everything together in one place. The target audience for this document are teams who are evaluating the use of torch.compile for large scale training runs.

First, the basics. torch.compile (also known as PT2) is a compiler for PyTorch eager programs for both inference and training workloads. Speedups from 1.5-2x compared to eager code are typical, and torch.compile also makes it possible to do global optimizations for memory (e.g., automatic activation checkpointing) and distributed communications (e.g., async tensor parallelism).

What is torch.compile’s functionality?

The headline functionality of torch.compile is a decorator you can attach to a function to compile it:

@torch.compile()
def f(x, y):
    ...

Here are some non-functional properties of compile which are important to know:

  • Just-in-time compilation. We don’t actually compile the function until it is called for the first time, and execution blocks until compilation completes. There is both local and remote caching to skip compilation cost when you rerun the model. (Ahead-of-time compilation is possible for inference with AOTInductor, and is being worked on for training.)
  • Compositional with Eager. PyTorch’s original success comes from the extreme hackability of eager mode, and torch.compile seeks to preserve this. The function can be as big or as small part of your training loop as you like; compiled functions compose with autograd, DDP, FSDP and other PyTorch subsystems. (This composition is sometimes imperfect, e.g., in the case of double backwards (not supported), tensor subclasses (requires specific support from the subclass), autograd (differentiating with respect to intermediates returned from a compiled region does not work).) If compilation doesn’t work on a region, you can disable it entirely with torch.compiler.disable() and fall back to eager.
  • Gradient updates are delayed to the end of compiled regions. This arises because PyTorch eager autograd does not support streaming gradients incrementally from a large backward node. (This can be solved by using compiled autograd, but this requires that the entirety of your backwards be compileable.)
  • Graphs may be recompiled. We aggressively specialize on all non-Tensor arguments/globals used in the function to ensure we always generate straight-line computation graphs with no control flow. If those arguments/globals change we will recompile the graph. (Recompilations can be banned with torch._dynamo.config.error_on_recompile = True.)
  • Static by default, recompile to dynamic shapes. We aggressively specialize all sizes to static. However, if we discover that a size varies over time, on the first recompile we will attempt to generate a single compiled region that handles dynamic shapes. We are not guaranteed to be able to compile a model with dynamic shapes. (You can use mark_dynamic to force an input shape to be dynamic, and you can use mark_unbacked to error if we specialize.)
  • Graph breaks transparently bypass non-capturable code. By default, if the compiler encounters a line of code that it is not able to handle, it will trigger a graph break, disabling compilation for that line of code, but still attempting to compile regions before and after it. (This behavior can be banned with fullgraph=True.)
  • Function calls are inlined and loops are unrolled by default. If you have many copies of a Transformer block in your model, your compile time will scale with the number of Transformer blocks. (You can reduce compile time by doing “`regional compilation <https://docs.pytorch.org/tutorials/recipes/regional_compilation.html>`_”, where you only compile the Transformer block instead of compiling the entire model.)
  • NOT bitwise equivalent with eager PyTorch. The biggest divergence with eager PyTorch is that when float16/bfloat16 operations are fused together, we do not insert redundant down/up-conversions. (This can be disabled torch._inductor.config.emulate_precision_casts = True; you can also rewrite eager code to perform operations in higher precision with the understanding torch.compile will optimize it. XLA has a similar config xla_allow_excess_precision which JAX enables by default.) However, we may also make decisions to swap out, e.g., matmul implementations, and there may also be slight divergence that arise from differences in reduction ordering that are unavoidable when compilation occurs. We support ablating the graph capture frontend separately from the compiler backend to help diagnose these kinds of problems.
  • Distributed collectives and DTensor can be compiled, but are unoptimized by default. We are able to capture c10d collectives and also programs that handle DTensors, but we don’t apply optimizations to collectives by default. (There are experimental optimizations that can be enabled, but this is active work in progress.) We generally do not expect to be able to trace through highly optimized distributed framework code.

State of advanced parallelism

For large scale training runs, torch.compile faces stiff competition from (1) PyTorch native distributed frameworks which embrace eager mode and implement all optimizations by hand (e.g., megatron), (2) custom “compiler” stacks which reuse our tracing mechanisms (e.g., symbolic_trace and make_fx) but implement their desired passes by hand, (3) JAX, which has always been XLA first and is years ahead in compile-driven parallelism techniques.

Here is where we currently are for advanced parallelism (with an emphasis on comparing with JAX):

  • DTensor, a “global tensor” abstraction for representing sharded tensors. DTensor is a tensor subclass which allows us to represent tensors which are sharded over an SPMD device mesh. The shape of a DTensor reflects the global shape of the original full tensor, but it only stores locally a shard of the data according to the placement. Here are some important details:

    • Shard placements. Unlike JAX placements, DTensor placements are “device mesh” oriented; that is to say, you conventionally specify a device mesh dim size list of placements, and Shard(i) indicates that the ith dimension of a tensor is sharded. This is opposite of JAX, which is “tensor” oriented. For example, given a 2-D mesh [“dp”, “tp”], a tensor with [Replicate, Shard(0)] in DTensor placement (or {“dp”: Replicate, “tp”: Shard(0)} with named device mesh axes), would correspond to a JAX placement of P(“tp”, None). The reason for this is that DTensor supports a Partial placement, which indicates that an axis on the device mesh has a pending reduction. Partial shows up ubiquitously from matrix multiplies, and it isn’t associated with any particular tensor axis, making it more convenient to represent in a device-mesh oriented formulation. The tradeoff is that device-mesh oriented placements don’t naively support specifying sharding ordering, e.g., suppose I want to shard a 1-D tensor on tp and then dp, in JAX I’d represent this as P((“tp”, “dp”),) but this order cannot be disambiguated from [Shard(0), Shard(0)] and in fact DTensor always forces left-to-right sharding. There is currently a proposal to extend our sharding specification to support ordering to bring us to parity with JAX expressiveness, but it is not yet implemented.
    • Autograd. DTensor is directly differentiable; we run autograd on programs that have DTensors (as opposed to desugaring a DTensor program to one with regular Tensors and differentiating it). This ensures that the sharding strategy of a primal and its corresponding tangent can diverge. This is parity with JAX.
    • Python subclass of Tensor. Unlike JAX, DTensor is a separate subclass from Tensor. However, Tensor and DTensor interoperate fine; a Tensor can simply be thought of as a DTensor that is replicated on all dimensions. DTensor is implemented in Python, which makes it easy to modify and debug but imposes quite a bit of overhead (for example, FSDP2 does not directly accumulate gradients into DTensor, because with thousands of parameters, performing detach and add operations on DTensor is a bottleneck). Still, despite this overhead, DTensor was designed for good eager performance, and extensively caches the results of sharding propagation so that in the fastpath, it only needs to lookup what redistribute it should perform and then directly dispatches to the local eager operation. However, this caching strategy means that overhead can be quite high for workloads with dynamic shapes, as the cache requires exact matches of all input shapes.
    • Compilation. DTensor is compilable by torch.compile, and doing so will desugar it into its underlying collectives and eliminate any eager mode DTensor overhead (even if you do not perform any other optimizations.) However, DTensor with dynamic shapes in compile is not well supported, see http://github.com/pytorch/pytorch/issues/159635 (we don’t think this is currently critical path for any critical use cases, so a relatively junior engineer has been chipping away at it.)
    • Greedy propagation. Because DTensor must work in eager mode, it only implements greedy shard propagation, where for every eager operation we greedily pick whatever output shard minimizes the collective costs of an operation. It is work in progress to support backward propagation of sharding with the assistance of a compiler-like framework.
    • Operator coverage. DTensor requires sharding propagation rules to work for operations. If a sharding propagation rule is not implemented, DTensor will fail rather than trigger an inefficient allgather to run the operator under replication. We don’t currently have full coverage of all operators, but important operators for transformer models like llama3 are all covered (sharding rules are defined here). You can write custom shardings for user defined operators.
    • Jagged sharding. We do not support a “jagged sharding” concept which would be necessary for expert parallelism with imbalanced routing. However, we believe that our existing sharding rules could largely be reused to support such an idea. As dynamism would only be exposed in the local tensor for the jagged shard, jagged shards don’t suffer from the dynamic shapes problems mentioned in the compilation section.
    • Ecosystem. We are committed to DTensor as the standard representation for sharded tensors, and DTensor is integrated with checkpointing, FSDP2, SimpleFSDP, AutoParallel, torchtitan, among others.
  • Functional collectives. If you don’t like DTensor, we also support “functional collectives”, which are non-mutating versions of collective operations that can be used to manually implement SPMD operations in a compiler-friendly way without needing DTensor. (In fact, if you use traditional collective APIs and compile them, we will silently translate them into functional collectives for compiler passes.) When compiled, functional collectives don’t necessarily force allocation of the output buffer as they can be re-inplaced. Importantly, functional collectives currently do NOT support autograd, see https://discuss.pytorch.org/t/supporting-autograd-for-collectives/219430

  • Graph capture. There are two particularly popular graph capture mechanisms which people have used to perform distributed optimizations separate from model code. All graph capture mechanisms produce FX graphs, which are a simple Python basic block IR representation with no control flow, which is entirely unopinionated about what actual operator set can occur in the graph.

    • Symbolic_trace. This was the original graph capture mechanism and is quite popular, despite its limitations. It is implemented entirely with Python operator overloading and will give you exactly whatever operations are overloadable in the graph. We consider this largely a legacy pipeline as you are unable to trace code involving conditionals on shapes and you end up with a graph that has no useful metadata about the shapes/dtypes of intermediate values. For example, PiPPY, a legacy stack for performing pipeline parallelism, was built on top of symbolic_trace graph capture.
    • make_fx/torch.export. This graph capture mechanism works by actually sending (fake) tensors through your program and recording ATen operators. There are a number of different variants: e.g., whether or not it is a Python tracing approach ala JAX jit, or whether it uses sophisticated bytecode analysis ala Dynamo; similarly, there are various levels of IR you can extract (pre-dispatch, post-dispatch; also, operators can be decomposed or kept as single units). Our compiler parallelism efforts are built on top of this capture mechanism, but there is nothing stopping you per se from writing your own graph pass on top of this IR. In practice, this can be difficult without PyTorch expertise, because (1) integrating a traced graph into PyTorch’s autograd system so it can interoperate with other code is quite complicated to do in full generality, (2) the exact operator sets you get at various phases of compilation are undocumented and in practice very tied to the Inductor lowering stack, and it is poorly documented on how to prevent operators from getting decomposed before your pass gets to them.
  • Not SPMD compiler by default. torch.compile does not assume the program being compiled is SPMD by default, which means it will not do things like drop unused collectives (you can change this behavior with a config flag). Additionally, the default mode of use for torch.compile is to compile in parallel on all nodes, which means care has to be taken to ensure that every instance of the compiler compiles identically (only one rank recompiling, or compilers making different decisions, can lead to NCCL timeout). We ultimately think that we should compile a program once and send it to all nodes, but as this is not currently implemented, the general approach people have taken to solve this problem is to either (1) eliminate all sources of divergent behavior from ranks, e.g., don’t allow the compiler to look at the actual size for dynamic inputs when making compiler decisions, or (2) introducing extra collectives to the compiler to communicate decisions that must be made consistently across all ranks.

Our vision for the future of advanced parallelism, spearheaded by the in-progress SimpleFSDP and AutoParallel, is that users should write single-node programs that express mathematically what they want to do. These are then transformed into efficient distributed programs in two steps: (1) first, collectives are inserted into the graph in a naive way (i.e., simply to express what the sharding of all intermediates should be), and (2) the collectives are optimized to handle scheduling concerns such as pre-fetching and bucketing. AutoParallel sets a GSPMD style goal of automatically determining a good enough sharding for a program–it should be able to rediscover data parallel, tensor parallel, even expert parallel(!)–but SimpleFSDP sets a smaller goal of just inserting collectives in the pattern that FSDP would mandate, and then writing FSDP-specific optimization passes for recovering FSDP2’s performance. It is very common to write domain specific optimizations; for example, async tensor parallelism is also implemented as a pass that detects TP patterns and rewriting them to async TP operations. Unlike JAX, which started with a very generic solver and has needed to add more manual escape hatches over time, PyTorch has started with writing all of the distributed patterns exactly by hand, and we are only recently adding more automatic mechanisms as an alternative to doing everything by hand.

State of optimization

torch.compile performs many optimizations, but here are some particularly important ones to know about:

  • Inductor. Inductor is our backend for torch.compile that generates Triton kernels for PyTorch programs. It has very good coverage of PyTorch’s operator set and can do fusions of pointwise and reductions, including in the patterns that typically occur for backwards. It also is able to fuse pointwise operations into matmuls and autotune different matmul backends (including cuBlas, cutlass and Triton) to select the best one for any given size. When people talk about torch.compile speeding up their programs, they are conventionally talking about Inductor; however, you don’t have to use torch.compile with Inductor; for example, you could run with AOTAutograd only and skip Inductor compilation.
  • CUDA graphs. Inductor builds in support for CUDA graphing models. Unlike manual CUDA graphs application, we can give better soundness guarantees than manual CUDA graphs application (e.g., forgetting to copy in all input buffers, CPU compute inside the CUDA graph region). torch.compile CUDA graphs is typically used with Inductor but we also offer an eager-only cudagraphs integration (that is less well exercised).
  • Automatic activation checkpointing. With torch.compile, we can globally optimize the memory-compute tradeoff, much better than the activation checkpointing APIs that eager PyTorch supports (and require the user to manually feed in what they want checkpointed or not). However, some folks have reported that it can be quite miserable tuning the hyperparameter for AC; we have also found bugs in it.
  • FP8 optimizations. One big success story for traditional compilation was adding support for a custom FP8 flavor. With torch.compile, they didn’t have to write manual kernels for their variant. This has since been upstreamed to torchao.
  • Flex attention. Flex attention usage continues to grow, with 632 downstream repo users in OSS (vs 125 in Jan ‘25). It has been used to enable chunked attention, document masking and context parallelism in llama family models. It is a really good research tool, although sometimes people complain about slight numerical differences.
  • Helion. Helion is an actively developed project aiming to go beta in October this year which offers a higher level interface for programming Triton kernels that looks just like writing PyTorch eager code. It relies heavily on autotuning to explore the space of possible structural choices of kernels to find the best one. It is not production ready but it is worth knowing that it is coming soon.

State of compile time

torch.compile is a just-in-time compiler and as such, in its default configuration, compilation will occur on your GPU cluster (preventing you from using the GPUs to do other useful work!) In general, most pathological compile times arise from repeated recompilation (often due to dynamic shapes, but sometimes not). In Transformer models, compile time can also be improved by only compiling the Transformer block (which can then be compiled only once, instead of having to be compiled N times for each Transformer block in the model).

We don’t think caching is an ideal long-term solution for large scale training runs, and we have been working on precompile to solve the gap here. Precompile simply means having compilation be an ahead-of-time process which produces a binary which you can directly run from your training script to get the compiled model. The compilation products are built on top of our ABI stable interface (developed for AOTInductor) which allows the same binaries to target multiple PyTorch versions, even though PyTorch the library does not offer ABI compatibility from version to version.

How do I get started?

The most typical pattern we see for people who want to make use of torch.compile for large-scale training is to fork torchtitan and use this codebase as the basis for your training stack. torchtitan showcases PyTorch native functionality, including torch.compile–in effect, it shows you how to use features in PyTorch together in a way that lets you do large-scale training. From there, swap out the components you are opinionated about and keep the things you don’t care about.

Vibe coding case study: ScubaDuck

A lot of strong engineers that I know haven’t really taken a serious look at AI coding; they’ve used LLMs to ask questions or write simple scripts and appreciate that it is a useful tool, but haven’t actually tried building a nontrivial application entirely from scratch in vibe coding style (here, I use the term in its original meaning: when you do AI coding without carefully reviewing the output). This is understandable: if you’re not working on a green field project, there aren’t that many opportunities to write code in this style–standard practice for established projects is that someone else needs to review all of the code you write: this is a bad match for vibe coding! So in this post, I want to give a concrete case study of a nontrivial system that was entirely vibe coded (ScubaDuck), to argue the following claims:

  1. AI coding can be done on a manager’s schedule: you don’t need continuous blocks of coding time and context-switching is considerably less harmful. ScubaDuck was implemented in three days of part time work, where all of the work happened when the baby was napping.
  2. AI coding substantially lowers the cost of doing projects in tech stacks you are less familiar with. ScubaDuck is mostly JavaScript UI code, which is not something I write on a day-to-day basis.
  3. AI coding is an unlock for “sidequests”: support software that’s ancillary to your main task that is nice to have, but not essential. If previously you would have decided the cost outweighed the benefit, AI coding reducing the cost means you should redo these calculations.
  4. Vibe coding works and can produce working software. ScubaDuck is an existence proof that vibe coding is a viable strategy for generating JavaScript UI code (NB: I don’t claim vibe coding will work for all domains, nor do I claim this is the only domain for it works. Hopefully you can also build some intuition for where it is more or less likely to work). You will not one shot it (ScubaDuck was 150 prompts in the end) but if you are prompting the LLM to also generate tests, you can reliably fix issues without causing regressions to existing code.
  5. Vibe coding is good for situations where buggy software is low impact; be on the lookout for ways to engineer this sort of situation. ScubaDuck is a read-only interface, where the only downside to being buggy is you can’t issue the queries you want to issue.

Update: You can see all of my prompts and the resulting agent trajectories at scubaduck-prompts.

What is ScubaDuck?

ScubaDuck is a discount implementation of Meta’s internal Scuba realtime database system. You can read more about what exactly this is on GitHub, but it’s not so important for the purposes of this post: the key details you need to know about ScubaDuck is that it consists of a Python server that exposes an API to perform queries against a DuckDB database, and an HTML and JavaScript frontend application which implements the forms for building these queries and rendering of the output data. Both the forms and output data rendering have nontrivial JavaScript enhancements: some form inputs are chip inputs and support autocomplete, and the time series view is an SVG chart. All of these components were coded from scratch, so the project has no third-party JavaScript dependencies.

So on the one hand, this project is pretty simple. There are no stringent performance or uptime requirements, it’s a pretty standard server-client program that the LLM has seen millions of times before (this is good!) On the other hand, the exact behavior of the frontend UI is quite intricate and would be very difficult to one-shot in a single prompt. Indeed, as I was coding and testing the application, I frequently ran into situations that I didn’t anticipate in my original specification, and that I had to ask Codex to refine. Another way to put it is that ScubaDuck is a relatively simple functional specification (although this too was not one shot), but I did a lot of polishing of small behaviors so that the interface behaved in the way that I expected Scuba to behave. Here, it was helpful that I had a very clear idea of what I wanted (since I’ve used Scuba quite a lot at work).

Going into ScubaDuck, I had a pretty good sense that this project should be a good fit for LLMs. HTML, JavaScript and Python are all extremely high resource languages, and I’d heard lots of people raving about how good LLMs were at transforming wireframes and mockups into fully functional websites. It is also fully self contained and straightforward-ish to test (only “ish” because you do have to use something like Playwright to actually test the frontend UI, which honestly is a slog. But fortunately, the LLM can write the tests for you!) One design decision I made, which I didn’t originally anticipate but worked out in the end, was the decision to not use any third-party JavaScript libraries. This was by accident: Python has no native of bundling third party JavaScript, but I wanted the tool to work offline. I wasn’t sure if you could vibe code an SVG charting library from scratch, but apparently you can and it’s quite easy!

Agent setup

ScubaDuck was implemented with OpenAI Codex in the cloud (not the CLI tool). Codex’s cloud offering requires you to initialize a hermetic environment which the coding agent can execute commands in. It’s pretty well known now that AI coding agents work much better if they are able to run the code they write and see if it worked or not, so this is quite an important part of the process. Unfortunately, this was somewhat time consuming trial and error to setup. I had a fairly detailed initial prompt, and what I would do was submit it to Codex, watch it fail, read over the trajectory (the agent logs) to see what happened (Codex wanted to use npm! Codex couldn’t download something from the internet! Codex tried to use a package that wasn’t available!) and then fixed whatever environment misconfiguration had caused it to fail, or edited AGENTS.md to instruct it to not do some behavior. According to my history, the first day of the project was spent unsuccessfully trying to get the project setup, and my first successful Codex PR only happened on May 19.

At the end of setup, I had the following:

  1. A pyproject.toml with exactly the dependencies I wanted to be used (duckdb, flask and python-dateutil), a lockfile for it (since I was using uv) and my preferred configuration for various tools (pytest, ruff). I’m a big fan of pytest-xdist for vibe coded projects, since you can prompt the LLM to write tests that will work when run in parallel and it does a pretty good job at this. Later I’d also add a pyright configuration, though initially I left it out because I saw Codex doing some strange things on account of duckdb being untyped, and I didn’t want to debug it at the time (the fix, by the way, is instructing the LLM to define stubs as necessary in this case.)
  2. An AGENTS.md file with some basic instructions to try to get Codex to stop doing things I saw it doing in the initial trajectories that I didn’t want it to do. Nothing fancy, just if you see Codex do something bad, tell it not to do it in AGENTS.md. A good example of this is the “There are no nested AGENTS.md files, this is the only agents file”: Codex is post-trained to look for nested AGENTS.md files, but you can save a few tool calls if you tell it there aren’t any. (Note: folklore for Claude 3.7 is that instruction following for this sort of rules following was not great. Word on the street is that both Codex and Claude 4 are substantially better at this. Extra note: For uv users, another notable instruction in AGENTS.md is how to activate the venv, since at time of writing I couldn’t get Codex to make this happen automatically.)
  3. A setup script for the environment. This took the most debugging, because Codex runs all Internet access through a proxy and sometimes it works imperfectly.

After I got my initial prompt to generate a first draft of the application, I was able to begin vibe coding in earnest.

The Human-Agent loop

The basic vibe coding loop works like this:

  1. Interact with the application and find things that are broken
  2. Prompt the LLM to fix them
  3. Repeat

For example, after the very first PR, some very mild poking around immediately revealed the bugs fixed in #2:

There’s a race condition in the current test logic for matching against table contents in run_query. Specifically, if there were previously valid results in lastResults, and for some reason Dive doesn’t do anything, then we will still see the old results. The testing framework should explicitly clear lastResults before attempting an interaction.

…and #3:

Filter functionality does not work. We will first add a failing test, and then fix it. The failing test should click “Add Filter”, then select “user” as the field, and then add an “alice” chip (by typing alice in the text box and pressing ENTER). Then when we dive, we should see two alice rows. Right now, NO request is issued at all when we click Dive. Diagnose and then fix the problem.

Prompt the agent to write tests. It’s very helpful to prompt the agent to generate tests for whatever bugs its fixing. For frontend code, I decided to use playwright to write these tests. An example in #11:

def test_header_and_tabs(page: Any, server_url: str) -> None:
    page.goto(server_url)
    page.wait_for_selector("#order_by option", state="attached")

    header = page.text_content("#header")
    assert "sample.csv" in header
    assert "events" in header

    assert page.is_visible("#settings")
    assert page.is_hidden("#columns")
    page.click("text=Columns")
    assert page.is_visible("#columns")
    cols = page.locator("#column_list li").all_inner_texts()
    assert "timestamp" in cols
    assert "event" in cols
    page.click("text=View Settings")
    assert page.is_visible("#settings")

    btn_color = page.evaluate("getComputedStyle(document.querySelector('#dive')).backgroundColor")
    assert "rgb(0, 128, 0)" == btn_color

    sidebar_overflow = page.evaluate("getComputedStyle(document.querySelector('#sidebar')).overflowY")
    view_overflow = page.evaluate("getComputedStyle(document.querySelector('#view')).overflowY")
    assert sidebar_overflow == 'auto'
    assert view_overflow == 'auto'

This is kind of a very manual and somewhat brittle test. I doubt I would have had the patience to manually write and debug this. But the LLM will happily do it. There is also a fear that the test might not actually be testing anything for real. But because these tests are all generated from bug reports, you can prompt the model to write the test first, check that it’s failing, and then fix the problem. I didn’t instruct this in the above prompt, but I did instruct it for #155, for example. The trajectory shows that the LLM demonstrates it can repro the bug before fixing the issue:

I’ll add a case-insensitive test for “samples” to check server behavior with lowercase input.

(scubaduck) root@5b69743466e4:/workspace/scubaduck# python - <<'EOF'
> import json
> from scubaduck import server
>
> app = server.app
> client = app.test_client()
> payload = {
>     "table": "events",
>     "start": "2024-01-01 00:00:00",
>     "end": "2024-01-03 00:00:00",
>     "graph_type": "table",
>     "order_by": "samples",  # lowercase
>     "order_dir": "DESC",
>     "limit": 10,
>     "columns": [],
>     "group_by": ["user"],
> }
> rv = client.post("/api/query", data=json.dumps(payload), content_type="application/json")
> print('status', rv.status_code)
> print('data', rv.get_json())
> EOF
status 400
data {'error': 'Unknown column: samples'}
(scubaduck) root@5b69743466e4:/workspace/scubaduck#

I’ve found the bug happens when “Samples” is lowercase. I’ll adjust the query to handle case insensitivity and add a test for it. Time to modify server.py to fix the issue!

Beware BC-breaking changes. The LLM generated tests worked well when I didn’t need to change behavior in a BC-breaking way. Example prompt:

Change time column to default to none unconditionally. (I want to default it to a time column if we find a good candidate but this threshold needs tuning and I don’t have the tuning right now.)

This hit a Codex timeout:

This attempt to update the default time column led to extensive frontend test failures that require deeper changes. The environment made it difficult to fully diagnose the Playwright test behaviors within the time available, so the work could not be completed.

In standard software engineering practice, when this happens, decouple the BC compatible and BC breaking changes!

Make it so that Time Column can be set to (none). When it is set this way, the Start/End fields are hidden and we don’t apply a filter on time range. (#115)

and then later instead of defaulting the time column to none, I added a heuristic to pick a column that looked like time, which picked the same column that all of the existing tests had also expected to be called with.

Refactors have to be split up. Codex’s timeout means that you can’t ask it to do too much in one go. Here’s a prompt that timed out:

scubaduck/index.html has gotten a bit long. Let’s split out some of the JS code into dedicated JS files for their functionality. Also setup the necessary Flask scaffolding to serve these JS files. I think splitting out these specific components would be good:

  • Dropdown implementation
  • Sidebar resizing
  • JS controlling the View Settings (e.g., updateDisplayTypeUI, as well as one off interactions on form elements, columns handling, filter handling, the actual Dive implementation (including query updating), reading in defaults from query string)
  • Table rendering (e.g., formatNumber, sorting)
  • Chip input implementation
  • Chart rendering (showTimeSeries)

Make changes to AGENTS.md or README.md describing the structure so you can quickly find where the components you need are

I eventually did manage the refactor by prompting Codex to individually move out the pieces I wanted to extract one-by-one. This is a place where I think Claude Code probably would have performed better.

Parallelizing tasks. As you can see from the lengths of my prompts, it does take a while to write a good prompt; you’re basically writing a bug report with enough detail that the LLM can repro it and then fix it. So sometimes I would be bottlenecked on prompt writing. However, sometimes the prompts were quite short. In those cases, Codex encourages you to submit more tasks that can run in parallel. I found this worked well, and I’d sometimes have as many as five instances going (once again, rate limited by discovering problems, making designs and typing prompts!) One irritation is when the tasks end up conflicting with each other. Sometimes the conflicts are easy to fix, but if it feels nontrivial, it’s often better to just ask Codex to redo one of the PRs on latest main after the other has landed. To avoid merge conflicts, it helps to have only one “main feature” agent going at any time, and then ask the agent to do random bugfixes in parallel with it. Once you have no more tasks to get running, you can go do something else while you wait for the agents to finish (manager schedule!)

Prompting

As a reminder, I’ve posted all of my prompts (including the ones that failed) at scubaduck-prompts, and I think it’s helpful to skim through them to get a flavor of what I was asking the LLM. But to summarize, what did I spend most of my time on prompting Codex to do? My general vibe (ahem) is that I spent most of my time doing minor enhancements, where I instructed Codex to make some part of the program work slightly differently, in a way that was previously unspecified from the previous prompt. The metaphor I had in my head while I was working on the project was like that of a sculptor chiseling away marble: in the beginning, anything is possible, but as I kept prompting, I continuously narrowed down the space of possible programs I had until I had exactly the one I wanted. One big thing I want to note is that Codex rarely needed to make updates to my tests; for the most part, tests that were added never got taken away, because I never “changed my mind”. I suspect that the vibe coding process would have been rockier if I was having to change behavior frequently.

One of the things that surprised me the most about the process was how easy it was to implement a line chart in SVG with Codex. My first prompt resulted in a chart that looked broken on the test data:

We’re going to add a new View type, to go along with Samples and Table: Time Series. Time Series supports all the fields that Table supports, and a few more:

  • X-axis: Main group by dimension, e.g., the x-axis on time series view. This is our custom dropdown selector, but only time columns are populated here. It should prefer a default setting from the following list, most preferred first: “time”, “timestamp”
  • Granularity: Choose the time interval between data points on the chart. For example, a granularity of 1 hour means there will be a data point every 60 minutes that is aggregated with the chosen Aggregate function over the data for the granularity period before point. This is a plain drop down. The valid values are: Auto, Fine, 1 second, 5 seconds, 10 seconds, 30 seconds, 1 minute, 4 minutes, 5 minutes, 10 minutes, 15 minutes, 30 minutes, 1 hour, 3 hours, 6 hours, 1 day, 1 week, 30 days. The semantics of the Auto setting is that it sets the interval to whatever would result in maximum 100 buckets (if there are not enough data points for that many buckets, it just picks the finest time interval that makes sense), and Fine which sets the interval to 500 buckets.
  • Fill Missing Buckets: This is a dropdown. For now, it has the settings “Fill with 0 (Per Series)” (default), “Connect (Per Series)” and “Leave blank”.

Additionally, the default setting of Limit is 7, as it controls how many elements from group by will be plotted (the actual number of lines plotted could be a multiple of this, as we will plot every selected Column).

Unlike Samples and Table, we will instead display a line chart in the right panel. To plot the line chart, we will implement it by hand with JS and SVG, similar to how highcharts implements it. We will not use any third party dependencies. Lines will be plotted as paths, no smoothing, no dots for individual data points. Each series (as generated by group by) should be plotted with a different color, assigned using a best practices color palette for graph design. There should be a rendering of x-axis and y-axis; the x-axis should have slanted labels to aid readability. When we mouse over the chart, a vertical line should snap to the center of the time bucket that we are closest to. We should also display a crosshair on all of the series showing us their values at that data point, and highlight the closest point we are on, and increase the thickness of the series that point is on. To the left of the graph (still in the right panel), there should be a legend. The legend looks like this:

[GROUP BY VALUE] [AGGREGATE]
[First Column name, with series color]
[Number of samples for the first column]
[Second Column name, with series color]
[Number of samples for the second column]
... for all columns
----
... for all group by values (up to the limit)

So for example, if I group by user, I might see:

Alice AVG
value
4 (samples)

The highlighted series (which has a thicker line) should also be highlighted in the legend).

This was kind of terrifying, because I initially thought I didn’t have a good way to test the SVG outputs. But after doing some regular old-fashioned debugging and reading the code (yes, this part not vibe coded), I figured out the problem, and also realized that Playwright can test that an SVG path is not just entirely straight. After the initial bugs were fixed, I mostly had to add missing features like x-axis/y-axis and interactivity features (amusingly, Codex ignored most of the instructions in the latter half of the prompt, giving only the barest bones legend. I suspect this was because I had some files which were too long). My general take after this was that JS chart libraries are going to become obsolete: it’s much easier to vibe code a bespoke implementation and then customize the heck out of it.

Conclusion

ScubaDuck was implemented in about 150 Codex prompts. As you can see from the sample prompts above, the prompts are recognizably programming, they just happen to be in plain English language. This is a big help, because I never had to keep track of the nest of callbacks and state machines for implementing complex UI elements in JavaScript. I had to be fluent in what I wanted my program to do, and a good QA tester for the application to discover new problems that needed to be fixed, but I did not have to worry at all about the vagaries of SVG DOM elements or pixel position computation minutiae. It’s hard to say how long it would have taken to code this by hand, but I think reproducing a UI that’s been in production for years at Meta in three (part-time) days is pretty good!

Despite having done a bit of AI coding before, I also learned a bit from working on Codex. Codex made it blindingly clear that the parallel modality (and subsequent conflict resolution) is important. It made me adjust up my estimation of the capability of LLMs to write raw HTML/JS and evoked a future where people vibe code components in place of taking on a third party dependency. I was very appreciative of no rate limit Codex (though I doubt it’s going to last.) It also reminded me how difficult it will be to setup agent environments for “real” projects (like PyTorch).

Hopefully, this case study has given you some ideas for things to try. Go forth and vibe code, responsibly!