Famously, PyTorch and JAX don’t agree on how shardings should be represented: PyTorch takes a mesh-dim oriented view, where for each dimension in your device mesh, you specify what sharding should be applied; JAX takes a tensor-dim oriented view, where for each dimension on your tensor, you say which mesh dimensions (potentially multiple!) shard it. Among my Twitter followers, it is generally agreed that the JAX formulation is more intuitive from a user perspective. OK, fine; if you prefer one representation over another, it’s easy enough to translate between the two representations (in easy situations, at least!) In this post, I want to talk more about the framework implementation side: what is the better internal representation of sharding? I don’t claim to have all the answers, but my motivation for writing this post is to help explain where I currently stand and how I evaluate proposals for evolving DTensor and sharding in PyTorch.
Read more...
I have recently needed to draw the contents of high-dimensional (e.g., 4D and up) tensors where it is important to ensure that is clear how to identify each of the dimensions in the representation. Common strategies I’ve seen people do in this situation include printing a giant list 2D slices (what the default PyTorch printer will do) or flattening the Tensor in some way back down to a 2D tensor. However, if you have a lot of horizontal space, there is a strategy that I like that makes it easy to identify all the axes of the higher dimensional tensor: draw it as a matrix of matrices.
Read more...
With contributions from Richard Zou.
PT2’s dominant internal representation, FX graphs, do not directly support control flow (if statements, while loops): they only represent straight-line basic blocks. Most of our graph capture mechanisms are tracing based (fx.symbolic_trace, make_fx, Dynamo), which means that we expect to be able to linearize all conditionals we encounter into a straight line program. Sometimes, you want to work with code that has control flow while working the compiler stack. There is no silver bullet, instead there are a lot of different options with different tradeoffs.
Read more...
When training large scale LLMs, there is a large assortment of parallelization strategies which you can employ to scale your training runs to work on more GPUs. There are already a number of good resources for understanding how to parallelize your models: I particularly recommend How To Scale Your Model and The Ultra-Scale Playbook. The purpose of this blog post is to discuss parallelization strategies in a more schematic way by focusing only on how they affect your device mesh. The device mesh is an abstraction used by both PyTorch and JAX that takes your GPUs (however many of them you’ve got in your cluster!) and organizes them into a N-D tensor that expresses how the devices communicate with each other. When we parallelize computation, we shard a tensor along one dimension of the mesh, and then do collectives along that dimension when there are nontrivial dependencies between shards. Being able to explain why a device mesh is set up the way it is for a collection of parallelization strategies is a good check for seeing if you understand how the parallelization strategies work in the first place! (Credit: This post was influenced by Visualizing 6D Mesh Parallelism.)
Read more...
CuTe is a C++ library that aims to make dealing with complicated indexing easier. A key part of how it does this is by defining a Layout type, which specifies how to map from logical coordinates to physical locations (CuTe likes to say layouts are “functions from integers to integers.”) In fact, CuTe layouts are a generalization of PyTorch strides, which say you always do this mapping by multiplying each coordinate with its respective stride and summing them together, e.g., i0 * s0 + i1 * s1 + .... Although NVIDIA’s docs don’t spell it out, the CuTe’s generalization here is actually very natural, and in this blog post I’d like to explain how you could have invented it (on a good day).
Read more...
The purpose of this post is to sum up, in one place, the state of torch.compile for training as of August 2025. Nothing in here isn’t something you might not already know about from elsewhere on the Internet, but we rarely put everything together in one place. The target audience for this document are teams who are evaluating the use of torch.compile for large scale training runs.
First, the basics. torch.compile (also known as PT2) is a compiler for PyTorch eager programs for both inference and training workloads. Speedups from 1.5-2x compared to eager code are typical, and torch.compile also makes it possible to do global optimizations for memory (e.g., automatic activation checkpointing) and distributed communications (e.g., async tensor parallelism).
Read more...
In my previous two posts “`Ways to use torch.compile <http://blog.ezyang.com/2024/11/ways-to-use-torch-compile/>`_” and “`Ways to use torch.export <http://blog.ezyang.com/2024/12/ways-to-use-torch-export/>`_”, I often said that PyTorch would be good for a use case, but there might be some downsides. Some of the downsides are foundational and difficult to remove. But some… just seem like a little something is missing from PyTorch. In this post, here are some things I hope we will end up shipping in 2025!
Read more...
Previously, I discussed the value proposition of torch.compile. While doing so, I observed a number of downsides (long compile time, complicated operational model, lack of packaging) that were intrinsic to torch.compile’s API contract, which emphasized being able to work on Python code as is, with minimal intervention from users. torch.export occupies a different spot in the tradeoff space: in exchange for more upfront work making a model exportable, it allows for use of PyTorch models in environments where using torch.compile as is would be impossible.
Read more...
On the surface, the value proposition of torch.compile is simple: compile your PyTorch model and it runs X% faster. But after having spent a lot of time helping users from all walks of life use torch.compile, I have found that actually understanding how this value proposition applies to your situation can be quite subtle! In this post, I want to walk through the ways to use torch.compile, and within these use cases, what works and what doesn’t. By the way, some of these gaps are either served by export, or by missing features we are actively working on, those will be some other posts!
Read more...
Tensor libraries like PyTorch and JAX have developed compact and accelerated APIs for manipulating n-dimensional arrays. N-dimensional arrays are kind of similar to tables in database, and this results in the logical question which is could you setup a Tensor-like API to do queries on databases that would be normally done with SQL? We have two challenges:
- Tensor computation is typically uniform and data-independent. But SQL relational queries are almost entirely about filtering and joining data in a data-dependent way.
- JOINs in SQL can be thought of as performing outer joins, which is not a very common operation in tensor computation.
However, we have a secret weapon: first class dimensions were primarily designed to as a new frontend syntax that made it easy to express einsum, batching and tensor indexing expressions. They might be good for SQL too.
Read more...