ezyang’s blog

the arc of software bends towards understanding

New Years resolutions for PyTorch in 2025

In my previous two posts "Ways to use torch.compile" and "Ways to use torch.export", I often said that PyTorch would be good for a use case, but there might be some downsides. Some of the downsides are foundational and difficult to remove. But some... just seem like a little something is missing from PyTorch. In this post, here are some things I hope we will end up shipping in 2025!

Improving torch.compile

A programming model for PT2. A programming model is a an abstract description of the system that is both simple (so anyone can understand it and keep it in their head all at once) and can be used to predict the system's behavior. The torch.export programming model is an example of such a description. Beyond export, we would like to help users understand why all aspects of PT2 behave the way it does (e.g., via improved error messages), and give simple, predictable tools for working around problems when they arise. The programming model helps us clearly define the intrinsic complexity of our compiler, which we must educate users about. This is a big effort involving many folks on the PyTorch team and I hope we can share more about this effort soon.

Pre-compilation: beyond single graph export. Whenever someone realizes that torch.compile compilation is taking a substantial amount of time on expensive cluster machines, the first thing they ask is, "Why don't we just compile it in advance?" To support precompiling the torch.compile API exactly as is not so easy; unlike a traditional compiler which gets the source program directly as input, users of torch.compile must actually run their Python program to hit the regions of code that are intended to be compiled. Nor can these regions be trivially enumerated and then compiled: not only must know all the metadata input tensors flowing into a region, a user might not even know what the compiled graphs are if a model has graph breaks.

OK, but why not just run the model, dump all the compiled products, and then reuse them later? This works! Here is a POC from Nikita Shulga where a special decorator aot_compile_sticky_cache swaps between exporting a graph and running the exported product. Zhengxu Chen used a similar idea to export Whisper as a few distinct graphs, which he then manually stitched together in C++ to get a Python-free version of Whisper. If you want training to work, you can more directly integrate AOTInductor as an Inductor backend, e.g., as seen in this POC.. We are a stones throw away from working precompilation, which can guarantee no compilation at runtime, we just need to put the pieces together!

Improving caching further. There are some gaps with caching which we hope to address in the near future: (1) loading Triton cache artifacts takes a long time because we still re-parse the Triton code before doing a cache lookup (James Wu is on this), (2) if you have a lot of small graphs, remote cache ends up having to do lots of small network requests, instead of one batched network request at the beginning (Oguz Ulgen recently landed this), (3) AOTAutograd cache is not fully rolled out yet (James Wu again). These collectively should be worth a 2x speedup or even more on warm cache time.

Fix multithreading. We should just make sure multithreading works, doing the testing and fiddly thread safety auditing needed to make it work. Here's a list of multithreading related issues.

Improving torch.export

Draft mode export. Export requires a lot of upfront work to even get an exported artifact in the first place. Draft mode export capitalizes on the idea that it's OK to generate an unsound "draft" graph early in the export, because even an incorrect graph is useful for kicking the tires on the downstream processing that happens after export. A draft export gives you a graph, and it also gives you a report describing what potential problems need to be fixed to get some guarantees about the correctness of the export. You can then chip away on the problems in the report until everything is green. One of the biggest innovations of draft-mode export is pervasive use of real tensor propagation when doing export: you run the export with actual tensors, so you can always trace through code, even if it is doing spicy things like data-dependent control flow.

Libtorch-free AOTInductor. AOTInductor generated binaries have a relatively small ABI surface that needs to be implemented. This hack from the most recent CUDA Mode meetup shows that you can just create an alternate implementation of the ABI that has no dependence on libtorch. This makes your deployed binary size much smaller!

Support for bundling CUDA kernels into AOTInductor. AOTInductor already supports directly bundling Triton kernels into the generated binary, but traditional CUDA kernels cannot be bundled in this way. There's no reason this has to be the case though: all we're doing is bundling cubins in both case. If we have the ability to bundle traditional CUDA kernels into AOTInductor, this means you could potentially directly embed custom operators into AOTInductor binaries, which is nice because then those operators no longer have to be offered on the runtime (especially if you're commonly iterating on these kernels!)

Export multigraphs. Export's standard model is to give you a single graph that you call unconditionally. But it's easy to imagine a level of indirection on top of these graphs, where we can dispatch between multiple graphs depending on some arguments to the model. For example, if you have a model that optionally takes an extra Tensor argument, you can simply have two graphs, one for when the Tensor is absent, and one for when it is present.

ABI stable PyTorch extensions. It's hard work being a third-party PyTorch extension with native code, because whenever there's a new release of Python or PyTorch you have to rebuild all of your wheels. If there was a limited ABI that you could build your extension against that didn't expose CPython and only relied on a small, stable ABI of PyTorch functions, your binary packaging situation would be much simpler! And if an extension relied on a small ABI, it could even be bundled with AOTInductor binary, letting these export products be truly package agnostic (one of our lessons we learned with torch.package is picking the split between "what is packaged" and "what is not" is very difficult, and people would much rather just have everything be packaged.) Jane Xu is investigating how to do this, and separately, Scott Wolchok has been refactoring headers in libtorch so that a small set of headers can be used independently of the rest of libtorch.

  • January 9, 2025

Ways to use torch.export

Previously, I discussed the value proposition of torch.compile. While doing so, I observed a number of downsides (long compile time, complicated operational model, lack of packaging) that were intrinsic to torch.compile's API contract, which emphasized being able to work on Python code as is, with minimal intervention from users. torch.export occupies a different spot in the tradeoff space: in exchange for more upfront work making a model exportable, it allows for use of PyTorch models in environments where using torch.compile as is would be impossible.

Enable end-to-end C++ CPU/GPU Inference

Scenario: Like before, suppose you want to deploy your model for inference. However, now you have more stringent runtime requirements: perhaps you need to do inference from a CPython-less environment (because your QPS requirements require GIL-less multithreading; alternately, CPython execution overhead is unacceptable but you cannot use CUDA graphs, e.g., due to CPU inference or dynamic shapes requirements). Or perhaps your production environment requires hermetic deploy artifacts (for example, in a monorepo setup, where infrastructure code must be continually pushed but model code should be frozen). But like before, you would prefer not to have to rewrite your model; you would like the existing model to serve as the basis for your Python-less inference binary.

What to do: Use torch.export targeting AOTInductor. This will compile the model into a self-contained shared library which then can be directly invoked from a C++ runtime. This shared library contains all of the compiler generated Triton kernels as precompiled cubins and is guaranteed not to need any runtime compilation; furthermore, it relies only on a small runtime ABI (with no CPython dependency), so the binaries can be used across versions of libtorch. AOTInductor's multithreading capability and low runtime overhead also makes it a good match for CPU inference too!

You don't have to go straight to C++ CPU/GPU inference: you can start with using torch.compile on your code before investing in torch.export. There are four primary extra requirements export imposes: (1) your model must compile with fullgraph=True (though you can sometimes bypass missing Dynamo functionality by using non-strict export; sometimes, it is easier to do non-strict torch.export than it is to torch.compile!), (2) your model's inputs/outputs must only be in torch.export's supported set of argument types (think Tensors in pytrees), (3) your model must never recompile--specifically, you must specify what inputs have dynamic shapes, and (4) the top-level of your model must be an nn.Module (so that export can keep track of all of the parameters your model has).

Some tips:

  • Check out the torch.export programming model. The torch.export programming model is an upcoming doc which aims to help set expectations on what can and cannot be exported. It talks about things like "Tensors are the only inputs that can actually vary at runtime" and common mistakes such as module code which modifies NN modules (not supported!) or optional input types (you will end up with an export that takes in that input or not, there is no runtime optionality).
  • Budget time for getting a model to export. With torch.compile for Python inference, you could just slap it on your model and see what happens. For torch.export, you have to actually finish exporting your entire model before you can even consider running the rest of the pipeline. For some of the more complicated models we have exported, there were often dozens of issues that had to be worked around in one way or another. And that doesn't even account for all of the post-export work you have to do, like validating the numerics of the exported model.
  • Intermediate value debugging. AOTInductor has an option to add dumps of intermediate tensor values in the compiled C++ code. This is good for determining, e.g., the first time where a NaN shows up, in case you are suspecting a miscompilation.

Open source examples: Among other things, torchchat has an example end-to-end AOTInductor setup for server-side LLM inference, which you can view in run.cpp.

torch.export specific downsides:

  • No built-in support for guard-based dispatch (multiple compilations). Earlier, I mentioned that an exported model must not have any recompiles. This leads to some fairly common patterns of code not being directly supported by torch.export: you can't export a single model that takes an enum as input, or has an optional Tensor argument, or accepts two distinct tensor shapes that need to be compiled individually. Now, technically, we could support this: you could imagine a package that contains multiple exported artifacts and dispatches between them depending on some conditions (e.g., the value of the enum, whether or the optional Tensor argument was provided, the shape of the input tensor). But you're on your own: torch.compile will do this for you, but torch.export will not.
  • No built-in support for models that are split into multiple graphs. Similarly, we've mentioned that an exported model must be a single graph. This is in contrast to torch.compile, which will happily insert graph breaks and compile distinct islands of code that can be glued together with Python eager code. Now, technically, you can do this with export too: you can carve out several distinct subnets of your model, export them individually, and then glue them together with some custom written code on the other end (in fact, Meta's internal recommendation systems do this), but there's no built-in support for this workflow.
  • The extra requirements often don't cover important components of real world models. I've mentioned this previously as the extra restrictions export places on you, but it's worth reiterating some of the consequences of this. Take an LLM inference application: obviously, there is a core model that takes in tokens and produces logit predictions--this part of the model is exportable. But there are also important other pieces such as the tokenizer and sampling strategy which are not exportable (tokenizer because it operates on strings, not tensors; sampling because it involves complicated control flow). Arguably, it would be much better if all of these things could be directly bundled with the model itself; in practice, end-to-end applications should just expect to directly implement these in native code (e.g., as is done in torchchat). Our experience with TorchScript taught us that we don't really want to be in the business of designing a general purpose programming language that is portable across all of export's targets; better to just bet that the tokenizer doesn't change that often and eat the cost of natively integrating it by hand.

AOTInductor specific downsides:

  • You still need libtorch to actually run the model. Although AOTInductor binaries bundle most of their compiled kernel implementation, they still require a minimal runtime that can offer basic necessities such as tensor allocation and access to custom operators. There is not yet an official offering of an alternative, lightweight implementation of the stable ABI AOTInductor binaries depends on, so if you do want to deploy AOTInductor binaries you will typically have to also bring libtorch along. This is usually not a big deal server side, but it can be problematic if you want to do client side deployments!
  • No CUDA graphs support. This one is not such a big deal since you are much less likely to be CPU bound when the host side logic is all compiled C++, but there's no support for CUDA graphs in AOTInductor. (Funnily enough, this is also something you technically can orchestrate from outside of AOTInductor.)

Edge deployment

Scenario: You need to deploy your PyTorch model to edge devices (e.g., a mobile phone or a wearable device) where computational resources are limited. You have requirements that are a bit different from server size: you care a lot more about minimizing binary size and startup time. Traditional PyTorch deployment with full libtorch won't work. The device you're deploying too might also have some strange extra processors, like a DSP or NPU, that you want your model to target.

What to do: Use torch.export targeting Executorch. Among other things, Executorch offers a completely separate runtime for exported PyTorch programs (i.e., it has no dependency on libtorch, except perhaps there are a few headers which we share between the projects) which was specifically designed for edge deployment. (Historical note: we spent a long time trying to directly ship a stripped down version of libtorch to mobile devices, but it turns out it's really hard to write code that is portable on server and client, so it's better to only share when absolutely necessary.) Quantization is also a pretty important part of deployment to Edge, and Executorch incorporates this into the end-to-end workflow.

Open source examples: torchchat also has an Executorch integration letting you run an LLM on your Android phone.

Downsides. All of the export related downsides described previously apply here. But here's something to know specifically about Executorch:

  • The edge ecosystem is fragmented. At time of writing, there are seven distinct backends Executorch can target. This is not really Executorch's fault, it comes with the territory--but I want to call it out because it stands in stark contrast to the NVIDIA's server-side hegemony. Yes, AMD GPUs are a thing, and various flavors of CPU are real, but it really is a lot easier to be focused on server side because NVIDIA GPUs come first.

Pre-compiled kernels for eager mode

Scenario: You need a new function or self-contained module with an efficient kernel implementation. However, you would prefer not to have to write the CUDA (or even Triton) by hand; the kernel is something that torch.compile can generate from higher level PyTorch implementation. At the same time, however, you cannot tolerate just-in-time compilation at all (perhaps you are doing a massive training job, and any startup latency makes it more likely that one of your nodes will fail during startup and then you make no progress at all; or maybe you just find it annoying when PyTorch goes out to lunch when you cache miss).

What to do: Use torch.export targeting AOTInductor, and then load and run the AOTInductor generated binary from Python.

Downsides. So, we know this use case works, because we have internally used this to unblock people who wanted to use Triton kernels but could not tolerate Triton's just-in-time compilation. But there's not much affordance in our APIs for this use case; for example, guard-based dispatch is often quite useful for compiled functions, but you'll have to roll that by hand. More generally, when compiling a kernel, you have to make tradeoffs about how static versus dynamic the kernel should be (for example, will you force the inputs to be evenly divisible by eight? Or would you have a separate kernel for the divisible and not divisible cases?) Once again, you're on your own for making the call there.

An exchange format across systems

Scenario: In an ideal world, you would have a model, you could export it to an AOTInductor binary, and then be all done. In reality, maybe this export process needs to be a multi-stage process, where it has to be processed to some degree on one machine, and then finish processing on another machine. Or perhaps you need to shift the processing over time: you want to export a model to freeze it (so it is no longer tied to its original source code), and then repeatedly run the rest of the model processing pipeline on this exported program (e.g., because you are continuously updating its weights and then reprocessing the model). Maybe you want to export the model and then train it from Python later, committing to a distributed training strategy only when you know how many nodes you are running. The ability to hermetically package a model and then process it later is one of the big value propositions of TorchScript and torch.package.

What to do: Use torch.export by itself, potentially using pre-dispatch if you need to support training use-cases. torch.export produces an ExportedProgram which has a clean intermediate representation that you can do processing on, or just serialize and then do processing on later.

Downsides:

  • Custom operators are not packaged. A custom operator typically refers to some native code which was linked with PyTorch proper. There's no way to extract out this kernel and embed it into the exported program so that there is no dependence; instead, you're expected to ensure the eventual runtime relinks with the same custom operator. Note that this problem doesn't apply to user defined Triton kernels, as export can simply compile it and package the binary directly into the exported product. (Technically, this applies to AOTInductor too, but this tends to be much more of a problem for use cases which are primarily about freezing rapidly evolving model code, as opposed to plain inference where you would simply just expect people to not be changing custom operators willy nilly.)
  • Choose your own decompositions. Export produces IR that only contains operators from a canonical operator set. However, the default choice is sometimes inappropriate for use cases (e.g., some users want aten.upsample_nearest2d.vec to be decomposed while others do not), so in practice for any given target you may have a bespoke operator set that is appropriate for that use case. Unfortunately, it can be fiddly getting your operator set quite right, and while we've talked about ideas like a "build your own operator set interactive tool" these have not been implemented yet.
  • Annoyingly large FC/BC surface. Something I really like about AOTInductor is that it has a very small FC/BC surface: I only need to make sure I don't make breaking changes to the C ABI, and I'm golden. With export IR, the FC/BC surface is all of the operators produced by export. Even a decomposition is potentially BC breaking: a downstream pass could be expecting to see an operator that no longer exists because I've decomposed it into smaller pieces. Matters get worse in pre-dispatch export, since the scope of APIs used inside export IR expands to include autograd control operators (e.g., torch.no_grad) as well as tensor subclasses (since Tensor subclasses cannot be desugared if we have not yet eliminated autograd). We will not break your AOTInductor blobs. We can't as easily give the same guarantee for the IR here.

Next time: What's missing, and what we're doing about it

  • December 23, 2024

Ways to use torch.compile

On the surface, the value proposition of torch.compile is simple: compile your PyTorch model and it runs X% faster. But after having spent a lot of time helping users from all walks of life use torch.compile, I have found that actually understanding how this value proposition applies to your situation can be quite subtle! In this post, I want to walk through the ways to use torch.compile, and within these use cases, what works and what doesn't. By the way, some of these gaps are either served by export, or by missing features we are actively working on, those will be some other posts!

Improve training efficiency on a small-medium scale

Scenario: You have a model in PyTorch that you want to train at a small-medium scale (e.g., below 1K GPUs--at the 1K point there is a phase change in behavior that deserves its own section). You would like it to train faster. Locally, it's nice to get a trained model faster than you would have otherwise. But globally, the faster everyone's models train, the less GPU hours they use, which means you can run more jobs in a given time window with a fixed cluster. If your supply of GPUs is inelastic (lol), efficiency improvement means you can support more teams and use cases for the same amount of available GPUs. At a capacity planning level, this can be a pretty big deal even if you are GPU rich.

What to do: In some sense, this is the reason we built torch.compile. (When we were initially planning torch.compile, we were trying to assess if we were going after inference; but inference compilers are a much more crowded space than training compilers, and we reasoned that if we did a good job building a training compiler, inference would work too--which it did!) The dream which we sold with torch.compile is that you could slap it on the top of your model and get a speed up. This turns out to... not quite be true? But the fact remains that if you're willing to put in some work, there is almost always performance waiting at the end of the road for you. Some tips:

  • Compile only the modules you need. You don't have to compile the entire model; there might be specific modules which are easy to compile which will give you the most of the benefit. For example, in recommendation systems, there is not much compute improvement to be had from optimizing the embedding lookups, and their model parallelism is often quite hard to handle in the compiler, so torch.compiler.disable them. NB: This doesn't apply if you want to do some global graph optimization which needs the whole model: in that case, pass fullgraph=True to torch.compile and ganbatte!
  • Read the missing manual. The missing manual is full of guidance on working with the compiler, with a particular emphasis on working on training.

Open source examples: torchtune and torchtitan are two first party libraries which are intended to showcase modern PyTorch using torch.compile in a training context. There's also some training in torchao.

Downsides:

  • The compiler is complicated. One of the things we've slowly been coming to terms with is that, uh, maybe promising you could just slap torch.compile on a model and have it run faster was overselling the feature a teensy bit? There seems to be some irreducible complexity with compilers that any user bringing their own model to torch.compile has to grapple with. So yes, you are going to spend some of your complexity budget on torch.compile, in hopes that the payoff is worth it (we think it is!) One ameliorating factor is that the design of torch.compile (graph breaks) means it is very easy to incrementally introduce torch.compile into a codebase, without having to do a ton of upfront investment.
  • Compile time can be long. The compiler is not a straightforward unconditional win. Even if the compiler doesn't slow down your code (which it can, in pathological cases), you have to spend some amount of time compiling your model (investment), which you then have to make back by training the model more quickly (return). For very small experimentation jobs, or jobs that are simply crashing, the time spent compiling is just dead weight, increasing the overall time your job takes to run. (teaser: async compilation aims to solve this.) To make matters worse, if you are scheduling your job on systems that have preemption, you might end up repeatedly compiling over and over again every time your job gets rescheduled (teaser: caching aims to solve this.) But even when you do spend some time training, it is not obvious without an A/B test whether or not you are actually getting a good ROI. In an ideal world, everyone using torch.compile would actually verify this ROI calculation, but it doesn't happen automatically (teaser: automatic ROI calculation) and in large organizations we see people running training runs without even realizing torch.compile is enabled.
  • Numerics divergence from eager. Unfortunately, the compiler does not guarantee exact bitwise equivalence with eager code; we reserve the right to do things like select different matrix multiply algorithms with different numerics or eliminate unnecessary downcast/upcasts when fusing half precision compute together. The compiler is also complicated and can have bugs that can cause loss not to converge. Expect to also have to evaluate whether or not application of torch.compile affects accuracy. Fortunately, for most uses of compiler for training efficiency, the baseline is the eager model, so you can just run an ablation to figure out who is actually causing the accuracy problem. (This won't be true in a later use case when the compiler is load bearing, see below!)

Improve Python inference efficiency

Scenario: You've finished training your model and you want to deploy it for inference. Here, you want to improve the efficiency of inference to improve response latency or reduce the overall resource requirements of the system, so you can use less GPUs to serve the traffic you are receiving. Admittedly, it is fairly common to just use some other, more inference friendly systems (which I will decline to name by name lol) to serve the model. But let's say you can't rewrite the model in a more serving friendly language (e.g., because the model authors are researchers and they keep changing the model, or there's a firehose of models and you don't have the money to keep continuously porting each of them, or you depend on an ecosystem of libraries that are only available in CPython).

What to do: If Python can keep up with the CPU-side QPS requirements, a way of getting good performance without very much work is taking the Python model, applying torch.compile on it in the same way as you did in training and directly using this as your inference solution. Some tips that go beyond training:

  • Autotuning makes the most sense for inference. In training runs, you have a limited window (the lifetime of the training job) to get return on the investment you spent optimizing the model. In the serving regime, you can amortize over the entire lifetime of your model in inference, which is typically much longer. Therefore, expensive optimization modes like mode="max-autotune" are more likely to pay off!
  • Warmup inference processes before serving traffic to them. Because torch.compile is a just-in-time compiler, you will spend quite a bit of time compiling (even if you cache hit) at startup. If you have latency requirements, you will want to warmup a fresh process with a representative set of inputs so that you can make sure you trigger all of the compilation paths you need to hit. Caching will reduce compile time but not eliminate it.
  • Try skip_guard_eval_unsafe to reduce guard overhead. Dynamo guard overhead can be material in the inference case. If this is a problem, get a nightly and try skip_guard_eval_unsafe.

Open source examples: LLM serving on torch.compile is quite popular: vllm, sglang, tensorrt-llm, gpt-fast (this is technically not an E2E serving solution, but one of its primary reasons for existing is to serve as a starting point so you can build your own torch.compile based LLM inference stack on top of it). Stable diffusion models are also notable beneficiaries of torch.compile, e.g., diffusers.

Downsides:

  • Just in time compilation is a more complicated operational model. It would be better if you didn't have to warmup inference processes before serving traffic to them. Here, torch.compile has traded operational simplicity for ease of getting started. If you wanted to guarantee that compilation had already happened ahead of time, you have to instead commit to some sort of export-based flow (e.g., C++ GPU/CPU inference) below.
  • Model and dependency packaging in Python is unaddressed. You need to somehow package and deploy the actual Python code (and all its dependencies) which constitute the model; torch.compile doesn't address this problem at all (while torch.export does). If you are running a monorepo and do continuous pushes of your infra code, it can be organizationally complicated to ensure people don't accidentally break model code that is being shipped to production--it's very common to be asked if there's a way to "freeze" your model code so that the monorepo can move on. But with Python inference you have to solve this problem yourself, whether the solution is torch.package, Docker images, or something else.
  • Caches are not guaranteed to hit. Do you have to recompile the model every time you restart the inference process? Well, no, we have an Inductor and Triton (and an in-progress AOTAutograd) cache which in principle can cache all of the cubin's that are generated by torch.compile. Most of the time, you can rely on this to reduce startup cost to Dynamo tracing the model only. However, the caches are not guaranteed to hit: there are rarer cases where we don't know how to compute the cache key for some feature a model is using, or the compiler is nondeterministic in a way that means the cache doesn't hit. You should file bugs for all of these issues as we are interested in fixing them, but we don't give a categorical guarantee that after you've compiled your inference program once, you won't have to compile it again. (And indeed, under torch.compile's user model, we can't, because the user code might be the root cause of the nondeterminism--imagine a model that is randomly sampling to decide what version of a model to run.)
  • Multithreading is currently buggy. It should, in principle, be possible to run torch.compile'd code from multiple threads in Python and get a speedup, especially when CUDA graphs or CPP wrapper is used. (Aside: Inductor's default compile target is "Python wrapper", where Inductor's individually generated Triton kernels are called from Python. In this regime, you may get in trouble due to the GIL; CUDA graphs and CPP wrapper, however, can release the GIL when the expensive work is being done.) However, it doesn't work. Track the issue at https://github.com/pytorch/pytorch/issues/136833

Like above, but the compiler is load bearing

Scenario: In both the cases above, we assumed that we had a preexisting eager model that worked, and we just wanted to make it faster. But you can also use the compiler in a load bearing way, where the model does not work without the compiler. Here are two common cases where this can occur:

  1. Performance: A compiler optimization results in an asymptotic or large constant factor improvement in performance can make a naive eager implementation that would have otherwise been hopelessly slow have good performance. For example, SimpleFSDP chooses to apply no optimizations to the distributed collectives it issues, instead relying on the compiler to bucket and prefetch them for acceptable performance.
  2. Memory: A compiler optimization reduces the memory usage of a model, can allow you to fit a model or batch size that would otherwise OOM. Although we don't publicly expose APIs for doing so, you can potentially use the compiler to do things like force a certain memory budget when doing activation checkpointing, without requiring the user to manually specify what needs to be checkpointed.

What to do: Unlike in the previous cases where you took a preexisting model and slap torch.compile, this sort of use of the compiler is more likely to arise from a codevelopment approach, where you use torch.compile while you build your model, and are constantly checking what the compiler does to the code you write. Some tips:

  • Don't be afraid to write your own optimization pass. Inductor supports custom FX optimization passes. torch.compile has done the work of getting your model into an optimizable form; you can take advantage of this to apply domain specific optimizations that Inductor may not support natively.

Open source examples. SimpleFSDP as mentioned above. VLLM uses torch.compile to apply custom optimization passes. Although its implementation is considerably more involved than what you might reasonable expect a third party to implement, FlexAttention is a good example of a non-compiler feature that relies on the compiler in a load-bearing way for performance.

Downsides: Beyond the ones mentioned above:

  • You can no longer (easily) use eager as a baseline. This is not always true; for example, FlexAttention has an eager mode that runs everything unfused which can still be fast enough for small experiments. But if you have an accuracy problem, it may be hard to compare against an eager baseline if you OOM in that case! It turns out that it's really, really useful to have access to an eager implementation, so it's worth working harder to make sure that the eager implementation works, even if it is slow. (It's less clear how to do that with, e.g., a fancy global optimization based activation checkpointing strategy.)

Next time: Ways to use torch.export

  • November 5, 2024

Tensor programming for databases, with first class dimensions

Tensor libraries like PyTorch and JAX have developed compact and accelerated APIs for manipulating n-dimensional arrays. N-dimensional arrays are kind of similar to tables in database, and this results in the logical question which is could you setup a Tensor-like API to do queries on databases that would be normally done with SQL? We have two challenges:

  • Tensor computation is typically uniform and data-independent. But SQL relational queries are almost entirely about filtering and joining data in a data-dependent way.
  • JOINs in SQL can be thought of as performing outer joins, which is not a very common operation in tensor computation.

However, we have a secret weapon: first class dimensions were primarily designed to as a new frontend syntax that made it easy to express einsum, batching and tensor indexing expressions. They might be good for SQL too.

Representing the database. First, how do we represent a database? A simple model following columnar database is to have every column be a distinct 1D tensor, where all columns part of the same table have a consistent indexing scheme. For simplicity, we'll assume that we support rich dtypes for the tensors (e.g., so I can have a tensor of strings). So if we consider our classic customer database of (id, name, email), we would represent this as:

customers_id: int64[C]
customers_name: str[C]
customers_email: str[C]

Where C is the number of the entries in the customer database. Our tensor type is written as dtype[DIM0, DIM1, ...], where I reuse the name that I will use for the first class dimension that represents it. Let's suppose that the index into C does not coincide with id (which is good, because if they did coincide, you would have a very bad time if you ever wanted to delete an entry from the database!)

This gives us an opportunity for baby's first query: let's implement this query:

SELECT c.name, c.email FROM customers c WHERE c.id = 1000

Notice that the result of this operation is data-dependent: it may be zero or one depending on if the id is in the database. Here is a naive implementation in standard PyTorch:

mask = customers_id == 1000
return (customers_name[mask], customers_email[mask])

Here, we use boolean masking to perform the data-dependent filtering operation. This implementation in eager is a bit inefficient; we materialize a full boolean mask that is then fed into the subsequent operations; you would prefer for a compiler to fuse the masking and indexing together. First class dimensions don't really help with this example, but we need to introduce some new extensions to first class dimensions. First, what we can do:

C = dims(1)
c_id = customers_id[C]  # {C} => int64[]
c_name = customers_name[C]  # {C} => str[]
c_email = customers_email[C]  # {C} => str[]
c_mask = c_id == 1000  # {C} => bool[]

Here, a tensor with first class tensors has a more complicated type {DIM0, DIM1, ...} => dtype[DIM2, DIM3, ...]. The first class dimensions are all reported in the curly braces to the left of the double arrow; curly braces are used to emphasize the fact that first class dimensions are unordered.

What next? The problem is that now we want to do something like torch.where(c_mask, c_name, ???) but we are now in a bit of trouble, because we don't want anything in the false branch of where: we want to provide something like "null" and collapse the tensor to a smaller number of elements, much like how boolean masking did it without first class dimensions. To express this, we'll introduce a binary version of torch.where that does exactly this, as well as returning the newly allocated FCD for the new, data-dependent dimension:

C2, c2_name = torch.where(c_mask, c_name)  # {C2} => str[]
_C2, c2_email = torch.where(c_mask, c_email)  # {C2} => str[], n.b. C2 == _C2
return c2_name, c2_email

Notice that torch.where introduces a new first-class dimension. I've chosen that this FCD gets memoized with c_mask, so whenever we do more torch.where invocations we still get consistently the same new FCD.

Having to type out all the columns can be a bit tiresome. If we assume all elements in a table have the same dtype (let's call it dyn, short for dynamic type), we can more compactly represent the table as a 2D tensor, where the first dimension is the indexing as before, and the second dimension is the columns of the database. For clarity, we'll support using the string name of the column as a shorthand for the numeric index of the column. If the tensor is contiguous, this gives a more traditional row-wise database. The new database can be conveniently manipulated with FCDs, as we can handle all of the columns at once instead of typing them out individually):

customers:  dyn[C, C_ATTR]
C = dims(1)
c = customers[C]  # {C} => dyn[C_ATTR]
C2, c2 = torch.where(c["id"] == 1000, c)  # {C2} => dyn[C_ATTR]
return c2[["name", "email"]].order(C2)  # dyn[C2, ["name", "email"]]

We'll use this for the rest of the post, but the examples should be interconvertible.

Aggregation. What's the average age of all customers, grouped by the country they live in?

SELECT AVG(c.age) FROM customers c GROUP BY c.country;

PyTorch doesn't natively support this grouping operation, but essentially what is desired here is a conversion into a nested tensor, where the jagged dimension is the country (each of which will have a varying number of countries). Let's hallucinate a torch.groupby analogous to its Pandas equivalent:

customers: dyn[C, C_ATTR]
customers_by_country = torch.groupby(customers, "country")  # dyn[COUNTRY, JC, C_ATTR]
COUNTRY, JC = dims(2)
c = customers_by_country[COUNTRY, JC]  # {COUNTRY, JC} => dyn[C_ATTR]
return c["age"].mean(JC).order(COUNTRY)  # f32[COUNTRY]

Here, I gave the generic indexing dimension the name JC, to emphasize that it is a jagged dimension. But everything proceeds like we expect: after we've grouped the tensor and rebound its first class dimensions, we can take the field of interest and explicitly specify a reduction on the dimension we care about.

In SQL, aggregations have to operate over the entirety of groups specified by GROUP BY. However, because FCDs explicitly specify what dimensions we are reducing over, we can potentially decompose a reduction into a series of successive reductions on different columns, without having to specify subqueries to progressively perform the reductions we are interested in.

Joins. Given an order table, join it with the customer referenced by the customer id:

SELECT o.id, c.name, c.email FROM orders o JOIN customers c ON o.customer_id = c.id

First class dimensions are great at doing outer products (although, like with filtering, it will expensively materialize the entire outer product naively!)

customers: dyn[C, C_ATTR]
orders: dyn[O, O_ATTR]
C, O = dims(2)
c = customers[C]  # {C} => dyn[C_ATTR]
o = orders[O]  # {O} => dyn[O_ATTR]
mask = o["customer_id"] == c["id"]  # {C, O} => bool[]
outer_product = torch.cat(o[["id"]], c[["name", "email"]])  # {C, O} => dyn[["id", "name", "email"]]
CO, co = torch.where(mask, outer_product)  # {CO} => dyn[["id", "name", "email"]]
return co.order(CO)  # dyn[C0, ["id", "name", "email"]]

What's the point. There are a few reasons why we might be interested in the correspondence here. First, we might be interested in applying SQL ideas to the Tensor world: a lot of things people want to do in preprocessing are similar to what you do in traditional relational databases, and SQL can teach us what optimizations and what use cases we should think about. Second, we might be interested in applying Tensor ideas to the SQL world: in particular, I think first class dimensions are a really intuitive frontend for SQL which can be implemented entirely embedded in Python without necessitating the creation of a dedicated DSL. Also, this might be the push needed to get TensorDict into core.

  • October 14, 2024

What’s different this time? LLM edition

One of the things that I learned in grad school is that even if you've picked an important and unsolved problem, you need some reason to believe it is solvable--especially if people have tried to solve it before! In other words, "What's different this time?" This is perhaps a dreary way of shooting down otherwise promising research directions, but you can flip it around: when the world changes, you can ask, "What can I do now that I couldn't do before?"

This post is a list of problems in areas that I care about (half of this is PL flavor, since that's what I did my PhD in), where I suspect something has changed with the advent of LLMs. It's not a list of recipes; there is still hard work to figure out how exactly an LLM can be useful (for most of these, just feeding the entire problem into ChatGPT usually doesn't work). But I often talk to people want to get started on something, anything, but have no idea to start. Try here!

Static analysis. The chasm between academic static analysis work and real world practice is the scaling problems that come with trying to apply the technique to a full size codebase. Asymptotics strike as LOC goes up, language focused techniques flounder in polyglot codebases, and "Does anyone know how to write cmake?" But this is predicated on the idea that static analysis has to operate on a whole program. It doesn't; humans can do perfectly good static analysis on fragments of code without having to hold the entire codebase in their head, without needing access to a build system. They make assumptions about APIs and can do local reasoning. LLMs can play a key role in drafting these assumptions so that local reasoning can occur. What if the LLM gets it wrong? Well, if an LLM could get it wrong, an inattentive junior developer might get it wrong too--maybe there is a problem in the API design. LLMs already do surprisingly well if you one-shot prompt them to find bugs in code; with more traditional static analysis support, maybe they can do even better.

DSL purgatory. Consider a problem that can be solved with code in a procedural way, but only by writing lots of tedious, error prone boilerplate (some examples: drawing diagrams, writing GUIs, SQL queries, building visualizations, scripting website/mobile app interactions, end to end testing). The PL dream is to design a sweet compositional DSL that raises the level of abstraction so that you can render a Hilbert curve in seven lines of code. But history is also abound with cases where the DSL did not solve the problems, or maybe it did solve the problem but only after years of grueling work, and so there are still many problems that feel like there ought to be a DSL that should solve them but there isn't. The promise of LLMs is that they are extremely good at regurgitating low level procedural actions that could conceivably be put together in a DSL. A lot of the best successes of LLMs today is putting coding powers in the hands of domain experts that otherwise do not how to code; could it also help in putting domain expertise in the hands of people who can code?

I am especially interested in these domains:

  • SQL - Its strange syntax purportedly makes it easier for non-software engineers to understand, whereas many (myself included) would often prefer a more functional syntax ala LINQ/list comprehensions. It's pretty hard to make an alternate SQL syntax take off though, because SQL is not one language, but many many dialects everywhere with no obvious leverage point. That sounds like an LLM opportunity. Or heck, just give me one of those AI editor environments but specifically fine tuned for SQL/data visualization, don't even bother with general coding.
  • End to end testing - This is https://momentic.ai/ but personally I'm not going to rely on a proprietary product for testing in my OSS projects. There's definitely an OSS opportunity here.
  • Scripting website/mobile app interactions - The website scraping version of this is https://reworkd.ai/ but I am also pretty interested in this from the browser extension angle: to some extent I can take back control of my frontend experience with browser extensions; can I go further with LLMs? And we typically don't imagine that I can do the same with a mobile app... but maybe I can??

OSS bread and butter. Why is Tesseract still the number one OSS library for OCR? Why is smooth and beautiful text to voice not ubiquitous? Why is the voice control on my Tesla so bad? Why is the wake word on my Android device so unreliable? Why doesn't the screenshot parser on a fansite for my favorite mobage not able to parse out icons? The future has arrived, but it is not uniformly distributed.

Improving the pipeline from ephemeral to durable stores of knowledge. Many important sources of knowledge are trapped in "ephemeral" stores, like Discord servers, private chat conversations, Reddit posts, Twitter threads, blog posts, etc. In an ideal world, there would be a pipeline of this knowledge into more durable, indexable forms for the benefit of all, but actually doing this is time consuming. Can LLMs help? Note that the dream of LLMs is you can just feed all of this data into the model and just ask questions to it. I'm OK with something a little bit more manual, we don't have to solve RAG first.

  • October 4, 2024

Interactive scraping with Jupyter and Puppeteer

One of the annoying things about scraping websites is bouncing back and forth between the browser where you are using Dev Tools to work out what selectors you should be using to scrape out data, and your actual scraping script, which is usually some batch program that may have to take a few steps before the step you are debugging. A batch script is fine once your scraper is up and running, but while developing, it's really handy to pause the scraping process at some page and fiddle around with the DOM to see what to do.

This interactive-style development is exactly what Juypter notebooks shine at; when used in conjunction with a browser-based scraping library like Puppeteer, you can have exactly this workflow. Here's the setup:

  1. Puppeteer is a JavaScript library, so you'll need a JavaScript kernel for Jupyter to run it. As an extra complication, Puppeteer is also async, so you'll need a kernel that supports async execution. Fortunately, ijavascript-await provides exactly this. Note that on recent versions of node this package does not compile; you can install this PR which makes this work: https://github.com/n-riesco/ijavascript/pull/257 Hypothetically, we should be able to use stock ijavascript when node supports top level await, but this currently does not work: https://github.com/nodejs/node/issues/40898
  2. Inside the directory you will store your snotebooks, you'll need to npm install puppeteer so that it's available for your notebooks.
  3. Launch Puppeteer with let puppeteer = require('puppeteer'); let browser = await puppeteer.launch({headless: false}); and profit!

There will be a live browser instance which you can poke at using Dev Tools, and you type commands into the Jupyter notebook and see how they affect the browser state.

I tweeted about this and the commenters had some good suggestions about other things you could try:

  • You don't have to use Puppeteer; Selenium can also drive the browser, and it has a Python API to boot (so no faffing about with alternate Jupyter kernels necessary). I personally prefer working in JavaScript for crawlers, since the page scripting itself is also in JavaScript, but this is mostly a personal preference thing.
  • For simple interactions, where all you really want is to just do a few interactions and record them, Headless Recorder provides a nice extension for just directly recording operations in your browser and then getting them out in executable form. I haven't tried it out yet but it seems like it would be very easy to use.
  • November 23, 2021

PyTorch Developer Podcast

I'm launching a new podcast, the PyTorch Developer Podcast. The idea is to be a place for the PyTorch dev team to do bite sized (10-20 min) topics about all sorts of internal development topics in PyTorch. For now, it's just me monologuing for fifteen minutes about whatever topic I decide. The plan is to release an episode daily, five days a week, until I run out of things to say (probably not for a while, I have SO MANY THINGS TO SAY). I don't edit the podcasts and do minimal planning, so they're a bit easier to do than blog posts. Check it out! There's two episodes out already, one about how we do Python bindings for our C++ objects and another about history and constraints of the dispatcher. If there are any topics you'd like me to cover, give a shout.

  • May 5, 2021

Rage bug reporting

At Facebook, we have an internal convention for tooling called "rage". When something goes wrong and you want to report a bug, the tool developer will typically ask you to give them a rage. For a command line tool, this can be done by running a rage subcommand, which will ask about which previous CLI invocation you'd like to report, and then giving you a bundle of logs to send to the developer.

A rage has an important property, compared to a conventional log level flag like -v: rage recording is always on. In other words, it is like traditional server application logs, but applied to client software. Logging is always turned on, and the rage subcommand makes it easy for a user to send only the relevant portion of logs (e.g., the logs associated with the command line invocation that is on).

For some reason, rage functionality is not that common in open source tools. I can imagine any number of reasons why this might be the case:

  • Adding proper logging is like flossing--annoying to do at the time even when it can save you a lot of pain later.
  • Even if you have logging, you still need to add infrastructure to save the logs somewhere and let users retrieve them afterwards.
  • It's something of an art to write logs that are useful enough so that developer can diagnose the problem simply by "reading the tea leaves", but not so detailed that they slow down normal execution of the program. And don't forget, you better not expose private information!
  • Most programs are simple, and you can just fall back on the old standby of asking the user to submit reproduction instructions in their bug report.

Still, in the same way most sysadmins view logging as an invaluable tool for debugging server issues, I think rage reporting is an invaluable tool for debugging client issues. In ghstack, it didn't take very many lines of code to implement rage reporting: ghstack.logs (for writing the logs to the rage directory) and ghstack.rage (for reading it out). But it has greatly reduced my support load for the project; given a rage, I can typically figure out the root cause of a bug without setting up a reproducer first.

  • April 25, 2021

The PyTorch open source process

PyTorch is a fairly large and active open source project, and sometimes we have people come to us and ask if there are any lessons from how we run PyTorch that they could apply to their own projects. This post is an attempt to describe some of the processes as of 2021 that help PyTorch operate effectively as an open source project. I won't claim that everything we do necessarily the best way to go about doing things, but at the very least, everything I describe here is working in practice.

Background. Not all open source projects are the same, and there are some peculiarities to PyTorch which may reduce the applicability of some of what I describe below in other contexts. Here are some defining features of PyTorch, as a project:

  • The majority of full time PyTorch developers work at Facebook. To be clear, there are many full time PyTorch developers that work at other companies: NVIDIA, Intel, Quansight, Microsoft, AMD, IBM, Preferred Networks, Google and Amazon all employ people whose job it is to work on PyTorch. But the majority of full timers are at Facebook, distinguishing PyTorch from hobbyist open source projects or projects run by a foundation of some sort.
  • PyTorch is a federation. As coined by Nadia Eghbal, PyTorch is a project with high contributor growth and user growth. In my State of PyTorch (2020) talk, I go into more details, but suffice to say, we have over nine companies contributing to PyTorch, and a long tail of other contributors (making up 40% of all of our commits). This makes managing PyTorch sometimes particularly challenging, and many of the processes I will describe below arose from growing pains scaling this level of activity.
  • PyTorch has a lot of surface area. CPU, CUDA, ROCm, ONNX, XLA, serving, distributions, quantization, etc. It's impossible for a single contributor to be well-versed in every area of the project, and so some of the challenge is just making sure the right people see the things they need to see.

Alright, so how does PyTorch deal with its scale? Here are some of the things we do.

Issue triage. PyTorch receives too many bug reports a day for any one person to keep track of all of them. Largely inspired by this apenwarr post, we setup an oncall rotation amongst Facebook contributors to serve as first line triage for all of these issues. The golden rule of issue triage is that you DO NOT fix bugs in triage; the goal of triage is to (1) route bugs to the correct people via appropriate GitHub labels, and (2) look for high priority bugs and raise awareness of these bugs. Every week, we have a meeting to review high priority bugs (and other bugs marked for triage review) and talk about them. The oncall itself rotates daily, to discourage people from letting a week's worth of issues pile up in the backlog, and we use a relatively intricate search query to make sure only relevant issues show up for the oncall to handle.

The most important consequence of issue triage is that you can unwatch PyTorch repository as a whole. Instead, by watching various labels (using our cc bot), you can trust that you will get CC'ed to issues related to topics, even if the triager doesn't know that you're interested in the issue! The weekly meeting makes sure that all maintainers collectively have an idea about what major issues are currently affecting PyTorch, and helps socialize what we as a project think of as a "high priority" issue. Finally, the high priority label is a good way to find impactful problems to work on in the project, even if you don't know much else about the project.

Pull request triage. Similarly, we receive a decent number of drive by pull requests from one time contributors. Those people are not in a good position to find reviewers for their contributions, so we also have a triager look through these pull requests and make sure someone is assigned to review them. If the PR is particularly simple, the triager might just go ahead and merge it themselves. There's actually some good automation for doing this (e.g., homu) but we've been too lazy to set any of it up, and by hand reviewer assignment doesn't seem to be too much burden on top of the existing oncall.

Tree hugging oncall. PyTorch has a huge CI system covering many different system configurations which most contributors rely on to test if their changes are safe. Sometimes people break master. Separate from the triage oncall, we have a tree hugging oncall whose job it is to revert jobs if they break master. This oncall involves mostly paying attention to the CI HUD and reverting commits if they result in master breakage in one of the configurations.

Importing to Facebook infrastructure. We actually run Facebook infrastructure directly off of the HEAD branch in PyTorch. The tooling that makes this possible is fbshipit, which mirrors commits between Facebook's internal monorepo and our public GitHub repository. This setup has been something of a double-edged sword for us: requiring Facebook and GitHub to be in sync means that only Facebook employees can actually land pull requests (we try to streamline the process as much as possible for external maintainers, but at the end of the day someone at Facebook has to actually push the green button), but it means we don't have to worry about doing periodic "mega-imports" into Facebook infrastructure (which we have done in the past and were quite difficult to do). We are very interested in fixing this situation and have floated some proposals on changing how we do internal releases to make it possible to let external contributors land PRs directly.

RFCs. Most feature discussion happens on GitHub issues, but sometimes, a feature is too big and complicated to adequately discuss in a GitHub issue. In those cases, they can be discussed in the rfcs repository (inspired by the Rust RFCs process). The formal process on this repository isn't too solidified yet, but generally people go there if they feel that it is too difficult to discuss the issue in GitHub issues. We don't yet have a process for shepherding unsolicited RFCs.

Conclusion. PyTorch's open source process isn't rocket science: there's an oncall, the oncall does some things. The devil is in the details: all of PyTorch's oncall responsibilities are carefully scoped so that your oncall responsibilities aren't something that will take an unbounded amount of time; they're something you can knock out in an hour or two and call it a day. You could make the argument that we rely excessively on oncalls when automation is possible, but what we have found is that oncalls require less infrastructure investment, and integrate well with existing processes and flows at Facebook. They might not be right everywhere, but at least for us they seem to be doing a good job.

  • January 6, 2021

The hidden problem(?) with basic block procedures in SSA

Years ago, Nadav Rotem related to me this story about why basic block procedures in Swift are not as good as they seem. Nelson Elhage reminded me about this on Twitter and so I thought this should be put into the public record.

Basic block procedures make certain optimizations more difficult. Consider this program:

block j3 (%y1, %y2) { ... }
block j1 () { jump j3(%x1, %x2) }
block j2 () { jump j3(%x3, %x4) }

Is this program easier or more difficult to optimize than the traditional SSA with phi-nodes formulation?

L1:
   goto L3
L2:
   goto L3
L3:
   %y1 = phi [%x1, %L1] [%x3, %L2]
   %y2 = phi [%x2, %L1] [%x4, %L2]

Suppose that the optimizer determines that y1 is unused inside j3/L3 and can be eliminated. In basic block land, y1 can be eliminated simply by deleting "y1 = phi x1 x3". However, in join point land, you have to not only eliminate y1 but also update all the call sites of j3, since you've changed the function signature. In a mutable AST, changing function signatures is a pain; in particular, the mutations you would have to do to eliminate the argument include intermediate states that are not valid ASTs (making it easy to accidentally trigger asserts.)

When I saw this example, I wondered why GHC (which has the moral equivalent of basic block procedures in the form of join points) didn't have this problem. Well, it turns out this optimization can be done as a series of local transformations. First, we do a worker/wrapper transformation, introducing an intermediate block (the worker) that drops the dead argument:

block j3 (%y1, %y2) { jump wj3(%y2) }
block j1 () { jump j3(%x1, %x2) }
block j2 () { jump j3(%x3, %x4) }
block wj3 (%y2) { ... }

Later, we inline j3, which removes the wrapper. Worker/wrapper is a very important optimization for functional programs, but it's easy to imagine why it is less preferred in mutable compiler land.

  • October 24, 2020