Ways to use torch.export
Previously, I discussed the value proposition of torch.compile. While doing so, I observed a number of downsides (long compile time, complicated operational model, lack of packaging) that were intrinsic to torch.compile's API contract, which emphasized being able to work on Python code as is, with minimal intervention from users. torch.export occupies a different spot in the tradeoff space: in exchange for more upfront work making a model exportable, it allows for use of PyTorch models in environments where using torch.compile as is would be impossible.
Enable end-to-end C++ CPU/GPU Inference
Scenario: Like before, suppose you want to deploy your model for inference. However, now you have more stringent runtime requirements: perhaps you need to do inference from a CPython-less environment (because your QPS requirements require GIL-less multithreading; alternately, CPython execution overhead is unacceptable but you cannot use CUDA graphs, e.g., due to CPU inference or dynamic shapes requirements). Or perhaps your production environment requires hermetic deploy artifacts (for example, in a monorepo setup, where infrastructure code must be continually pushed but model code should be frozen). But like before, you would prefer not to have to rewrite your model; you would like the existing model to serve as the basis for your Python-less inference binary.
What to do: Use torch.export targeting AOTInductor. This will compile the model into a self-contained shared library which then can be directly invoked from a C++ runtime. This shared library contains all of the compiler generated Triton kernels as precompiled cubins and is guaranteed not to need any runtime compilation; furthermore, it relies only on a small runtime ABI (with no CPython dependency), so the binaries can be used across versions of libtorch. AOTInductor's multithreading capability and low runtime overhead also makes it a good match for CPU inference too!
You don't have to go straight to C++ CPU/GPU inference: you can start with using torch.compile on your code before investing in torch.export. There are four primary extra requirements export imposes: (1) your model must compile with fullgraph=True (though you can sometimes bypass missing Dynamo functionality by using non-strict export; sometimes, it is easier to do non-strict torch.export than it is to torch.compile!), (2) your model's inputs/outputs must only be in torch.export's supported set of argument types (think Tensors in pytrees), (3) your model must never recompile--specifically, you must specify what inputs have dynamic shapes, and (4) the top-level of your model must be an nn.Module (so that export can keep track of all of the parameters your model has).
Some tips:
- Check out the torch.export programming model. The torch.export programming model is an upcoming doc which aims to help set expectations on what can and cannot be exported. It talks about things like "Tensors are the only inputs that can actually vary at runtime" and common mistakes such as module code which modifies NN modules (not supported!) or optional input types (you will end up with an export that takes in that input or not, there is no runtime optionality).
- Budget time for getting a model to export. With torch.compile for Python inference, you could just slap it on your model and see what happens. For torch.export, you have to actually finish exporting your entire model before you can even consider running the rest of the pipeline. For some of the more complicated models we have exported, there were often dozens of issues that had to be worked around in one way or another. And that doesn't even account for all of the post-export work you have to do, like validating the numerics of the exported model.
- Intermediate value debugging. AOTInductor has an option to add dumps of intermediate tensor values in the compiled C++ code. This is good for determining, e.g., the first time where a NaN shows up, in case you are suspecting a miscompilation.
Open source examples: Among other things, torchchat has an example end-to-end AOTInductor setup for server-side LLM inference, which you can view in run.cpp.
torch.export specific downsides:
- No built-in support for guard-based dispatch (multiple compilations). Earlier, I mentioned that an exported model must not have any recompiles. This leads to some fairly common patterns of code not being directly supported by torch.export: you can't export a single model that takes an enum as input, or has an optional Tensor argument, or accepts two distinct tensor shapes that need to be compiled individually. Now, technically, we could support this: you could imagine a package that contains multiple exported artifacts and dispatches between them depending on some conditions (e.g., the value of the enum, whether or the optional Tensor argument was provided, the shape of the input tensor). But you're on your own: torch.compile will do this for you, but torch.export will not.
- No built-in support for models that are split into multiple graphs. Similarly, we've mentioned that an exported model must be a single graph. This is in contrast to torch.compile, which will happily insert graph breaks and compile distinct islands of code that can be glued together with Python eager code. Now, technically, you can do this with export too: you can carve out several distinct subnets of your model, export them individually, and then glue them together with some custom written code on the other end (in fact, Meta's internal recommendation systems do this), but there's no built-in support for this workflow.
- The extra requirements often don't cover important components of real world models. I've mentioned this previously as the extra restrictions export places on you, but it's worth reiterating some of the consequences of this. Take an LLM inference application: obviously, there is a core model that takes in tokens and produces logit predictions--this part of the model is exportable. But there are also important other pieces such as the tokenizer and sampling strategy which are not exportable (tokenizer because it operates on strings, not tensors; sampling because it involves complicated control flow). Arguably, it would be much better if all of these things could be directly bundled with the model itself; in practice, end-to-end applications should just expect to directly implement these in native code (e.g., as is done in torchchat). Our experience with TorchScript taught us that we don't really want to be in the business of designing a general purpose programming language that is portable across all of export's targets; better to just bet that the tokenizer doesn't change that often and eat the cost of natively integrating it by hand.
AOTInductor specific downsides:
- You still need libtorch to actually run the model. Although AOTInductor binaries bundle most of their compiled kernel implementation, they still require a minimal runtime that can offer basic necessities such as tensor allocation and access to custom operators. There is not yet an official offering of an alternative, lightweight implementation of the stable ABI AOTInductor binaries depends on, so if you do want to deploy AOTInductor binaries you will typically have to also bring libtorch along. This is usually not a big deal server side, but it can be problematic if you want to do client side deployments!
- No CUDA graphs support. This one is not such a big deal since you are much less likely to be CPU bound when the host side logic is all compiled C++, but there's no support for CUDA graphs in AOTInductor. (Funnily enough, this is also something you technically can orchestrate from outside of AOTInductor.)
Edge deployment
Scenario: You need to deploy your PyTorch model to edge devices (e.g., a mobile phone or a wearable device) where computational resources are limited. You have requirements that are a bit different from server size: you care a lot more about minimizing binary size and startup time. Traditional PyTorch deployment with full libtorch won't work. The device you're deploying too might also have some strange extra processors, like a DSP or NPU, that you want your model to target.
What to do: Use torch.export targeting Executorch. Among other things, Executorch offers a completely separate runtime for exported PyTorch programs (i.e., it has no dependency on libtorch, except perhaps there are a few headers which we share between the projects) which was specifically designed for edge deployment. (Historical note: we spent a long time trying to directly ship a stripped down version of libtorch to mobile devices, but it turns out it's really hard to write code that is portable on server and client, so it's better to only share when absolutely necessary.) Quantization is also a pretty important part of deployment to Edge, and Executorch incorporates this into the end-to-end workflow.
Open source examples: torchchat also has an Executorch integration letting you run an LLM on your Android phone.
Downsides. All of the export related downsides described previously apply here. But here's something to know specifically about Executorch:
- The edge ecosystem is fragmented. At time of writing, there are seven distinct backends Executorch can target. This is not really Executorch's fault, it comes with the territory--but I want to call it out because it stands in stark contrast to the NVIDIA's server-side hegemony. Yes, AMD GPUs are a thing, and various flavors of CPU are real, but it really is a lot easier to be focused on server side because NVIDIA GPUs come first.
Pre-compiled kernels for eager mode
Scenario: You need a new function or self-contained module with an efficient kernel implementation. However, you would prefer not to have to write the CUDA (or even Triton) by hand; the kernel is something that torch.compile can generate from higher level PyTorch implementation. At the same time, however, you cannot tolerate just-in-time compilation at all (perhaps you are doing a massive training job, and any startup latency makes it more likely that one of your nodes will fail during startup and then you make no progress at all; or maybe you just find it annoying when PyTorch goes out to lunch when you cache miss).
What to do: Use torch.export targeting AOTInductor, and then load and run the AOTInductor generated binary from Python.
Downsides. So, we know this use case works, because we have internally used this to unblock people who wanted to use Triton kernels but could not tolerate Triton's just-in-time compilation. But there's not much affordance in our APIs for this use case; for example, guard-based dispatch is often quite useful for compiled functions, but you'll have to roll that by hand. More generally, when compiling a kernel, you have to make tradeoffs about how static versus dynamic the kernel should be (for example, will you force the inputs to be evenly divisible by eight? Or would you have a separate kernel for the divisible and not divisible cases?) Once again, you're on your own for making the call there.
An exchange format across systems
Scenario: In an ideal world, you would have a model, you could export it to an AOTInductor binary, and then be all done. In reality, maybe this export process needs to be a multi-stage process, where it has to be processed to some degree on one machine, and then finish processing on another machine. Or perhaps you need to shift the processing over time: you want to export a model to freeze it (so it is no longer tied to its original source code), and then repeatedly run the rest of the model processing pipeline on this exported program (e.g., because you are continuously updating its weights and then reprocessing the model). Maybe you want to export the model and then train it from Python later, committing to a distributed training strategy only when you know how many nodes you are running. The ability to hermetically package a model and then process it later is one of the big value propositions of TorchScript and torch.package.
What to do: Use torch.export by itself, potentially using pre-dispatch if you need to support training use-cases. torch.export produces an ExportedProgram which has a clean intermediate representation that you can do processing on, or just serialize and then do processing on later.
Downsides:
- Custom operators are not packaged. A custom operator typically refers to some native code which was linked with PyTorch proper. There's no way to extract out this kernel and embed it into the exported program so that there is no dependence; instead, you're expected to ensure the eventual runtime relinks with the same custom operator. Note that this problem doesn't apply to user defined Triton kernels, as export can simply compile it and package the binary directly into the exported product. (Technically, this applies to AOTInductor too, but this tends to be much more of a problem for use cases which are primarily about freezing rapidly evolving model code, as opposed to plain inference where you would simply just expect people to not be changing custom operators willy nilly.)
- Choose your own decompositions. Export produces IR that only contains operators from a canonical operator set. However, the default choice is sometimes inappropriate for use cases (e.g., some users want aten.upsample_nearest2d.vec to be decomposed while others do not), so in practice for any given target you may have a bespoke operator set that is appropriate for that use case. Unfortunately, it can be fiddly getting your operator set quite right, and while we've talked about ideas like a "build your own operator set interactive tool" these have not been implemented yet.
- Annoyingly large FC/BC surface. Something I really like about AOTInductor is that it has a very small FC/BC surface: I only need to make sure I don't make breaking changes to the C ABI, and I'm golden. With export IR, the FC/BC surface is all of the operators produced by export. Even a decomposition is potentially BC breaking: a downstream pass could be expecting to see an operator that no longer exists because I've decomposed it into smaller pieces. Matters get worse in pre-dispatch export, since the scope of APIs used inside export IR expands to include autograd control operators (e.g., torch.no_grad) as well as tensor subclasses (since Tensor subclasses cannot be desugared if we have not yet eliminated autograd). We will not break your AOTInductor blobs. We can't as easily give the same guarantee for the IR here.
Next time: What's missing, and what we're doing about it