February 3, 2026I recently received the following question about vibe-coding for tlparse, a structured log parser for torch.compile (slightly edited for ease of reading):
Hi Ed, I have some thoughts on vibe-coding tlparse and I’m curious what your opinions are. I think it’s fairly easy to vibe code and improve tlparse, but it’s hard to find people who know Rust or HTML or JavaScript well to properly review the vibe-coded stuff. The Rust PRs are not impossible to review, but the JavaScript ones are certainly hard… YOLO-landing PRs that we can’t review is certainly bad, I guess one option is just not landing any changes, and tell people to vibe-code themselves…?
I wonder if you have any opinion on this? I saw one of your BE week proposals is for a more vibe-coding friendly tlparse. Do you think we should just not attempt to review or land any front-end features (which most likely we cannot review), and all-in on the “custom vibe-coded” frontend route?
Oh boy, do I have opinions!
When is it acceptable not to review code? I find this is the wrong question to ask, because code review is only a proxy for the underlying measure we actually care about: does it matter if this code is good? And to answer this question, we have to understand whether or not whether or not we are in a high stakes or low stakes situation.
Here are some things that suggest high stakes:
- It does destructive/irreversible actions
- It used by many users (human or computer)
- It covers a large surface area
- It constitutes BC surface
- It handles money, private data, secrets, lives
- It runs automatically
- There are no automated tests / it is tested in prod
Here are some things that suggest low stakes:
- It is personal use only
- It is disposable or exploratory
- The output is easy to verify
- It is trivial to rollback to old versions (or better yet, you can run as many versions as you want at the same time)
- It comes with no warranty
- It is a starting point, rather than a complete product
- No side effects, no persistent data
Many problems won’t neatly fall into one bucket or another, but it’s still helpful to know what the high stakes aspects are, because you can spend more effort (e.g., doing code review) on those aspects and less on the low stakes things. Also, being aware of why something is high stakes can push you towards restructuring things so that a problem becomes lower stakes.
Let’s run this exercise for tlparse. tlparse is generally a low stakes project: it takes a structured logs and generates HTML files for viewing it. You can (in principle) run newer or older versions of it, and as a program with no side effects it is basically as simple as these things get.
However, there are some things that are high stakes about it. It is used extremely widely internally at Meta; if it was broken, we would likely get a message within a day from someone who was relying on it to diagnose problems in training jobs (including production SEVs.) The way it is deployed internally is as a classic executable which is automatically rolled out as new versions are published; most users don’t know how to run an old version of it–so while in principle it can be rolled back trivially, in practice you would need to instruct users on how to do so. Finally, although it doesn’t do that much (just generate HTML from a structured log), there is a large amount of variety of actual logs you see in production, which makes it difficult to comprehensively test it. A memorable example from the past is someone adding syntax highlighting to the tool. I reverted this because it caused extremely long files to take a long time to parse, and it also made it more difficult to grep for specific lines in the generated folder.
How can you lower the stakes? In the case of tlparse, here are some ideas:
- Have a separation between prod (the stable stuff that is shown by default) and experimental (the wacky untested stuff that needs some dogfooding to see if it works)
- Make it easy to have multiple versions of tlparse; update broke something, just go back to your favorite version
- Don’t deploy tlparse via a single rollout mechanism; have it as a local app that people can run / vibe code on without a deployment step
- Improve testing to ensure features don’t break
- Keep it simple, don’t add lots of features to keep the surface area of things to test simpler
Another trap is to think of code review as an all or nothing option: “if I haven’t done a careful line-by-line review of this code, I might as well just not review it at all.” There are lots of ways to evaluate LLM generated code, and some of these don’t involve reviewing the code at all.
First, let’s talk about evaluating LLM generated code in high stakes situations. The bar here is that your LLM generated code should be indistinguishable from the code you would have written yourself: the LLM just helped you type faster. This is a very high bar: imagine a mathematical proof written by someone not you. Can you just read the proof and understand it? Typically not: the only way to a real, durable understanding is to work through the steps of the proof yourself. It is difficult, intellectually taxing work–and you cannot offload this to the LLM, because the LLM isn’t the owner of the code, you are the owner of the code! I find in a situation like this it is best to steer the LLM very closely during authoring: I should have a clear idea of what I want written, and the LLM is just typing, and I am pausing and correcting it when it does things I don’t want. If you have the LLM run by itself for an hour, you’ll end up with code that is not yours, and you will have to spend the effort to make it yours. It is much easier, with both LLMs and regular colleagues, to own code if you’re involved every step in its conception.
A lower stakes situation is when the exact details of the code don’t matter, but you’re still thinking architecturally about the problem. I often am in this mode when I’m doing exploration: I don’t care about the exact things the LLM is typing, but I do care that it roughly has the right shape. After you’ve used LLMs a bunch, you get a sense for what kinds of mistakes they do and don’t make. You can just skip reviewing all the things that you know LLMs generally won’t mess up, and just look for the big strokes. To do this, however, you need to have a good, high level understanding of how things should work. If there’s a big pile of JavaScript code and you know absolutely nothing about how DOMs work, this isn’t going to work! And if there is a way to do something without having reams of JavaScript, you should absolutely steer the LLM to not do that.
In a situation where you are vibe coding and not looking at the generated code at all, you still have a very important job. You need to actually Q&A the feature and see if it actually works. I find for my pure vibe coding projects, this is the most time consuming part: I can ask the LLM to do anything, but then actually checking if it works (and transmitting feedback to the LLM) takes up most of my time. You can ask the LLM to write tests, but then you have to check if the tests are actually testing the important thing (to be clear, reviewing LLM generated tests is very high leverage.) The golden city in the distance is the LLM being able to Q&A the tool for yourself, but for something like tlparse HTML, we are still very early days in this kind of sophisticated browser use. (Claude has come a long way here, though, and I expect it to get better.)
Let’s look at some concrete PRs:
- Add create symbols logs to compilation metrics artifact - I agree with the outcome (an approval) here. It is not super intrusive, the new indexes are a little messy but this is unlikely to cause problems with other parts of the system.
- Add style to provenance tracking - Big JavaScript changes, likely all vibe coded. To get out of draft, there needs to be evidence of dog fooding to show it actually works. There will likely be bugs. It might be worth getting this into a lower stakes setting to iterate more quickly on real data.
It is perhaps unsatisfying that you have to evaluate things on a case-by-case basis. But hopefully this post gives you some framework to think about it.
February 3, 2026A central thesis of sharding in types is that the backward sharding can be directly computed from the forward sharding. This is not true for DTensor today, e.g., as seen in sum to Partial, and it leads to confusion where users cannot easily predict what the sharding of tensors are in their program. The question now arises: given a forward sharding, what should its backward sharding be? There are some easy cases to fill in:
- If forwards is sharded, backwards is sharded.
- If forwards is partial, backwards should be replicate (the easiest example to see this is local loss backwards: your local loss is partial over DP, but when you generate a ones tensor as
grad_loss, this should clearly be a replicated tensor).
But if your forwards is replicate, what should your backwards be? One answer to this question, adopted by JAX, is that the backwards should also be replicate. From a theoretical standpoint, this semantics is easily motivated by the following degenerate example: if my forwards computation is replicated across all my nodes (e.g., everyone is performing the same compute on exactly the same tensor), then the gradient should clearly be replicated across all nodes (and this is the only choice that allows us to avoid communications entirely). However, this does lead to an irritating problem where if you have a forwards that is replicated, and want your backwards to be partial, you need to introduce an entirely new forwards sharding (JAX calls it “reduced”) to indicate this. JAX chose to preserve replicate-to-replicate for backwards compatibility reasons.
The purpose of this post is to argue that Replicate Forwards, Partial Backwards is a better default (concretely, in JAX, swap the default so that all axes are “reduced” by default, and you have to explicitly make some “not” reduced–not to be confused with unreduced!) To see this, let’s look carefully at the sharding of a gated MLP with DP and TP (without the w1-w3 concatenation) in JAX. Here is the end-to-end example, done with einsums for ease of differentiation, written as explicitly as possible (the reshards can be elided in JAX, but they play an important role for DTensor erasure):
import jax
import jax.numpy as jnp
from jax import lax
jax.config.update('jax_num_cpu_devices', 4)
jax.set_mesh(jax.make_mesh((2, 2), ('dp', 'tp')))
def mlp(x, w1, w3, w2):
print(f"{jax.typeof(x)=}, {jax.typeof(w1)=}, {jax.typeof(w3)=}, {jax.typeof(w2)=}")
# !!! ATTENTION !!!
rx = jax.reshard(x, jax.P(None, 'dp', None, reduced={'tp'}))
rw1 = jax.reshard(w1, jax.P(None, 'tp', reduced={'dp'}))
rw3 = jax.reshard(w3, jax.P(None, 'tp', reduced={'dp'}))
rw2 = jax.reshard(w2, jax.P('tp', None, reduced={'dp'}))
print(f"{jax.typeof(rx)=}, {jax.typeof(rw1)=}, {jax.typeof(rw3)=}, {jax.typeof(rw2)=}")
h1 = jnp.einsum("sbh,hi->sbi", rx, rw1)
print(f"{jax.typeof(h1)=}")
h3 = jnp.einsum("sbh,hi->sbi", rx, rw3)
print(f"{jax.typeof(h3)=}")
h = jnp.einsum("sbi,sbi->sbi", jax.nn.silu(h1), h3)
print(f"{jax.typeof(h)=}")
out = jnp.einsum("sbi,ih->sbh", h, rw2, out_sharding=jax.P(None, 'dp', None, unreduced={'tp'}))
print(f"{jax.typeof(out)=}")
return out
seq = 4
batch = 8
hidden = 16
intermediate = 32
x = jax.device_put(
jnp.ones((seq, batch, hidden), dtype=jnp.float32),
jax.P(None, 'dp', None)
)
w1 = jax.device_put(
jnp.ones((hidden, intermediate), dtype=jnp.float32),
jax.P(None, 'tp')
)
w3 = jax.device_put(
jnp.ones((hidden, intermediate), dtype=jnp.float32),
jax.P(None, 'tp')
)
w2 = jax.device_put(
jnp.ones((intermediate, hidden), dtype=jnp.float32),
jax.P('tp', None)
)
mlp(x, w1, w3, w2)
If you take the prints and annotated them inline in the program, it looks like this:
x: f32[seq, batch@dp, hidden]
w1: f32[hidden, intermediate@tp]
w3: f32[hidden, intermediate@tp]
w2: f32[intermediate@tp, hidden]
rx: f32[seq, batch@dp, hidden]{R:tp} = jax.reshard(x, jax.P(None, 'dp', None, reduced={'tp'}))
rw1: f32[hidden, intermediate@tp]{R:dp} = jax.reshard(w1, jax.P(None, 'tp', reduced={'dp'}))
rw3: f32[hidden, intermediate@tp]{R:dp} = jax.reshard(w3, jax.P(None, 'tp', reduced={'dp'}))
rw2: f32[intermediate@tp, hidden]{R:dp} = jax.reshard(w2, jax.P('tp', None, reduced={'dp'}))
h1: f32[seq, batch@dp, intermediate@tp] = jnp.einsum("sbh,hi->sbi", x, rw1)
h3: f32[seq, batch@dp, intermediate@tp] = jnp.einsum("sbh,hi->sbi", x, rw3)
h: f32[seq, batch@dp, intermediate@tp] = jnp.einsum("sbi,sbi->sbi", jax.nn.silu(h1), h3)
out: f32[seq, batch@dp, hidden]{U:tp} = jnp.einsum("sbi,ih->sbh", h, rw2, out_sharding=jax.P(None, 'dp', None, unreduced={'tp'}))
Each reshard corresponds to a situation where we have a no-op in forwards, but an all-reduce in backwards. Why does JAX’s primal-cotangent rule for sharding in types imply this? I have three arguments.
The intuitive argument. In a DP+TP gated MLP, you expect to need to do a TP all-reduce on grad_input (because as you leave the sharded TP region you need to aggregate gradients from the TP shards), and you need to do the traditional DP all-reduces on all the parameters. In PyTorch, the DP all-reduces are typically handled by the DDP/FSDP wrapper outside of this code, but when we accept JAX semantics, grad_w1: f32[hidden, intermediate@tp] (it’s replicated!) so we are obligated to ensure the all-reduce occurs before we exit this region.
The peephole argument. Let’s just look at one specific backwards and work it out by hand.
# Recall:
rw1: f32[hidden, intermediate@tp]{R:dp}
h1: f32[seq, batch@dp, intermediate@tp]
# Therefore: (reduced->unreduced)
grad_rw1: f32[hidden, intermediate@tp]{U:dp}
grad_h1: f32[seq, batch@dp, intermediate@tp]
# Recall:
h1: f32[seq, batch@dp, intermediate@tp] = jnp.einsum("sbh,hi->sbi", x, rw1)
# Einsum backwards says:
grad_rw1: f32[hidden, intermediate@tp]{U:dp} = jnp.einsum("sbh,sbi->hi", x, grad_h1)
# Contraction is on replicated 's' and sharded 'b', so the result is unreduced on dp axis
The conversion to ‘reduced’ in forwards turns into an all-reduce to compute grad_w1: f32[hidden, intermediate@tp]. If we want to be extremely explicit about our code, we are obligated to convert w1 to “reduced”, so that its backward is “unreduced” as is implied by einsum backwards. By the way, inside of a shard_map region, a very similar thing occurs; as w1 is invariant in DP, but x is varying in DP, we must pcast w1 from invariant to varying to get correct gradients.
The exhaustive argument. We can write out the full backwards (for brevity, I use g_ instead of grad_ for the gradients):
g_out: f32[seq, batch@dp, hidden]{R:tp}
# out: f32[seq, batch@dp, hidden]{U:tp} = jnp.einsum("sbi,ih->sbh", h, rw2, out_sharding=jax.P(None, 'dp', None, unreduced={'tp'}))
g_h: f32[seq, batch@dp, intermediate@tp] = einsum("sbh,ih->sbi", g_out, rw2)
g_rw2: f32[intermediate@tp, hidden]{U:dp} = einsum("sbi,sbh->ih", h, g_out, out_sharding=jax.P('tp', None, unreduced={'dp'}))
# h: f32[seq, batch@dp, intermediate@tp] = jnp.einsum("sbi,sbi->sbi", jax.nn.silu(h1), h3)
g_silu_h1: f32[seq, batch@dp, intermediate@tp] = einsum("sbi,sbi->sbi", g_h, h3)
g_h3: f32[seq, batch@dp, intermediate@tp] = einsum("sbi,sbi->sbi", g_h, silu(h1))
g_h1: f32[seq, batch@dp, intermediate@tp] = silu_backward(g_silu_h1, h1)
# h3: f32[seq, batch@dp, intermediate@tp] = jnp.einsum("sbh,hi->sbi", x, rw3)
g_rx_from_h3: f32[seq, batch@dp, hidden]{U:tp} = einsum("sbi,hi->sbh", g_h3, rw3, out_sharding=jax.P(None, 'dp', None, unreduced={'tp'}))
g_rw3: f32[hidden, intermediate@tp]{U:dp} = einsum("sbh,sbi->hi", rx, g_h3, out_sharding=jax.P(None, 'tp', unreduced={'dp'}))
# h1: f32[seq, batch@dp, intermediate@tp] = jnp.einsum("sbh,hi->sbi", x, rw1)
g_rx_from_h1: f32[seq, batch@dp, hidden]{U:tp} = einsum("sbi,hi->sbh", g_h1, rw1, out_sharding=jax.P(None, 'dp', None, unreduced={'tp'}))
g_rw1: f32[hidden, intermediate@tp]{U:dp} = einsum("sbh,sbi->hi", rx, g_h1, out_sharding=jax.P(None, 'tp', unreduced={'dp'}))
g_rx: f32[seq, batch@dp, hidden]{U:tp} = g_rx_from_h1 + g_rx_from_h3
g_w2: f32[intermediate@tp, hidden] = reshard(g_rw2, P('tp', None))
g_w1: f32[hidden, intermediate@tp] = reshard(g_rw1, P(None, 'tp'))
g_w3: f32[hidden, intermediate@tp] = reshard(g_rw3, P(None, 'tp'))
g_x: f32[seq, batch@dp, hidden] = reshard(g_rx, P(None, 'dp', None))
You can individually verify that each unreduced gradient is implied by the einsum in question.
The upshot. The real point of this example is to see that the “intuitive” sharding type for the arguments on mlp, actually forces a lot of communications in backwards, because we must make the gradients replicated, and that implies all-reduces. This can actually result in a suboptimal communication pattern: when both TP and SP are being used, the unreduced grad_input can be delayed all the way to the forwards all-gather between the SP and TP region. In backwards, we can directly do a reduce-scatter than doing an all-reduce and then later throwing out most of the result in the all-gather backwards. (Arguably, this isn’t a huge deal if you have a compiler like XLA, since you would expect it to know how to optimize the comms here, but the whole point of sharding-in-types is to give more control over when comms occur.)
A better type signature for this function is mlp(rx, rw1, rw3, rw2), where all of these arguments are reduced (rx on tp, and rw{1,3,2} on dp). Now the reshards can be controlled by the user; you can do exactly the same communication pattern as our original implementation, or you can delay them until later. And the best way to encourage people to write their code this way is to have replicate forwards imply partial backwards. (P.S. It is still useful to have another variant of replicate which really does have a replicate backwards. I don’t have a good name for it, but it could occasionally be used to do an all-reduce early before fan-out would imply you have to do multiple all-reduces.)
Thanks Natalia Gimelshein, Tianyu Lu, and Ailing Zhang for detailed discussions that helped me reach this position. Thanks Twitter for giving this position a sanity check. Any mistakes are my own.
Patrick Toulme requested the HLO and Shardy MLIR for the JAX program. Here they are, generated from this script. The raw output is here. Below, I have posted annotated versions, courtesy of Claude.
Full pre-partition HLO. The most interesting thing is you can see the sharding constraints added which trigger reductions:
%16 = sdy.sharding_constraint %15 <@mesh, [{"tp"}, {}], unreduced={"dp"}> : tensor<32x16xf32>
%27 = sdy.sharding_constraint %26 <@mesh, [{}, {"tp"}], unreduced={"dp"}> : tensor<16x32xf32>
%29 = sdy.sharding_constraint %28 <@mesh, [{}, {"dp"}, {}], unreduced={"tp"}> : tensor<4x8x16xf32>
%33 = sdy.sharding_constraint %32 <@mesh, [{}, {"tp"}], unreduced={"dp"}> : tensor<16x32xf32>
%35 = sdy.sharding_constraint %34 <@mesh, [{}, {"dp"}, {}], unreduced={"tp"}> : tensor<4x8x16xf32>
%36 = stablehlo.add %29, %35 : tensor<4x8x16xf32>
...
# ========== REDUCE WEIGHT GRADIENTS (dp reduction) ==========
# Reduce g_w2 across dp dimension
%37 = sdy.sharding_constraint %16 <@mesh, [{"tp"}, {}]> : tensor<32x16xf32>
# Reduce g_w3 across dp dimension
%38 = sdy.sharding_constraint %27 <@mesh, [{}, {"tp"}]> : tensor<16x32xf32>
# Reduce g_w1 across dp dimension
%39 = sdy.sharding_constraint %33 <@mesh, [{}, {"tp"}]> : tensor<16x32xf32>
# ========== REDUCE g_x (tp reduction) ==========
# Reduce g_x across tp dimension
%40 = sdy.sharding_constraint %36 <@mesh, [{}, {"dp"}, {}]> : tensor<4x8x16xf32>
# ========== RETURN ==========
# Returns: (out, grad_x, grad_w1, grad_w3, grad_w2)
return %12, %40, %39, %38, %37 : tensor<4x8x16xf32>, tensor<4x8x16xf32>, tensor<16x32xf32>, tensor<16x32xf32>, tensor<32x16xf32>
Full post-partition HLO. The most notable thing to observe here is that the DP all-reduces have been bucketed together into a single all-reduce, which is what you’d expect any self-respecting DP implementation to do. Also interestingly, the TP reduction is actually done first before summing the gradients together on both paths; you could probably also sum the gradients together first before all-reducing.
# ============================================================================
# MAIN ENTRY POINT
# Returns: (out, grad_x, grad_w1, grad_w3, grad_w2)
# ============================================================================
ENTRY %main.11_spmd (param.5: f32[4,4,16], param.6: f32[16,16], param.7: f32[16,16], param.8: f32[16,16], param.9: f32[4,4,16]) -> (f32[4,4,16], f32[4,4,16], f32[16,16], f32[16,16], f32[16,16]) {
# ========== BACKWARD: All-reduce weight gradients across TP ==========
# Sum weight gradients across tensor parallel devices (replica_groups={{0,2},{1,3}})
# Returns: (g_w1, g_w3, g_w2)
%all-reduce.1 = (f32[16,16]{1,0}, f32[16,16]{1,0}, f32[16,16]{1,0}) all-reduce(%dot.27, %dot.28, %dot.29), channel_id=3, replica_groups={{0,2},{1,3}}, use_global_device_ids=true, to_apply=%region_2.5, metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/transpose/jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/transpose" stack_frame_id=5}
# line 16 backward: reshape g_x_h3
%bitcast.8 = f32[4,4,16]{2,1,0} bitcast(%dot.18), metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/dot_general" stack_frame_id=8}
# Extract weight gradients from all-reduce tuple
# line 15 backward: g_w1 (reduced)
%get-tuple-element.4 = f32[16,16]{1,0} get-tuple-element(%all-reduce.1), index=0, metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/transpose/jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/transpose" stack_frame_id=5}
# line 16 backward: g_w3 (reduced)
%get-tuple-element.6 = f32[16,16]{1,0} get-tuple-element(%all-reduce.1), index=1, metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/transpose/jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/transpose" stack_frame_id=8}
# line 18 backward: g_w2 (reduced)
%get-tuple-element.8 = f32[16,16]{1,0} get-tuple-element(%all-reduce.1), index=2, metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbi,ih->sbh))/transpose/jit(forward_and_backward)/transpose(jvp(sbi,ih->sbh))/transpose" stack_frame_id=11}
# ========== BACKWARD: All-reduce input gradients across TP ==========
# Sum input gradients across tensor parallel devices (replica_groups={{0,1},{2,3}})
# Returns: (g_x_h3, g_x_h1)
%all-reduce = (f32[4,4,16]{2,1,0}, f32[4,4,16]{2,1,0}) all-reduce(%bitcast.8, %bitcast.11), channel_id=1, replica_groups={{0,1},{2,3}}, use_global_device_ids=true, to_apply=%region_0.1, metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/dot_general/jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/dot_general" stack_frame_id=8}
# Extract input gradients from all-reduce tuple
# line 16 backward: g_x_h3 (reduced across TP)
%get-tuple-element = f32[4,4,16]{2,1,0} get-tuple-element(%all-reduce), index=0, metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/dot_general/jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/dot_general" stack_frame_id=8}
# line 15 backward: g_x_h1 (reduced across TP)
%get-tuple-element.2 = f32[4,4,16]{2,1,0} get-tuple-element(%all-reduce), index=1, metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/dot_general/jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/dot_general" stack_frame_id=5}
# ========== BACKWARD: Sum input gradients from both paths ==========
# line 15+16 backward: g_x = g_x_h1 + g_x_h3
%wrapped_add = f32[4,4,16]{2,1,0} fusion(%get-tuple-element, %get-tuple-element.2), kind=kLoop, calls=%wrapped_add_computation, metadata={op_name="jit(forward_and_backward)/transpose(jvp(sbh,hi->sbi))/add_any" stack_frame_id=5}
# ========== FINAL OUTPUT ==========
# Returns: (out, grad_x, grad_w1, grad_w3, grad_w2)
ROOT %tuple.6 = (f32[4,4,16]{2,1,0}, f32[4,4,16]{2,1,0}, f32[16,16]{1,0}, f32[16,16]{1,0}, f32[16,16]{1,0}) tuple(%bitcast.4, %wrapped_add, %get-tuple-element.4, %get-tuple-element.6, %get-tuple-element.8)
}
February 1, 2026DTensor has famously terrible eager mode performance; for example, this paper measured a 35-60% slowdown in end-to-end training performance with and without DTensor (with DTensor operations taking at least 7x longer than actually running the computation for real). While it is possible to alleviate some of this slowdown via optimizations (in the paper, veScale shows fast bypass of sharding propagation, improved cache lookups and C++ code can take dispatch overhead to 30us), this is still too high for some settings.
- You could eliminate this overhead by rewriting your code without DTensors at all, but this gives up the benefits of expressing your code in global SPMD form.
- You could eliminate the overhead by using a compiler or CUDA graphs, but this requires your code to be traceable.
Is there a way to have our cake (global SPMD) while eating it too (eager mode with no DTensor)? veScale proposed Static Eager mode as a way of eliminating DTensor at runtime, by observing that DTensor placements are largely static in a program, which means that you can just drop DTensor at runtime as long as you manually insert any communication that would have occurred if you had run with DTensor (veScale does extra work under the hood to add hooks for these communications). However, it is quite difficult for researchers to understand how to insert redistributes / gradient redistributes. In this blog post, I want to reimagine their system under the following constraint: what if you could just erase the DTensors, without having to add any hooks at all? Spoiler alert: JAX-style sharding in types without implicit conversions is enough to get the job done.
First, let’s describe the problem more carefully. Typically, a desirable property for type systems is that the types can be erased before execution without changing the runtime behavior of a program. In the case of DTensor, we want to erase all of the placements in a program, with the hope that we can still run everything we need to without them. Actually, most of the time in DTensor this will work, because the ideal situation for DTensor is that you just run the operation as-is on the local tensors. The problem is, of course, redistribute, which needs to know what the original placement of a DTensor is to issue the collectives to get it into the new desired placement. Even worse, to detect if an implicit redistribution would occur, you need to compute that input placements are illegal and how to insert redistributes to make it legal.
The explicit redistribute problem is easy enough to fix: a user could specify a specific collective (which, with DTensors, would be checked for consistency against the actual input/output placements), or we could ask the user to specify the input placements so that we can still compute the sequence of collectives to get to the output placement. To avoid implicit redistributions in forwards, one can simply ensure you insert explicit redistributes anywhere they are needed; to avoid implicit redistributions in backwards, you need a type system that guarantees that the backwards collectives correspond precisely to forward collectives. You need two ideas from the JAX sharding type system: first, the backward gradient placement should always be computable from the forward primal placement, so you can always reason locally about if comms need (as you always know the placement of grad_output.) Second, you need enough vocabulary (e.g., reduced/unreduced) to ensure that these forced backward gradient placements don’t lead to communication inefficiencies.
Once your user program has a DTensor erasure property, you now have code that can be run with either plain Tensor or DTensor. Running with DTensor is akin to running a (dynamic) type checker on the program: in explicit mode it will error if implicit redistributes occur, and it will also error if you messed up and claimed that something was in some placement when it was not. Running with Tensor, you just elide all of the shard checking and trust that the user program was written correctly. This works if the user program doesn’t branch; if you are worried about inexhaustive testing, you could run under both DTensor and torch.compile, where guards and Dynamo can help you identify whether or not you have actually exercised all potential inputs or not.
The resulting code you write is very similar to a Megatron-style training framework, but with the twist that you can check that you inserted all of your collectives by running with DTensors rather than Tensors. More generally, this is an interesting pattern for gradual compilation; your program can be entirely run in eager mode, and the compiler is relegated to a sideband static analysis tool (akin to an optional type checker), which still can be an essential part of the workflow for catching and debugging problems.
Sum should not cast to Partial. In classic PyTorch DTensor, a sum on a sharded dimension performs no communications and produces a Partial placement. This behavior cannot be supported under DTensor erasure. Let us reason through it:
input: Shard(0)
sum = input.sum(0)
sum: Partial()
grad_sum = torch.ones_like(sum)
grad_sum: Replicate() # if we suppose Partial() <-> Replicate()
grad_input = grad_sum.expand(input.shape)
grad_input: Replicate()
We see that the backwards formula for sum is a plain eager operation that expands grad_sum. This operation will produce a Replicate tensor. But the primal-cotangent sharding mapping specifies that grad_input should be a Shard(0) tensor; a (free) redistribute from Replicate to Shard(0) is required. In ordinary DTensor, we discover this redistribution happens later when there is a shard-replicate interaction; however, with DTensor erasure, there is no way to discover this, and we must ban this sum.
To resolve the problem, we simply need to distinguish between two distinct APIs for sum. Standard torch.sum should only ever do a local summation and error out if the reduction is across a sharded dimension. Cast to partial (aka lax.pcast(to='unreduced') in JAX) is a separate function that only works on sharded tensors. The distinct function can now be associated with a custom autograd function that triggers the re-sharding in backwards.
Broadcasts over shards require conversion to ‘reduced’. Above, I claimed that JAX’s explicit mode for global SPMD only issues collectives when explicitly asked to do so. But this isn’t completely true. When I do an operation between a sharded and replicated tensor (under JAX semantics), this can result in an implicit all-reduce in backwards:
import jax
import jax.numpy as jnp
jax.config.update('jax_num_cpu_devices', 4)
jax.set_mesh(jax.make_mesh((4,), ('dp',)))
batch = 8
hidden = 16
out_dim = 4
x = jax.device_put(
jnp.ones((batch, hidden), dtype=jnp.float32),
jax.P('dp', None)
)
w = jax.device_put(
jnp.ones((hidden, out_dim), dtype=jnp.float32),
jax.P(None, None)
)
grad_out = jax.device_put(
jnp.ones((batch, out_dim), dtype=jnp.float32),
jax.P('dp', None)
)
def forward(x, w):
return jnp.einsum("bh,ho->bo", x, w)
def backward(x, w, grad_out):
_, vjp_fn = jax.vjp(forward, x, w)
return vjp_fn(grad_out)[1] # gradient w.r.t. w
compiled = jax.jit(backward).lower(x, w, grad_out).compile()
print(compiled.as_text())
This prints:
%add.clone (x.3: f32[], y.1: f32[]) -> f32[] {
%x.3 = f32[] parameter(0)
%y.1 = f32[] parameter(1)
ROOT %add.1 = f32[] add(%x.3, %y.1)
}
%fused_computation (param_0.1: f32[4,16]) -> f32[16,4] {
%param_0.1 = f32[4,16]{1,0} parameter(0)
%transpose.7 = f32[16,4]{0,1} transpose(%param_0.1), dimensions={1,0}, metadata={op_name="jit(backward)/transpose(jvp(bh,ho->bo))/transpose" stack_frame_id=3}
ROOT %copy.1 = f32[16,4]{1,0} copy(%transpose.7), metadata={op_name="jit(backward)/transpose(jvp(bh,ho->bo))/transpose" stack_frame_id=3}
}
ENTRY %main.0_spmd (param.1: f32[2,16], param: f32[2,4]) -> f32[16,4] {
%param = f32[2,4]{1,0} parameter(1), sharding={devices=[4,1]<=[4]}, metadata={op_name="grad_out"}
%param.1 = f32[2,16]{1,0} parameter(0), sharding={devices=[4,1]<=[4]}, metadata={op_name="x"}
%dot = f32[4,16]{1,0} dot(%param, %param.1), lhs_contracting_dims={0}, rhs_contracting_dims={0}, metadata={op_name="jit(backward)/transpose(jvp(bh,ho->bo))/dot_general" stack_frame_id=3}
%all-reduce = f32[4,16]{1,0} all-reduce(%dot), channel_id=1, replica_groups=[1,4]<=[4], use_global_device_ids=true, to_apply=%add.clone, metadata={op_name="jit(backward)/transpose(jvp(bh,ho->bo))/dot_general" stack_frame_id=3}
ROOT %transpose_copy_fusion = f32[16,4]{1,0} fusion(%all-reduce), kind=kLoop, calls=%fused_computation, metadata={op_name="jit(backward)/transpose(jvp(bh,ho->bo))/transpose" stack_frame_id=3}
}
We can see the original program has no collectives, but the HLO has an all-reduce. Actually, this is just the classic all-reduce you have to do to the gradients of all parameters when doing DP, so this isn’t exactly surprising. This is bad for DTensor erasure, though, because the way JAX knows to insert the collective here is by noticing that the backwards of the linear produces an unreduced result, but sharding-in-types demands that the gradient be replicated.
Now, we could fix the particular case of DP by arguing that DP sharding is special and it’s the responsibility of the DP framework to know that a reduction is necessary (this is how torchtitan on DTensor classically operates: we don’t represent DP directly in the DTensor, and FSDP is responsible for actually doing the all-reduces and grad scaling.) A more theoretically sound solution, however, is to simply say that einsum should forbid broadcasting a replicated tensor with a sharded tensor (even though in forwards this can be done without any communication); instead, in this case, you must have a reduced tensor on the sharded mesh axis, so that the gradient is an unreduced tensor on that mesh axis. A conversion from replicate to reduced will trigger the all-reduce in backwards.
We can do an analysis on einsum to see that the broadcast situation is the only situation where these extra reductions can occur. Recall from Computing sharding with einsum that when we compute the gradient for an input, we interchange the indices for that input with the indices for the output in the einsum formula. We can do a case-by-case analysis for every valid input sharding to see what communication happens in the backwards:
Shard("batch"), Shard("batch") -> Shard("batch"): interchanging an input with output still results in a batch pass-through, no comms.Shard("contract"), Shard("contract") -> Partial(): interchanging an input with output results in Shard("contract"), Replicate() -> Shard("contract") (recall that the cotangent type for reduced, is unreduced, aka replicate!), no commsShard("broadcast"), Replicate() -> Shard("broadcast"): there are two cases here:Shard("broadcast"), Replicate() -> Shard("broadcast") (grad of first input): this is the same as forwards, no comms.Shard("broadcast"), Shard("broadcast") -> Partial() (grad of second input): contraction over sharded dimension produces partial, yes comms.
So this really is the only situation in einsum that has this problem. (Pointwise ops that broadcast would also have this problem.)
January 28, 2026Conventionally, a type system is something that classifies values into data types like float32 or int64. However, fancy type systems go beyond data types, allowing us to talk about potentially arbitrary invariants on data; for example, if we were to talk about the “type” of a array, it would cover not only its data type, but also its shape, e.g., f32[40, 20]. JAX’s type system of abstract values (avals) goes further than just data types and shapes and is equipped to reason about sharding related invariants. However, this type system is poorly documented, especially recent additions like reduced/unreduced axes (circa June 2025). In this blog post, I want to give a consolidated description of the sharding related aspects of JAX’s typing in explicit sharding mode, as of 2026. Disclaimer: I am not a JAX developer, and there may potentially be mistakes in this presentation; please let me know about errors in Twitter. I will assume that you have some knowledge about how to work with JAX sharding in the frontend; please refer to Distributed arrays and automatic parallelization, Explicit sharding and Manual parallelism with shard_map for a refresher on these topics.
Warning: All of this discussion is about Explicit sharding. The avals do not store accurate sharding information in the (default) auto mode. For example:
import jax
import jax.numpy as jnp
from jax.sharding import AxisType
jax.config.update('jax_num_cpu_devices', 2)
mesh1 = jax.make_mesh((2,), ('i',), axis_types=(AxisType.Auto,))
jax.set_mesh(mesh1)
array1 = jax.device_put(jnp.ones((8,)), jax.P("i"))
print(array1.aval.sharding)
mesh2 = jax.make_mesh((2,), ('i',), axis_types=(AxisType.Explicit,))
jax.set_mesh(mesh2)
array2 = jax.device_put(jnp.ones((8,)), jax.P("i"))
print(array2.aval.sharding)
will print:
NamedSharding(mesh=AbstractMesh('i': 2, axis_types=(Auto,), device_kind=cpu, num_cores=None), spec=PartitionSpec(None,))
NamedSharding(mesh=AbstractMesh('i': 2, axis_types=(Explicit,), device_kind=cpu, num_cores=None), spec=PartitionSpec('i',))
The new jax.make_mesh API defaults to Explicit, but the direct jax.sharding.Mesh constructor defaults to Auto. Beware!
The name of a type in JAX is AbstractValue. The important concrete subclass of this type is ShapedArray, whose type I have simplified below:
class ShapedArray(AbstractValue):
shape: tuple[int, ...]
dtype: dtype
weak_type: bool # when True, don't force type promotion
sharding: NamedSharding
vma: frozenset[MeshAxisName]
memory_space: MemorySpace # see https://github.com/jax-ml/jax/pull/30556
You can see ordinary things like shape and dtype, as well as some weirder things. In this post, we are going to focus on sharding and vma. Here are simplified definitions for NamedSharding and PartitionSpec:
MeshAxisName = str
class NamedSharding:
mesh: Mesh
spec: PartitionSpec
class PartitionSpec(tuple[None | MeshAxisName | tuple[MeshAxisName, ...], ...]):
unreduced: frozenset[MeshAxisName]
reduced: frozenset[MeshAxisName]
Let us now describe these types in more detail. For each, we will ask:
- What does it mean?
- Is the type applicable for global SPMD or local SPMD (or both?)
- How does the type propagate across operations?
- How is the type transformed by autograd (what is the mapping from primal to cotangent type?)
PartitionSpec is the most user visible concept. In this section, we’ll ignore unreduced/reduced for now. Without those fields, it is simply a tuple with one entry per dimension of the array (in PyTorch I often refer to this as tensor-oriented sharding, as opposed to mesh-oriented sharding). For each array dimension, you specify which mesh axes shard it. There can be zero (None), one ("i") or many (("i", "j")) mesh axes sharding a dimension; sharding is applied from left to right. You recover the global view (i.e., the one whose shape is described in ShapedArray) by stacking the arrays distributed across those mesh axes. PartitionSpec is commonly abbreviated as just P in JAX code. When you print the type of an array in explicit mode, JAX will inline the partition spec into the shape: e.g., float32[8,16@tp] implies a PartitionSpec of P(None, "tp"). There’s a much longer description with pretty pictures at Distributed arrays and automatic parallelization.
PartitionSpec propagate according to shard propagation rules, which must be defined on a per-operator basis. If you want to perform an operation without performing communication, it must be the case that running that operation locally on the sharded tensors and then stacking it (per the output sharding), would be the same as stacking the arrays first (per the input sharding) and then running the operation globally. The output sharding is not always the same as the input sharding, and the shard propagation rule is also responsible for computing this output sharding.
How does PartitionSpec interact with autograd? The PartitionSpec of primals and cotangents matches: a value that is replicated in forwards, will also be replicated in backwards; similarly, if it is sharded in forwards, it will be sharded in backwards. (Reduced/unreduced will be an exception to this rule, discussed below.)
I want to take a moment to discuss a subtlety of PartitionSpec with respect to shard_map. As I discussed in Global vs Local SPMD, inside of a shard_map region, it’s not really well defined what the global shape of a array is: you only have to specify a PartitionSpec on the inputs and outputs. By default, inside of a shard_map, the shape is the local shape of a array, the mesh says your axis is Manual, and the PartitionSpec says there is no sharding on the array anymore (perhaps confusingly, since None here in the global SPMD view would imply it’s replicated, but that’s not at all the meaning here).
As a small example:
import jax
import jax.numpy as jnp
jax.config.update('jax_num_cpu_devices', 2)
jax.set_mesh(jax.make_mesh((2,), ('i',)))
x = jax.device_put(jnp.ones((8,)), jax.P("i"))
@jax.shard_map(out_specs=jax.P("i"))
def f(x_local):
print(jax.typeof(x_local))
print(jax.typeof(x_local).sharding)
return x_local
f(x)
prints:
float32[4]{V:i}
NamedSharding(mesh=AbstractMesh('i': 2, axis_types=(Manual,), device_kind=cpu, num_cores=None), spec=PartitionSpec(None,))
The PartitionSpec is along for the ride, but it no longer contains sharding information for the manual axes (as you would expect.) (Wondering about the {V:i}? See the next section.) Another subtlety is that JAX allows for mixing Manual mode with not-Manual mode, via axis_names. So you can actually potentially see nontrivial PartitionSpec inside of a shard_map region! The footnote has a worked example.
VMA is short for “varying manual axes”. The motivation for this feature is described in Efficient transposition of replication-inducing collectives and the feature has been around long enough that the shard_map docs have a section about it.
Unlike PartitionSpec, VMA is a shard_map only concept; as its name suggests, it only applies to Manual axes. VMA tracks whether or not values are varying or invariant across a given mesh dimension. Actually, we can think of VMA as an approximation of PartitionSpec. PartitionSpec tells us exactly how to stack the local arrays to form a global array–they vary across that mesh dimension: if PartitionSpec for a particular array dim is None, it instead claims says this array is invariant across this mesh dimension (no stacking needed, everything is the same!) With VMA, we don’t know how to stack the arrays (because the whole point of local SPMD is that the global view is undefined), but we do know if they are the same or different across a mesh dimension. You don’t need to track VMA for non-Manual axes, because PartitionSpec subsumes it. In the print of jax.typeof, the {V:i} indicates all of the varying mesh axes (V is for varying). When all axes are invariant, we just elide this from the print.
Because VMA is an approximation of PartitionSpec, its propagation rules are simpler as well. For non-communication ops, it is simply required that the VMA of all inputs match exactly. If one input is invariant while another is varying, JAX will insert an implicit conversion from invariant to varying to make the VMA match. In Megatron via shard_map, I have a worked example of this, where VMA is how JAX triggers an all-reduce on gradients from the TP region into a replicated grad_input.
How does VMA interact with autograd? The VMA of primals and cotangents matches: a value that is varying in forwards, will also be varying in backwards; similarly, if it is invariant in forwards, it will be invariant in backwards. If you are skeptical that forcing this constraint is an efficient thing to do, you would be right! (See reduced/unreduced below.)
Collective operations are more complex regarding VMA, because we must reason about what happens to the variance of the local tensors before and after the collective. In the bottom of this section there is a table of how all the collectives affect variance. Beyond the standard collectives, VMA also introduces the necessity to represent a conversion from invariant to varying (which is a no-op in forwards and an all-reduce in backwards.) As a type system nerd, I think this is a very cool use of type systems to rule out illegal programs (that forget to do required all-reduces in backwards!)
One side note: VMA is not actually required for correctness, and you can actually disable this type system with check_vma=False. It was actually introduced as a way to make it possible to write more efficient programs that were impossible to write without it. Without VMA, JAX can conservatively assumes that all axes are potentially varying, and it will insert extra collectives to ensure you get the correct result. By actually modeling VMA, we can tell when something is invariant and potentially skip this collective.
The unreduced and reduced fields in PartitionSpec are quite new and are not documented in the public JAX documentation. However, we think of them as quite important when doing work with explicit sharding (for example, Wanchao, one of the original authors of DTensor, has told me the need to represent Partial placements is one of the big reasons why DTensor has mesh-oriented placements rather than tensor-oriented placements.) Unlike PartitionSpec/VMA, these fields apply for both Explicit and Manual axes.
It is easiest to first describe what unreduced means. Unreduced means there is a pending reduction (summation) on a device mesh axis that is necessary to get the global view of the array. Outside of shard_map, the most common way to generate an unreduced array in JAX is to use jnp.einsum with out_sharding that specifies unreduced. For example:
import jax
import jax.numpy as jnp
jax.config.update('jax_num_cpu_devices', 2)
jax.set_mesh(jax.make_mesh((2,), ('i',)))
x = jax.device_put(jnp.ones((4, 8,)), jax.P(None, "i"))
y = jax.device_put(jnp.ones((4, 8,)), jax.P(None, "i"))
u = jnp.einsum('bx,bx->b', x, y, out_sharding=jax.P(unreduced={'i'}))
print("u", jax.typeof(u))
prints:
float32[4]{U:i}
As I described in Computing sharding with einsum, when you do a contraction on two dimensions that are sharded on the same mesh axis, this can be done locally as long as you remember that there is a pending reduction (aka, partial/unreduced) that you need to do across that mesh axis to get the final result. Prior to the existence of unreduced in JAX, it wasn’t possible to express that you wanted no communication: the output sharding could only express replicated/sharded states, so you were going to end up doing an all-reduce or reduce-scatter.
You can also have unreduced axes in shard_map. For example, you can cast a varying array into an unreduced array, and then trigger the reduction. (Warning: unreduced doesn’t work with shard_map without jax.jit, see issue #34684)
import jax
import jax.numpy as jnp
from jax import lax
jax.config.update('jax_num_cpu_devices', 2)
jax.set_mesh(jax.make_mesh((2,), ('i',)))
x = jax.device_put(jnp.ones((8,)), jax.P("i"))
@jax.jit
@jax.shard_map(out_specs=jax.P(None))
def f(x_local):
u = lax.pcast(x_local, to='unreduced', axis_name='i')
print(jax.typeof(u))
return lax.psum(u, axis_name='i')
f(x)
prints:
float32[4]{U:i}
How does unreduced propagate? Alas, this once again is something you have to define on a per operator basis, like in sharding propagation rules. One rule of thumb is that linear functions can always propagate unreduced, since linearity means f(x + y) == f(x) + f(y), but at time of writing JAX just writes all the unreduced rules out by hand.
How does unreduced interact with VMA and PartitionSpec? It is mutually exclusive from varying/sharded. If an array is unreduced on a mesh axis, it cannot also be varying or sharded on that mesh axis. (Technically, one might argue that if something is unreduced on a mesh axis, it is obviously varying on that axis, but varying and unreduced need to be treated differently for AD purposes, so it’s best to keep these distinct.)
Inside of shard_map, there are a few functions for working with unreduced:
lax.pcast will let you directly convert a varying axis into an unreduced axis, declaring your intent to do a reduction later. However, you can’t do the “no-op” pcast from unreduced to varying, because this is not well-defined from a global semantics perspective: it’s in general not defined to define a function as “take its input, decompose it into x + y, and then run an arbitrary function on x and y individually.- The existing reduction collectives like lax.psum and lax.psum_scatter will accept both axes that are varying as well as unreduced, and do the obvious thing.
So what is reduced? As described in the original PR to make unreduced + AD work, reduced is like replicate, but it causes the cotangent (gradient) to be unreduced (and vice versa). Remember that in sharding with types, the cotangent sharding is always a function of the primal sharding. When you have a replicated primal, it is ambiguous whether or not you want the cotangent to be replicated or unreduced, so JAX introduces a new type (reduced) to let you distinguish them solely even if you are only looking at the primal type. The rule in JAX is replicate goes to replicate, and reduced goes to unreduced (and vice versa). Like unreduced, reduced is tracked both in Explicit and Manual mode.
How can you interact with reduced shardings? Unlike unreduced, you can directly device_put some data as reduced (since reduced is the same as replicated), or jax.reshard a tensor to a reduced placement. A transition from invariant/replicated to reduced is a no-op in forwards but triggers an all-reduce in backwards; similarly, you can pcast reduced to varying, which is a no-op for both forwards and backwards (invariant to varying forces an all-reduce immediately, since invariant’s cotangent is invariant, but reduced can delay the all-reduce since its cotangent type is unreduced). Separately, inside of a shard_map, lax.all_gather can be instructed to directly go to reduced (which will result in a reduce-scatter in backwards.)
How does reduced propagate? The propagation rule is simple: reduced is a statement about a set of mesh axes (not array dims), so we simply pass through the set of reduced axes whenever we do an operation. If an operation is N-ary, it’s required that all inputs are reduced on the same mesh axes. Unlike replicate, inputs cannot be replicated or sharded on reduced mesh axis, and JAX will force you to add conversions to make it typecheck. (It’s actually not clear to me that there isn’t a good default choice for implicit conversions here, but it’s certainly a lot safer to not allow it to start!)
January 27, 2026Global SPMD (also known as the “global view”, exposed by code using DTensor or jax.Array) refers to writing multi-device code as if it was on a single device, with an orthogonal mechanism for expressing how these full tensors are distributed over multiple devices (this mechanism can be implicit or explicit, e.g., as seen in this table).
Local SPMD (also known as “per-device view”, and exposed by local_map and shard_map, and also traditional PyTorch distributed code operating on plain Tensors, e.g., Megatron-style) refers to writing code from the “local” view on a single device, with explicit collectives when communicating across devices.
The big question I want to address in this post is, how do I pick between these two modes? Conventional wisdom in the JAX ecosystem is that you should default to global SPMD (either via auto or explicit sharding mode), and drop down to manual local SPMD if the compiler isn’t giving you the correct communications and you want to do it by hand. I want to give a more nuanced version of this take.
First, there is nothing about global SPMD that precludes fine-grained control over when collectives happen. JAX doesn’t directly support this kind of mode, but it’s not difficult to imagine how it would work: take JAX with explicit mesh axes, but instead of only erroring out on ambiguous communication, error out when any implicit communication happens (e.g., you must explicitly call reshard to trigger communication). We actually added an explicit mode to DTensor along these lines, although it currently doesn’t work super well because we lack some other important aspects of JAX sharding in types.
For me, the more important difference is that global and local SPMD are actually different semantics. An obvious divergence is that in local SPMD, there isn’t any source of truth about what the “global” view of the Tensor is: the local tensor just exists, you know that there are different versions of it on the other nodes. You don’t know how you’re supposed to stack these tensors together to get a global tensor: you typically only know this at the boundary of the local region using out_specs / out_placements. And even if you knew how to stack the tensors together, local SPMD has different semantics than global SPMD, as the exact computation you perform depends on how exactly the local tensor is sharded. You’re not doing an operation on the global tensor: you’re chunking the tensor, running the operation on each chunk, and then stacking it back together. The whole point of sharding propagation in global SPMD is to figure out if this is equivalent to running the operation on the full tensor, and there are many cases when it is not.
If you are not thinking carefully about your distributed computation, local SPMD can be a source of bugs. It is common to write distributed code where certain parallelisms are enabled or disabled. If you do a reduction over an axis, if that axis is replicated the result is replicated, but if it is sharded you will end up with a partial reduction that has to be accounted for in some other way. If you forget, the code will work when the parallelism is turned off and silently break when the parallelism is turned on. A bug like this is horrible enough that frameworks invest in ways to deal with situation.
This is perhaps the reason why Megatron is sometimes considered unfriendly for experimentation. Everything is written in local SPMD (as it doesn’t use DTensor), and if you want to experiment on something new you must upfront resolve all of the interactions with parallelism in the implementation of your code. This is all doable, but it can be pretty confusing if you are not a parallelism expert and easy to get wrong.
There is a flip side to this, however: if you are thinking carefully about your parallelism and are carefully orchestrating your local compute with your communications, it is much more natural to write things in local SPMD style. The local SPMD style only gives you operations that can be efficiently computed (since they are always local) and doesn’t force you to say what the global interpretation of a Tensor is (especially when it’s unnatural, like online softmax.) So once you get out of the experimentation phase and are working on efficiency, if you need some nontrivial communication pattern, it would be pretty normal to switch from global SPMD to local SPMD. But there’s also a lot of pedestrian modules that don’t need anything fancy, and it is better to keep them in global SPMD in that case.
In the PyTorch ecosystem, there are some more low level reasons why you might prefer local SPMD over global SPMD. The most prominent is DTensor’s eager overhead. Many parts of DTensor are implemented in Python rather than C++, and on the first invocation we must compute shard propagation rules, which is entirely in Python and quite expensive. It is possible to get reasonable performance with DTensor: if you torch.compile you can eliminate the overhead entirely, CUDA graphs also work, and FSDP2 shows that careful, minimal use of DTensor can still have acceptable CPU overhead. But this is perhaps one of the big reasons why distributed code with plain Tensors remains quite popular today.
January 26, 2026In Computing sharding with einsum, we worked an example of Megatron style tensor parallelism where we discover that the ordinary backwards formula for linear results in a pending reduction on grad_input, even though the input was replicated and no communications happened in forwards. In Megatron, which is implemented with plain Tensors and manual collectives, you just have to know that this reduction is necessary and manually insert it with a custom autograd function.
If we wanted to write a similar explicit-collective-style Megatron implementation in JAX, we would use shard_map. Like in Megatron, you have to call a function which is a no-op in forwards and an all-reduce in backwards. However, JAX has built this function into its core library, and has an interesting name for it: jax.lax.pcast(..., to='varying') (previously known as pvary, although apparently this is deprecated now.) Why is this a “cast”? The answer is that JAX’s shard_map actually comes with an optional type system (enabled with check_vma=True) which reject your program if you forget to insert an all-reduce in your backwards!
Let’s see how this works. As a reminder, our shapes are:
input: [sequence, batch, in_features]
weight: [in_features, out_features]
output: [sequence, batch, out_features]
We can describe the input and output sharding of these in JAX style (tensor-dim oriented), which is what we would feed into the in_specs and out_specs of shard_map:
input: P(None, None, None)
weight: P(None, "tp")
output: P(None, None, "tp")
Although on the boundaries of shard_map we can say exactly what the sharding of the input/output tensors are (e.g., how to reassemble them back into full tensors), on the inside of a shard_map this is not a well-defined question: you only have the local tensors and can do whatever you want with them before you reassemble them back into global tensors with shardings.
However, when check_vma=True, JAX will keep still track of something weaker than sharding: whether or not the tensors are varying (i.e., different) across a mesh dimension. This is conventionally notated as dtype[local_shape]@{varying axes}, e.g., f32[3]{i} means that this tensor varies across the mesh axis i (the braces are omitted when nothing varies). Let’s write down the local shapes of our input/output tensors with variance information (note that the sharded tensor dimensions have shapes that are divided by the mesh they are sharded over):
input: f32[sequence, batch, in_features] # nothing varying
weight: f32[in_features, out_features/tp]@{tp} # varies over tp dim
output: f32[sequence, batch, out_features/tp]@{tp} # varies over tp dim
The point of a type system is that you have typing rules that say whether or not an operation is legal between two types, rejecting programs that are ill-typed. In particular, in jaxpr it’s illegal to have an operation like matrix multiply between two tensors with differing VMA: you have to insert a cast to make the VMAs match before you can do the operation. Actually, JAX will typically insert these casts implicitly for you, but for clarity we’re going to insert the cast explicitly here:
import jax
import jax.numpy as jnp
jax.config.update('jax_num_cpu_devices', 2)
jax.set_mesh(jax.make_mesh((2,), ('tp',)))
sequence, batch, in_features, out_features = 4, 2, 8, 16
input = jnp.ones((sequence, batch, in_features))
# NB: the full weight, shard_map will automatically shard weight for us given the
# input spec
weight = jnp.ones((in_features, out_features))
@jax.shard_map(in_specs=(jax.P(None, None, None), jax.P(None, "tp")), out_specs=jax.P(None, None, "tp"))
def colwise_linear(input, weight):
print('input', jax.typeof(input))
print('weight', jax.typeof(weight))
input = jax.lax.pcast(input, "tp", to="varying")
print('pcast_input', jax.typeof(input))
output = jnp.einsum("sbi,io->sbo", input, weight)
print('output', jax.typeof(output))
return output
output = colwise_linear(input, weight)
This prints:
input float32[4,2,8]
weight float32[8,8]{V:tp}
pcast_input float32[4,2,8]{V:tp}
output float32[4,2,8]{V:tp}
It’s a little difficult to see why the program “doesn’t typecheck” when you omit the pcast, because even if you don’t pcast JAX will implicitly insert it for you, but you can verify that the cast happens anyway by inspecting the HLO with and without the explicit pcast (this is left as an exercise to the reader to verify; an LLM can one-shot this program transformation). The type system here is also not that invasive: local operations simply propagate variance (varying to varying, invariant to invariant), and you only need a small menagerie of collective operations to help you manage collectives which can take you between varying and invariant.
Although JAX’s type system here is a bit difficult to explain from first principles, it seems quite beneficial as it makes it impossible to forget backwards reductions that are required when you do operations between sharded and replicated tensors. We’re hoping to bring a similar capability to PyTorch DTensor in the near future, introducing a new “LTensor” subclass that is able to track metadata along these lines.
January 25, 2026Mental arithmetic in grade school (e.g., memorizing your times tables) is typically justified on the grounds that facility in basic calculations makes it easier to focus on higher-level problems that require being able to do these manipulations. When working on DTensor, I have also found it important to be able to quickly calculate what shardings you get when you do matrix multiplies on sharded tensors. Without being able to do this quickly and accurately, working through examples becomes a slog. I’ve also found that while diagrammatic approaches (e.g., drawing a matrix and slicing it into shards) are intuitive, they are slow and unwieldy to do calculations with.
Recently, I’ve found that working on sharding with einsum is nice and efficient, and I hope to persuade you to do it this way when you need to reason about sharding! This post somewhat overlaps with Sharded Matrices and How to Multiply Them, but with some different emphasis and some different notation.
Einstein summation is a compact way of representing many multi-dimensional linear algebra operations, including matrix multiplies. It is nice because you don’t have to puzzle through the abstruse differences of matrix multiply operations like @, torch.matmul, torch.bmm, torch.mm: for any “matrix multiply”, as long as you know the input and output shapes of your tensor, you can directly write out an einsum equation. For example, classic matrix multiply as you see it in math has a signature like mm(x: f32[A, B], y: f32[B, C]) -> f32[A, C]. In einsum notation, you would simply write torch.einsum("ij,jk->ik", x, y): each of the indices lines up exactly with the input sizes. As another example, in nn.Linear, your weight has shape (in_features, out_features). You don’t have to remember how to setup the transposition, just write torch.einsum("bi,oi->bo", input, weight).
A useful piece of terminology that pops up for einsum is a contraction dimension. This is any index that appears in the input tensors but not the output tensors. The ones that show up in both inputs and outputs are free dimensions: if the free dimension is in all inputs it’s a batch dimension, and if it’s missing from some inputs we will broadcast those tensors.
Do you always forget how exactly you should transpose your tensors in the backward formula for matrix multiply? As long as you aren’t doing weird things in your einsum (e.g., no repeated indices, every input index is paired with another index), there is a very simple way to compute backwards: keep every input constant except the one you want to compute the gradient for, and swap its index set with the output index set.
For example, linear is "bi,oi->bo" for (input, weight -> output). Then we have:
grad_input = torch.einsum("bo,oi->bi", grad_output, weight)
grad_weight = torch.einsum("bi,bo->oi", input, grad_output)
Intuitively, the reason this works is because reverse-mode AD actually is just transposing the linear function defined by our einsum, and transposed matrix multiplies can be implemented by just reading off its shapes.
Now that we’re thinking in terms of einsum formulas, all we need is the sharding rule for einsum. The sharding rule tells us under what situations we can perform a matrix multiply by simply doing matrix multiplies on the local shards, producing the output matrix under some output placement.
There are not too many rules. Take a running example "abi,aoi->abo", we can write down these valid placements for a particular mesh dimension (I’ve replaced numeric dim indices with the einsum character index for readability):
- If everything is replicated, the output is replicated:
Replicate(), Replicate() -> Replicate() - If a batch dimension is sharded, the output batch dimension is also sharded:
Shard("a"), Shard("a") -> Shard("a") - If a free dimension is sharded, the output free dimension is sharded, but any broadcasted input must be replicated:
Shard("b"), Replicate() -> Shard("b") - If a contraction dimension is sharded, we will have a pending reduction:
Shard("i"), Shard("i") -> Partial()
You can look at Computation With Sharded Arrays for a more detailed explanation for each of these cases.
In 2019, Xiaolin Li asked this question about CopyToModelParallelRegion in Megatron:
Why the backward function of _CopyToModelParallelRegion calls reduce fuction? Can somebody share the mathematical proof?
Let’s answer Xiaolin’s question. In Megatron, ColumnParallelLinear is defined as:
input: [sequence, batch, in_features]
weight: [in_features, out_features]
output: [sequence, batch, out_features]
In einsum notation, this is torch.einsum("sbi,io->sbo", input, weight).
On the TP mesh dimension, we have this sharding:
input: Replicate()
weight: Shard("out_features")
output: Shard("out_features")
Let us assume that grad_output: Shard("out_features"). Let’s compute the placements of grad_weight and grad_input. First the derivative formulas:
grad_input = torch.einsum("sbo,io->sbi", grad_output, weight)
grad_weight = torch.einsum("sbi,sbo->io", input, grad_output)
So we see:
grad_input: Partial() # o is sharded and a contraction dim
grad_weight: Shard("out_features") # o is sharded and a free dim
We see that grad_input has a pending reduction, and if downstream backwards is expecting to receive replicated tensors, we must trigger an all-reduce (e.g., in Megatron this all-reduce is manually triggered by _CopyToModelParallelRegion; if you use DTensor, it will just propagate the Partial() until a redistribution to Replicate() is required.)
In sequence parallel, we will shard the sequence dimension of an input, but not the weight. Let’s say we have a learnable scaling factor:
input: [sequence, batch, hidden]
weight: [hidden]
output: [sequence, batch, hidden]
In einsum notation, this is torch.einsum("sbh,h->sbh", input, weight).
On the SP mesh dimension, we have this sharding:
input: Shard("sequence")
weight: Replicate()
output: Shard("sequence")
Then we have:
grad_input = torch.einsum("sbh,h->sbh", grad_output, weight)
grad_weight = torch.einsum("sbh,sbh->h", input, grad_output)
So we see:
grad_input: Shard("sequence") # s is sharded and a free dim
grad_weight: Partial() # s is sharded and a contraction dim
Here, we must do an all-reduce over grad_weight to get the true replicated gradient.
Notice that this example is very similar to the tensor parallelism one, but the roles of input and weight have been swapped!
January 4, 2026This blog has lived on WordPress since it was initially created during a social challenge at MIT to write a blog post a week or pay up with beer. I remember a very important piece of advice I had been given at that time: don’t fuck around with your blog authoring software, just do the minimum viable thing (use WordPress) and focus on writing posts.
It’s 2026 now, the world is different, and in particular the existence of coding agents means that this particular advice falls flat now: it has never been easier to vibe code your own blog software and be done in an afternoon of token generation. Similarly, over the years, I had been increasingly unhappy about my WordPress setup (too hard to add images, ancient version of WordPress, Markdown has taken over the world why am I still writing in ReST, I love scripts.mit.edu but I definitely don’t want to use it to host serious things). So I typed this into ChatGPT and Claude and asked it what I should migrate too.
I currently have a Wordpress blog whose 633 posts are written in ReST using rest-wordpress with some manual code edits, and a theme based on Ashley that I also customized. I’d like to migrate to another blogging solution. I care a lot about ensuring the URLs are preserved. To a lesser extent, I also care about the comments, although I’m willing to compromise here (e.g., an offline flow where I have to explicitly publish comments might be OK; I know static site is difficult to support comments; I also know that email newsletter is popular and I’d like to support this modality if possible. I don’t use a WYSIWYG editor. It’s on Wordpress 5.1.19. It would be nice to have a way for people. Some more niche things plugins I’ve used is WP LaTeX and Share a Draft but I’m willing to do a lossy conversion if necessary (I don’t use LaTeX that much now; it’s just important to make sure the old posts still format correctly). Many of my posts have images and I’d like an easier flow than my current flow (where I have to manually upload my images to my server and then hyperlink them into the post). What do you recommend?
It suggested Hugo, which I had played around with before in AI Blindspots, and I figured, “Why not, I’ll just ask Claude to do the migration. A few hours later, two pro sessions worth of tokens and some PHP export scripts, the entire blog was moved over, no muss, no fuss. I live streamed a portion of this migration process although there’s nothing that special about it.
I actually wasn’t going to write a blog post about this, but I saw Jeff Geerling’s blog also had made frontpage Hacker News. I too haven’t figured out how I am going to solve the comments problem on the new format; I also think I will figure out how to get an email newsletter going from the blog. Here’s to seeing if this can encourage you to use LLMs to make the jump for your own personal site!
January 4, 2026Let’s suppose you asked an AI coding agent to “implement a CLI calculator”.
Imagine if, instead of only writing short Python script, it also started
building an automated test suite, a crash reporting mechanism and a telemetry
subsystem. You’d be like, “What the fuck is this?”
But now let’s say that you were planning to release this project to users. It
would be clearly negligent to not have an automated test suite. A crash
reporting mechanism might be overkill for a simple calculator, but for more
complicated CLIs interacting with the real world, it may not always be
feasible to have reproducer, in which case crash logs are essential.
Similarly, a telemetry subsystem would be wildly inappropriate for an open
source local-only calculator, but it could make sense for a networked
application or a corporate tool of all consenting users. One of the important
functions of a senior engineer is to be able to evaluate the context a
software project lives in and figure out if we need to do something, even if
it isn’t explicitly asked for. This is contrast to a helpful assistant, who
is first and foremost obligated to follow the user’s instructions. This
leads to a gap between a Helpful Assistant and a Senior Engineer.
In principle, you could prompt the LLM agent to act like a Senior Engineer.
In fact, why stop at Senior, let’s tell the LLM to be a Staff Engineer!
Imagine that scaling continues: what would you expect the LLM to do when
instructed to act in this way? Well, imagine a human L7 engineer who has just
been hired by a big tech company to head up some big, new, multi-year
initiative. Will they say, “Sure, I can help with that!” and start busily
coding away? Of course not: they will go out and start reviewing code,
reading docs, talking to people, asking questions, shadowing oncalls, doing
small starter tasks–they will start by going out and building context. Here,
the “helpful assistant” frame for LLMs is limiting: sure, Claude might ask you
a few questions to clarify the task upfront, but if your coding agent starts
asking you about “relevant partner teams” and “org-wide priorities for this
half” you are definitely going to raise an eyebrow.
What would take for an LLM to be able to act like a Senior Engineer?
Perhaps prompting is all you need, and you just need to write enough
information about the surrounding context for a project, and once you feed
in enough tokens, a smart model can infer the rest of the details you didn’t
explicitly right down. This context would be bespoke for every project; you
would have to redo this exercise every time you had a new project!
Perhaps you can instead prompt a model on how to operate agentically to get
the context it needs. This prompt here might be more reusable.
But the model may need to actually do wetwork (e.g., talk to humans) to get
all of the information it needs. And remember the old saying: the more generic
the advice is, the less useful it is. Specificity is king, which leads to…
Let’s say we solve continual learning. Instead of crafting the perfect
prompt upfront; you could just drop the model as an “embodied” software
developer. It reads code, talks to people, does projects, and in doing so
slowly develops its latent context, in the same way a human engineer does.
Building context will often be bottlenecked in the same way humans are: you
can’t get experience related to pushing a feature to production, until
you’ve actually pushed the feature to production (however long that takes).
But just like how you shouldn’t micromanage a Senior Engineer, all of these
approaches involve fundamentally different expectations about what an AI
coding agent should do, and so even if a model and scaffold are capable of
doing these things, it is altogether another question if it will be asked to
behave in this way. So let’s not take it as a foregone conclusion that METR
task times will keep following the empirical trendline: I expect a phase
transition when the context an LLM needs to do a good job exceeds the
capability of scaffolding to provide on the fly.
December 20, 2025I’ve recently been doing a lot of both submitting and reviewing pull requests to PyTorch that were authored with substantial LLM assistance. This is a big difference from earlier this year, where it was clear LLMs worked well for greenfield projects but the code was too hopelessly sloppy for a production codebase. Here are my merged PRs that mention claude code in their description; Jason Ansel has also had a similar experience (Meta only link, here is the list of issues he referenced in his writeup). There already has been increasing discourse (Simon Willison, LLVM) on how code review should adapt to this new era of LLMs. My contribution to this discourse is this: within teams, code review should change to being primarily be a human alignment mechanism.
Here is a simple example: it is well known that LLMs are prone to generating overly defensive code: e.g., they will be constantly sprinkling try...catch everywhere or testing if a variable is some type when system invariants imply that it should always be that type. If someone sends me a PR with these problems, I am not commenting on these problems solely because I want them to be fixed. If that’s all I cared about, I could have just fed my comments directly to claude code. The real problem is that the human who was operating the LLM didn’t agree with me that this defensive code was bad, and the point of the review is to align them with me on what is overly defensive versus not. In the most trivial cases, maybe the engineer didn’t read the LLM output, in which case the remedy is to make them actually read the code. But sometimes real human work has to happen; for example, maybe there is a global system invariant that one has to understand to know if the defensiveness is necessary or not. If we agree about the global system invariants, there’s no reason the code review has to go through me: the original code author can just instruct the LLM to fix problems and keep me out of the loop until they have aligned the LLM output to themselves–at which point we should do the more expensive human to human alignment. The ideal is that I don’t need to ever write review comments about mechanical problems, because they have already been fixed by the original author ahead of time.
Conversely, when I am putting up an LLM generated PR for human review, I am trying to transmit higher level information. How does the new code work? What do I need to know about the existing system to understand this code? This doesn’t even have to be in the PR description: if the LLM proposes a fix that I myself don’t understand, or seems difficult to understand, I will simply instruct it to try it a different way, until the resulting diff is obviously correct. Tokens are cheap: we should expect more out of the author of code, because the cost of generating these PRs has gone way down. Similarly, I am willing to throw out the code and start again; you don’t have to feel bad about wasting my time (I didn’t type it! I spent my time understanding the problem, and none of that is regretted.)
There is a lot of scaremongering about how engineers who don’t pick up AI tools will be left behind. My take on this is that there a number of different skills that make up what it means to be a good software engineer, and it is clear that LLM coding, even today, is clearly reweighting the relative importance of these skills. I care a lot more about your ability to read code, reason about the big picture, communicate clearly and to have good taste, than I care about your ability to mechanically write code. There is an archetype of junior engineer who is not that good at coding but very good at the softer, higher level skills, and I think they will be very valuable in this new world order. Conversely, I think going forward I will have substantially less patience if I have to keep telling you the same things over and over, because I just don’t value raw “ability to code” as much anymore. My ideal state is like that with long time senior teammates: I can trust that they have made good low level decisions, and I can focus on understanding the bigger picture and updating my mental model of how the system works.
Today’s LLMs have no memory: they have to rediscover everything in the system from first principles every time they are run. The purpose of the humans, of the team, is to collectively maintain a shared vision of what, platonically, the system should do. I want code review to reconfigure itself around this purpose.