When training large scale LLMs, there is a large assortment of parallelization strategies which you can employ to scale your training runs to work on more GPUs. There are already a number of good resources for understanding how to parallelize your models: I particularly recommend How To Scale Your Model and The Ultra-Scale Playbook. The purpose of this blog post is to discuss parallelization strategies in a more schematic way by focusing only on how they affect your device mesh. The device mesh is an abstraction used by both PyTorch and JAX that takes your GPUs (however many of them you’ve got in your cluster!) and organizes them into a N-D tensor that expresses how the devices communicate with each other. When we parallelize computation, we shard a tensor along one dimension of the mesh, and then do collectives along that dimension when there are nontrivial dependencies between shards. Being able to explain why a device mesh is set up the way it is for a collection of parallelization strategies is a good check for seeing if you understand how the parallelization strategies work in the first place! (Credit: This post was influenced by Visualizing 6D Mesh Parallelism.)
Read more...
CuTe is a C++ library that aims to make dealing with complicated indexing easier. A key part of how it does this is by defining a Layout type, which specifies how to map from logical coordinates to physical locations (CuTe likes to say layouts are “functions from integers to integers.”) In fact, CuTe layouts are a generalization of PyTorch strides, which say you always do this mapping by multiplying each coordinate with its respective stride and summing them together, e.g., i0 * s0 + i1 * s1 + .... Although NVIDIA’s docs don’t spell it out, the CuTe’s generalization here is actually very natural, and in this blog post I’d like to explain how you could have invented it (on a good day).
Read more...
The purpose of this post is to sum up, in one place, the state of torch.compile for training as of August 2025. Nothing in here isn’t something you might not already know about from elsewhere on the Internet, but we rarely put everything together in one place. The target audience for this document are teams who are evaluating the use of torch.compile for large scale training runs.
First, the basics. torch.compile (also known as PT2) is a compiler for PyTorch eager programs for both inference and training workloads. Speedups from 1.5-2x compared to eager code are typical, and torch.compile also makes it possible to do global optimizations for memory (e.g., automatic activation checkpointing) and distributed communications (e.g., async tensor parallelism).
Read more...