State of torch.compile for training (August 2025)
August 13, 2025The purpose of this post is to sum up, in one place, the state of torch.compile for training as of August 2025. Nothing in here isn’t something you might not already know about from elsewhere on the Internet, but we rarely put everything together in one place. The target audience for this document are teams who are evaluating the use of torch.compile for large scale training runs.
First, the basics. torch.compile (also known as PT2) is a compiler for PyTorch eager programs for both inference and training workloads. Speedups from 1.5-2x compared to eager code are typical, and torch.compile also makes it possible to do global optimizations for memory (e.g., automatic activation checkpointing) and distributed communications (e.g., async tensor parallelism).
What is torch.compile’s functionality?
The headline functionality of torch.compile is a decorator you can attach to a function to compile it:
@torch.compile()
def f(x, y):
...
Here are some non-functional properties of compile which are important to know:
- Just-in-time compilation. We don’t actually compile the function until it is called for the first time, and execution blocks until compilation completes. There is both local and remote caching to skip compilation cost when you rerun the model. (Ahead-of-time compilation is possible for inference with AOTInductor, and is being worked on for training.)
- Compositional with Eager. PyTorch’s original success comes from the extreme hackability of eager mode, and torch.compile seeks to preserve this. The function can be as big or as small part of your training loop as you like; compiled functions compose with autograd, DDP, FSDP and other PyTorch subsystems. (This composition is sometimes imperfect, e.g., in the case of double backwards (not supported), tensor subclasses (requires specific support from the subclass), autograd (differentiating with respect to intermediates returned from a compiled region does not work).) If compilation doesn’t work on a region, you can disable it entirely with torch.compiler.disable() and fall back to eager.
- Gradient updates are delayed to the end of compiled regions. This arises because PyTorch eager autograd does not support streaming gradients incrementally from a large backward node. (This can be solved by using compiled autograd, but this requires that the entirety of your backwards be compileable.)
- Graphs may be recompiled. We aggressively specialize on all non-Tensor arguments/globals used in the function to ensure we always generate straight-line computation graphs with no control flow. If those arguments/globals change we will recompile the graph. (Recompilations can be banned with
torch._dynamo.config.error_on_recompile = True.) - Static by default, recompile to dynamic shapes. We aggressively specialize all sizes to static. However, if we discover that a size varies over time, on the first recompile we will attempt to generate a single compiled region that handles dynamic shapes. We are not guaranteed to be able to compile a model with dynamic shapes. (You can use mark_dynamic to force an input shape to be dynamic, and you can use mark_unbacked to error if we specialize.)
- Graph breaks transparently bypass non-capturable code. By default, if the compiler encounters a line of code that it is not able to handle, it will trigger a graph break, disabling compilation for that line of code, but still attempting to compile regions before and after it. (This behavior can be banned with fullgraph=True.)
- Function calls are inlined and loops are unrolled by default. If you have many copies of a Transformer block in your model, your compile time will scale with the number of Transformer blocks. (You can reduce compile time by doing “`regional compilation <https://docs.pytorch.org/tutorials/recipes/regional_compilation.html>`_”, where you only compile the Transformer block instead of compiling the entire model.)
- NOT bitwise equivalent with eager PyTorch. The biggest divergence with eager PyTorch is that when float16/bfloat16 operations are fused together, we do not insert redundant down/up-conversions. (This can be disabled torch._inductor.config.emulate_precision_casts = True; you can also rewrite eager code to perform operations in higher precision with the understanding torch.compile will optimize it. XLA has a similar config xla_allow_excess_precision which JAX enables by default.) However, we may also make decisions to swap out, e.g., matmul implementations, and there may also be slight divergence that arise from differences in reduction ordering that are unavoidable when compilation occurs. We support ablating the graph capture frontend separately from the compiler backend to help diagnose these kinds of problems.
- Distributed collectives and DTensor can be compiled, but are unoptimized by default. We are able to capture c10d collectives and also programs that handle DTensors, but we don’t apply optimizations to collectives by default. (There are experimental optimizations that can be enabled, but this is active work in progress.) We generally do not expect to be able to trace through highly optimized distributed framework code.
State of advanced parallelism
For large scale training runs, torch.compile faces stiff competition from (1) PyTorch native distributed frameworks which embrace eager mode and implement all optimizations by hand (e.g., megatron), (2) custom “compiler” stacks which reuse our tracing mechanisms (e.g., symbolic_trace and make_fx) but implement their desired passes by hand, (3) JAX, which has always been XLA first and is years ahead in compile-driven parallelism techniques.
Here is where we currently are for advanced parallelism (with an emphasis on comparing with JAX):
DTensor, a “global tensor” abstraction for representing sharded tensors. DTensor is a tensor subclass which allows us to represent tensors which are sharded over an SPMD device mesh. The shape of a DTensor reflects the global shape of the original full tensor, but it only stores locally a shard of the data according to the placement. Here are some important details:
- Shard placements. Unlike JAX placements, DTensor placements are “device mesh” oriented; that is to say, you conventionally specify a device mesh dim size list of placements, and Shard(i) indicates that the ith dimension of a tensor is sharded. This is opposite of JAX, which is “tensor” oriented. For example, given a 2-D mesh [“dp”, “tp”], a tensor with [Replicate, Shard(0)] in DTensor placement (or {“dp”: Replicate, “tp”: Shard(0)} with named device mesh axes), would correspond to a JAX placement of P(“tp”, None). The reason for this is that DTensor supports a Partial placement, which indicates that an axis on the device mesh has a pending reduction. Partial shows up ubiquitously from matrix multiplies, and it isn’t associated with any particular tensor axis, making it more convenient to represent in a device-mesh oriented formulation. The tradeoff is that device-mesh oriented placements don’t naively support specifying sharding ordering, e.g., suppose I want to shard a 1-D tensor on tp and then dp, in JAX I’d represent this as P((“tp”, “dp”),) but this order cannot be disambiguated from [Shard(0), Shard(0)] and in fact DTensor always forces left-to-right sharding. There is currently a proposal to extend our sharding specification to support ordering to bring us to parity with JAX expressiveness, but it is not yet implemented.
- Autograd. DTensor is directly differentiable; we run autograd on programs that have DTensors (as opposed to desugaring a DTensor program to one with regular Tensors and differentiating it). This ensures that the sharding strategy of a primal and its corresponding tangent can diverge. This is parity with JAX.
- Python subclass of Tensor. Unlike JAX, DTensor is a separate subclass from Tensor. However, Tensor and DTensor interoperate fine; a Tensor can simply be thought of as a DTensor that is replicated on all dimensions. DTensor is implemented in Python, which makes it easy to modify and debug but imposes quite a bit of overhead (for example, FSDP2 does not directly accumulate gradients into DTensor, because with thousands of parameters, performing detach and add operations on DTensor is a bottleneck). Still, despite this overhead, DTensor was designed for good eager performance, and extensively caches the results of sharding propagation so that in the fastpath, it only needs to lookup what redistribute it should perform and then directly dispatches to the local eager operation. However, this caching strategy means that overhead can be quite high for workloads with dynamic shapes, as the cache requires exact matches of all input shapes.
- Compilation. DTensor is compilable by torch.compile, and doing so will desugar it into its underlying collectives and eliminate any eager mode DTensor overhead (even if you do not perform any other optimizations.) However, DTensor with dynamic shapes in compile is not well supported, see http://github.com/pytorch/pytorch/issues/159635 (we don’t think this is currently critical path for any critical use cases, so a relatively junior engineer has been chipping away at it.)
- Greedy propagation. Because DTensor must work in eager mode, it only implements greedy shard propagation, where for every eager operation we greedily pick whatever output shard minimizes the collective costs of an operation. It is work in progress to support backward propagation of sharding with the assistance of a compiler-like framework.
- Operator coverage. DTensor requires sharding propagation rules to work for operations. If a sharding propagation rule is not implemented, DTensor will fail rather than trigger an inefficient allgather to run the operator under replication. We don’t currently have full coverage of all operators, but important operators for transformer models like llama3 are all covered (sharding rules are defined here). You can write custom shardings for user defined operators.
- Jagged sharding. We do not support a “jagged sharding” concept which would be necessary for expert parallelism with imbalanced routing. However, we believe that our existing sharding rules could largely be reused to support such an idea. As dynamism would only be exposed in the local tensor for the jagged shard, jagged shards don’t suffer from the dynamic shapes problems mentioned in the compilation section.
- Ecosystem. We are committed to DTensor as the standard representation for sharded tensors, and DTensor is integrated with checkpointing, FSDP2, SimpleFSDP, AutoParallel, torchtitan, among others.
Functional collectives. If you don’t like DTensor, we also support “functional collectives”, which are non-mutating versions of collective operations that can be used to manually implement SPMD operations in a compiler-friendly way without needing DTensor. (In fact, if you use traditional collective APIs and compile them, we will silently translate them into functional collectives for compiler passes.) When compiled, functional collectives don’t necessarily force allocation of the output buffer as they can be re-inplaced. Importantly, functional collectives currently do NOT support autograd, see https://discuss.pytorch.org/t/supporting-autograd-for-collectives/219430
Graph capture. There are two particularly popular graph capture mechanisms which people have used to perform distributed optimizations separate from model code. All graph capture mechanisms produce FX graphs, which are a simple Python basic block IR representation with no control flow, which is entirely unopinionated about what actual operator set can occur in the graph.
- Symbolic_trace. This was the original graph capture mechanism and is quite popular, despite its limitations. It is implemented entirely with Python operator overloading and will give you exactly whatever operations are overloadable in the graph. We consider this largely a legacy pipeline as you are unable to trace code involving conditionals on shapes and you end up with a graph that has no useful metadata about the shapes/dtypes of intermediate values. For example, PiPPY, a legacy stack for performing pipeline parallelism, was built on top of symbolic_trace graph capture.
- make_fx/torch.export. This graph capture mechanism works by actually sending (fake) tensors through your program and recording ATen operators. There are a number of different variants: e.g., whether or not it is a Python tracing approach ala JAX jit, or whether it uses sophisticated bytecode analysis ala Dynamo; similarly, there are various levels of IR you can extract (pre-dispatch, post-dispatch; also, operators can be decomposed or kept as single units). Our compiler parallelism efforts are built on top of this capture mechanism, but there is nothing stopping you per se from writing your own graph pass on top of this IR. In practice, this can be difficult without PyTorch expertise, because (1) integrating a traced graph into PyTorch’s autograd system so it can interoperate with other code is quite complicated to do in full generality, (2) the exact operator sets you get at various phases of compilation are undocumented and in practice very tied to the Inductor lowering stack, and it is poorly documented on how to prevent operators from getting decomposed before your pass gets to them.
Not SPMD compiler by default. torch.compile does not assume the program being compiled is SPMD by default, which means it will not do things like drop unused collectives (you can change this behavior with a config flag). Additionally, the default mode of use for torch.compile is to compile in parallel on all nodes, which means care has to be taken to ensure that every instance of the compiler compiles identically (only one rank recompiling, or compilers making different decisions, can lead to NCCL timeout). We ultimately think that we should compile a program once and send it to all nodes, but as this is not currently implemented, the general approach people have taken to solve this problem is to either (1) eliminate all sources of divergent behavior from ranks, e.g., don’t allow the compiler to look at the actual size for dynamic inputs when making compiler decisions, or (2) introducing extra collectives to the compiler to communicate decisions that must be made consistently across all ranks.
Our vision for the future of advanced parallelism, spearheaded by the in-progress SimpleFSDP and AutoParallel, is that users should write single-node programs that express mathematically what they want to do. These are then transformed into efficient distributed programs in two steps: (1) first, collectives are inserted into the graph in a naive way (i.e., simply to express what the sharding of all intermediates should be), and (2) the collectives are optimized to handle scheduling concerns such as pre-fetching and bucketing. AutoParallel sets a GSPMD style goal of automatically determining a good enough sharding for a program–it should be able to rediscover data parallel, tensor parallel, even expert parallel(!)–but SimpleFSDP sets a smaller goal of just inserting collectives in the pattern that FSDP would mandate, and then writing FSDP-specific optimization passes for recovering FSDP2’s performance. It is very common to write domain specific optimizations; for example, async tensor parallelism is also implemented as a pass that detects TP patterns and rewriting them to async TP operations. Unlike JAX, which started with a very generic solver and has needed to add more manual escape hatches over time, PyTorch has started with writing all of the distributed patterns exactly by hand, and we are only recently adding more automatic mechanisms as an alternative to doing everything by hand.
State of optimization
torch.compile performs many optimizations, but here are some particularly important ones to know about:
- Inductor. Inductor is our backend for torch.compile that generates Triton kernels for PyTorch programs. It has very good coverage of PyTorch’s operator set and can do fusions of pointwise and reductions, including in the patterns that typically occur for backwards. It also is able to fuse pointwise operations into matmuls and autotune different matmul backends (including cuBlas, cutlass and Triton) to select the best one for any given size. When people talk about torch.compile speeding up their programs, they are conventionally talking about Inductor; however, you don’t have to use torch.compile with Inductor; for example, you could run with AOTAutograd only and skip Inductor compilation.
- CUDA graphs. Inductor builds in support for CUDA graphing models. Unlike manual CUDA graphs application, we can give better soundness guarantees than manual CUDA graphs application (e.g., forgetting to copy in all input buffers, CPU compute inside the CUDA graph region). torch.compile CUDA graphs is typically used with Inductor but we also offer an eager-only cudagraphs integration (that is less well exercised).
- Automatic activation checkpointing. With torch.compile, we can globally optimize the memory-compute tradeoff, much better than the activation checkpointing APIs that eager PyTorch supports (and require the user to manually feed in what they want checkpointed or not). However, some folks have reported that it can be quite miserable tuning the hyperparameter for AC; we have also found bugs in it.
- FP8 optimizations. One big success story for traditional compilation was adding support for a custom FP8 flavor. With torch.compile, they didn’t have to write manual kernels for their variant. This has since been upstreamed to torchao.
- Flex attention. Flex attention usage continues to grow, with 632 downstream repo users in OSS (vs 125 in Jan ‘25). It has been used to enable chunked attention, document masking and context parallelism in llama family models. It is a really good research tool, although sometimes people complain about slight numerical differences.
- Helion. Helion is an actively developed project aiming to go beta in October this year which offers a higher level interface for programming Triton kernels that looks just like writing PyTorch eager code. It relies heavily on autotuning to explore the space of possible structural choices of kernels to find the best one. It is not production ready but it is worth knowing that it is coming soon.
State of compile time
torch.compile is a just-in-time compiler and as such, in its default configuration, compilation will occur on your GPU cluster (preventing you from using the GPUs to do other useful work!) In general, most pathological compile times arise from repeated recompilation (often due to dynamic shapes, but sometimes not). In Transformer models, compile time can also be improved by only compiling the Transformer block (which can then be compiled only once, instead of having to be compiled N times for each Transformer block in the model).
We don’t think caching is an ideal long-term solution for large scale training runs, and we have been working on precompile to solve the gap here. Precompile simply means having compilation be an ahead-of-time process which produces a binary which you can directly run from your training script to get the compiled model. The compilation products are built on top of our ABI stable interface (developed for AOTInductor) which allows the same binaries to target multiple PyTorch versions, even though PyTorch the library does not offer ABI compatibility from version to version.
How do I get started?
The most typical pattern we see for people who want to make use of torch.compile for large-scale training is to fork torchtitan and use this codebase as the basis for your training stack. torchtitan showcases PyTorch native functionality, including torch.compile–in effect, it shows you how to use features in PyTorch together in a way that lets you do large-scale training. From there, swap out the components you are opinionated about and keep the things you don’t care about.
Thanks for writing this, very helpful!
> We don’t think caching is an ideal long-term solution for large scale training runs, and we have been working on precompile to solve the gap here.
Could you share some more on this? What are the perceived downsides of caching for large-scale training runs?