ezyang’s blog

the arc of software bends towards understanding

Tensor programming for databases, with first class dimensions

Tensor libraries like PyTorch and JAX have developed compact and accelerated APIs for manipulating n-dimensional arrays. N-dimensional arrays are kind of similar to tables in database, and this results in the logical question which is could you setup a Tensor-like API to do queries on databases that would be normally done with SQL? We have two challenges:

  • Tensor computation is typically uniform and data-independent. But SQL relational queries are almost entirely about filtering and joining data in a data-dependent way.
  • JOINs in SQL can be thought of as performing outer joins, which is not a very common operation in tensor computation.

However, we have a secret weapon: first class dimensions were primarily designed to as a new frontend syntax that made it easy to express einsum, batching and tensor indexing expressions. They might be good for SQL too.

Representing the database. First, how do we represent a database? A simple model following columnar database is to have every column be a distinct 1D tensor, where all columns part of the same table have a consistent indexing scheme. For simplicity, we'll assume that we support rich dtypes for the tensors (e.g., so I can have a tensor of strings). So if we consider our classic customer database of (id, name, email), we would represent this as:

customers_id: int64[C]
customers_name: str[C]
customers_email: str[C]

Where C is the number of the entries in the customer database. Our tensor type is written as dtype[DIM0, DIM1, ...], where I reuse the name that I will use for the first class dimension that represents it. Let's suppose that the index into C does not coincide with id (which is good, because if they did coincide, you would have a very bad time if you ever wanted to delete an entry from the database!)

This gives us an opportunity for baby's first query: let's implement this query:

SELECT c.name, c.email FROM customers c WHERE c.id = 1000

Notice that the result of this operation is data-dependent: it may be zero or one depending on if the id is in the database. Here is a naive implementation in standard PyTorch:

mask = customers_id == 1000
return (customers_name[mask], customers_email[mask])

Here, we use boolean masking to perform the data-dependent filtering operation. This implementation in eager is a bit inefficient; we materialize a full boolean mask that is then fed into the subsequent operations; you would prefer for a compiler to fuse the masking and indexing together. First class dimensions don't really help with this example, but we need to introduce some new extensions to first class dimensions. First, what we can do:

C = dims(1)
c_id = customers_id[C]  # {C} => int64[]
c_name = customers_name[C]  # {C} => str[]
c_email = customers_email[C]  # {C} => str[]
c_mask = c_id == 1000  # {C} => bool[]

Here, a tensor with first class tensors has a more complicated type {DIM0, DIM1, ...} => dtype[DIM2, DIM3, ...]. The first class dimensions are all reported in the curly braces to the left of the double arrow; curly braces are used to emphasize the fact that first class dimensions are unordered.

What next? The problem is that now we want to do something like torch.where(c_mask, c_name, ???) but we are now in a bit of trouble, because we don't want anything in the false branch of where: we want to provide something like "null" and collapse the tensor to a smaller number of elements, much like how boolean masking did it without first class dimensions. To express this, we'll introduce a binary version of torch.where that does exactly this, as well as returning the newly allocated FCD for the new, data-dependent dimension:

C2, c2_name = torch.where(c_mask, c_name)  # {C2} => str[]
_C2, c2_email = torch.where(c_mask, c_email)  # {C2} => str[], n.b. C2 == _C2
return c2_name, c2_email

Notice that torch.where introduces a new first-class dimension. I've chosen that this FCD gets memoized with c_mask, so whenever we do more torch.where invocations we still get consistently the same new FCD.

Having to type out all the columns can be a bit tiresome. If we assume all elements in a table have the same dtype (let's call it dyn, short for dynamic type), we can more compactly represent the table as a 2D tensor, where the first dimension is the indexing as before, and the second dimension is the columns of the database. For clarity, we'll support using the string name of the column as a shorthand for the numeric index of the column. If the tensor is contiguous, this gives a more traditional row-wise database. The new database can be conveniently manipulated with FCDs, as we can handle all of the columns at once instead of typing them out individually):

customers:  dyn[C, C_ATTR]
C = dims(1)
c = customers[C]  # {C} => dyn[C_ATTR]
C2, c2 = torch.where(c["id"] == 1000, c)  # {C2} => dyn[C_ATTR]
return c2[["name", "email"]].order(C2)  # dyn[C2, ["name", "email"]]

We'll use this for the rest of the post, but the examples should be interconvertible.

Aggregation. What's the average age of all customers, grouped by the country they live in?

SELECT AVG(c.age) FROM customers c GROUP BY c.country;

PyTorch doesn't natively support this grouping operation, but essentially what is desired here is a conversion into a nested tensor, where the jagged dimension is the country (each of which will have a varying number of countries). Let's hallucinate a torch.groupby analogous to its Pandas equivalent:

customers: dyn[C, C_ATTR]
customers_by_country = torch.groupby(customers, "country")  # dyn[COUNTRY, JC, C_ATTR]
COUNTRY, JC = dims(2)
c = customers_by_country[COUNTRY, JC]  # {COUNTRY, JC} => dyn[C_ATTR]
return c["age"].mean(JC).order(COUNTRY)  # f32[COUNTRY]

Here, I gave the generic indexing dimension the name JC, to emphasize that it is a jagged dimension. But everything proceeds like we expect: after we've grouped the tensor and rebound its first class dimensions, we can take the field of interest and explicitly specify a reduction on the dimension we care about.

In SQL, aggregations have to operate over the entirety of groups specified by GROUP BY. However, because FCDs explicitly specify what dimensions we are reducing over, we can potentially decompose a reduction into a series of successive reductions on different columns, without having to specify subqueries to progressively perform the reductions we are interested in.

Joins. Given an order table, join it with the customer referenced by the customer id:

SELECT o.id, c.name, c.email FROM orders o JOIN customers c ON o.customer_id = c.id

First class dimensions are great at doing outer products (although, like with filtering, it will expensively materialize the entire outer product naively!)

customers: dyn[C, C_ATTR]
orders: dyn[O, O_ATTR]
C, O = dims(2)
c = customers[C]  # {C} => dyn[C_ATTR]
o = orders[O]  # {O} => dyn[O_ATTR]
mask = o["customer_id"] == c["id"]  # {C, O} => bool[]
outer_product = torch.cat(o[["id"]], c[["name", "email"]])  # {C, O} => dyn[["id", "name", "email"]]
CO, co = torch.where(mask, outer_product)  # {CO} => dyn[["id", "name", "email"]]
return co.order(CO)  # dyn[C0, ["id", "name", "email"]]

What's the point. There are a few reasons why we might be interested in the correspondence here. First, we might be interested in applying SQL ideas to the Tensor world: a lot of things people want to do in preprocessing are similar to what you do in traditional relational databases, and SQL can teach us what optimizations and what use cases we should think about. Second, we might be interested in applying Tensor ideas to the SQL world: in particular, I think first class dimensions are a really intuitive frontend for SQL which can be implemented entirely embedded in Python without necessitating the creation of a dedicated DSL. Also, this might be the push needed to get TensorDict into core.

  • October 14, 2024

What’s different this time? LLM edition

One of the things that I learned in grad school is that even if you've picked an important and unsolved problem, you need some reason to believe it is solvable--especially if people have tried to solve it before! In other words, "What's different this time?" This is perhaps a dreary way of shooting down otherwise promising research directions, but you can flip it around: when the world changes, you can ask, "What can I do now that I couldn't do before?"

This post is a list of problems in areas that I care about (half of this is PL flavor, since that's what I did my PhD in), where I suspect something has changed with the advent of LLMs. It's not a list of recipes; there is still hard work to figure out how exactly an LLM can be useful (for most of these, just feeding the entire problem into ChatGPT usually doesn't work). But I often talk to people want to get started on something, anything, but have no idea to start. Try here!

Static analysis. The chasm between academic static analysis work and real world practice is the scaling problems that come with trying to apply the technique to a full size codebase. Asymptotics strike as LOC goes up, language focused techniques flounder in polyglot codebases, and "Does anyone know how to write cmake?" But this is predicated on the idea that static analysis has to operate on a whole program. It doesn't; humans can do perfectly good static analysis on fragments of code without having to hold the entire codebase in their head, without needing access to a build system. They make assumptions about APIs and can do local reasoning. LLMs can play a key role in drafting these assumptions so that local reasoning can occur. What if the LLM gets it wrong? Well, if an LLM could get it wrong, an inattentive junior developer might get it wrong too--maybe there is a problem in the API design. LLMs already do surprisingly well if you one-shot prompt them to find bugs in code; with more traditional static analysis support, maybe they can do even better.

DSL purgatory. Consider a problem that can be solved with code in a procedural way, but only by writing lots of tedious, error prone boilerplate (some examples: drawing diagrams, writing GUIs, SQL queries, building visualizations, scripting website/mobile app interactions, end to end testing). The PL dream is to design a sweet compositional DSL that raises the level of abstraction so that you can render a Hilbert curve in seven lines of code. But history is also abound with cases where the DSL did not solve the problems, or maybe it did solve the problem but only after years of grueling work, and so there are still many problems that feel like there ought to be a DSL that should solve them but there isn't. The promise of LLMs is that they are extremely good at regurgitating low level procedural actions that could conceivably be put together in a DSL. A lot of the best successes of LLMs today is putting coding powers in the hands of domain experts that otherwise do not how to code; could it also help in putting domain expertise in the hands of people who can code?

I am especially interested in these domains:

  • SQL - Its strange syntax purportedly makes it easier for non-software engineers to understand, whereas many (myself included) would often prefer a more functional syntax ala LINQ/list comprehensions. It's pretty hard to make an alternate SQL syntax take off though, because SQL is not one language, but many many dialects everywhere with no obvious leverage point. That sounds like an LLM opportunity. Or heck, just give me one of those AI editor environments but specifically fine tuned for SQL/data visualization, don't even bother with general coding.
  • End to end testing - This is https://momentic.ai/ but personally I'm not going to rely on a proprietary product for testing in my OSS projects. There's definitely an OSS opportunity here.
  • Scripting website/mobile app interactions - The website scraping version of this is https://reworkd.ai/ but I am also pretty interested in this from the browser extension angle: to some extent I can take back control of my frontend experience with browser extensions; can I go further with LLMs? And we typically don't imagine that I can do the same with a mobile app... but maybe I can??

OSS bread and butter. Why is Tesseract still the number one OSS library for OCR? Why is smooth and beautiful text to voice not ubiquitous? Why is the voice control on my Tesla so bad? Why is the wake word on my Android device so unreliable? Why doesn't the screenshot parser on a fansite for my favorite mobage not able to parse out icons? The future has arrived, but it is not uniformly distributed.

Improving the pipeline from ephemeral to durable stores of knowledge. Many important sources of knowledge are trapped in "ephemeral" stores, like Discord servers, private chat conversations, Reddit posts, Twitter threads, blog posts, etc. In an ideal world, there would be a pipeline of this knowledge into more durable, indexable forms for the benefit of all, but actually doing this is time consuming. Can LLMs help? Note that the dream of LLMs is you can just feed all of this data into the model and just ask questions to it. I'm OK with something a little bit more manual, we don't have to solve RAG first.

  • October 4, 2024

Interactive scraping with Jupyter and Puppeteer

One of the annoying things about scraping websites is bouncing back and forth between the browser where you are using Dev Tools to work out what selectors you should be using to scrape out data, and your actual scraping script, which is usually some batch program that may have to take a few steps before the step you are debugging. A batch script is fine once your scraper is up and running, but while developing, it's really handy to pause the scraping process at some page and fiddle around with the DOM to see what to do.

This interactive-style development is exactly what Juypter notebooks shine at; when used in conjunction with a browser-based scraping library like Puppeteer, you can have exactly this workflow. Here's the setup:

  1. Puppeteer is a JavaScript library, so you'll need a JavaScript kernel for Jupyter to run it. As an extra complication, Puppeteer is also async, so you'll need a kernel that supports async execution. Fortunately, ijavascript-await provides exactly this. Note that on recent versions of node this package does not compile; you can install this PR which makes this work: https://github.com/n-riesco/ijavascript/pull/257 Hypothetically, we should be able to use stock ijavascript when node supports top level await, but this currently does not work: https://github.com/nodejs/node/issues/40898
  2. Inside the directory you will store your snotebooks, you'll need to npm install puppeteer so that it's available for your notebooks.
  3. Launch Puppeteer with let puppeteer = require('puppeteer'); let browser = await puppeteer.launch({headless: false}); and profit!

There will be a live browser instance which you can poke at using Dev Tools, and you type commands into the Jupyter notebook and see how they affect the browser state.

I tweeted about this and the commenters had some good suggestions about other things you could try:

  • You don't have to use Puppeteer; Selenium can also drive the browser, and it has a Python API to boot (so no faffing about with alternate Jupyter kernels necessary). I personally prefer working in JavaScript for crawlers, since the page scripting itself is also in JavaScript, but this is mostly a personal preference thing.
  • For simple interactions, where all you really want is to just do a few interactions and record them, Headless Recorder provides a nice extension for just directly recording operations in your browser and then getting them out in executable form. I haven't tried it out yet but it seems like it would be very easy to use.
  • November 23, 2021

PyTorch Developer Podcast

I'm launching a new podcast, the PyTorch Developer Podcast. The idea is to be a place for the PyTorch dev team to do bite sized (10-20 min) topics about all sorts of internal development topics in PyTorch. For now, it's just me monologuing for fifteen minutes about whatever topic I decide. The plan is to release an episode daily, five days a week, until I run out of things to say (probably not for a while, I have SO MANY THINGS TO SAY). I don't edit the podcasts and do minimal planning, so they're a bit easier to do than blog posts. Check it out! There's two episodes out already, one about how we do Python bindings for our C++ objects and another about history and constraints of the dispatcher. If there are any topics you'd like me to cover, give a shout.

  • May 5, 2021

Rage bug reporting

At Facebook, we have an internal convention for tooling called "rage". When something goes wrong and you want to report a bug, the tool developer will typically ask you to give them a rage. For a command line tool, this can be done by running a rage subcommand, which will ask about which previous CLI invocation you'd like to report, and then giving you a bundle of logs to send to the developer.

A rage has an important property, compared to a conventional log level flag like -v: rage recording is always on. In other words, it is like traditional server application logs, but applied to client software. Logging is always turned on, and the rage subcommand makes it easy for a user to send only the relevant portion of logs (e.g., the logs associated with the command line invocation that is on).

For some reason, rage functionality is not that common in open source tools. I can imagine any number of reasons why this might be the case:

  • Adding proper logging is like flossing--annoying to do at the time even when it can save you a lot of pain later.
  • Even if you have logging, you still need to add infrastructure to save the logs somewhere and let users retrieve them afterwards.
  • It's something of an art to write logs that are useful enough so that developer can diagnose the problem simply by "reading the tea leaves", but not so detailed that they slow down normal execution of the program. And don't forget, you better not expose private information!
  • Most programs are simple, and you can just fall back on the old standby of asking the user to submit reproduction instructions in their bug report.

Still, in the same way most sysadmins view logging as an invaluable tool for debugging server issues, I think rage reporting is an invaluable tool for debugging client issues. In ghstack, it didn't take very many lines of code to implement rage reporting: ghstack.logs (for writing the logs to the rage directory) and ghstack.rage (for reading it out). But it has greatly reduced my support load for the project; given a rage, I can typically figure out the root cause of a bug without setting up a reproducer first.

  • April 25, 2021

The PyTorch open source process

PyTorch is a fairly large and active open source project, and sometimes we have people come to us and ask if there are any lessons from how we run PyTorch that they could apply to their own projects. This post is an attempt to describe some of the processes as of 2021 that help PyTorch operate effectively as an open source project. I won't claim that everything we do necessarily the best way to go about doing things, but at the very least, everything I describe here is working in practice.

Background. Not all open source projects are the same, and there are some peculiarities to PyTorch which may reduce the applicability of some of what I describe below in other contexts. Here are some defining features of PyTorch, as a project:

  • The majority of full time PyTorch developers work at Facebook. To be clear, there are many full time PyTorch developers that work at other companies: NVIDIA, Intel, Quansight, Microsoft, AMD, IBM, Preferred Networks, Google and Amazon all employ people whose job it is to work on PyTorch. But the majority of full timers are at Facebook, distinguishing PyTorch from hobbyist open source projects or projects run by a foundation of some sort.
  • PyTorch is a federation. As coined by Nadia Eghbal, PyTorch is a project with high contributor growth and user growth. In my State of PyTorch (2020) talk, I go into more details, but suffice to say, we have over nine companies contributing to PyTorch, and a long tail of other contributors (making up 40% of all of our commits). This makes managing PyTorch sometimes particularly challenging, and many of the processes I will describe below arose from growing pains scaling this level of activity.
  • PyTorch has a lot of surface area. CPU, CUDA, ROCm, ONNX, XLA, serving, distributions, quantization, etc. It's impossible for a single contributor to be well-versed in every area of the project, and so some of the challenge is just making sure the right people see the things they need to see.

Alright, so how does PyTorch deal with its scale? Here are some of the things we do.

Issue triage. PyTorch receives too many bug reports a day for any one person to keep track of all of them. Largely inspired by this apenwarr post, we setup an oncall rotation amongst Facebook contributors to serve as first line triage for all of these issues. The golden rule of issue triage is that you DO NOT fix bugs in triage; the goal of triage is to (1) route bugs to the correct people via appropriate GitHub labels, and (2) look for high priority bugs and raise awareness of these bugs. Every week, we have a meeting to review high priority bugs (and other bugs marked for triage review) and talk about them. The oncall itself rotates daily, to discourage people from letting a week's worth of issues pile up in the backlog, and we use a relatively intricate search query to make sure only relevant issues show up for the oncall to handle.

The most important consequence of issue triage is that you can unwatch PyTorch repository as a whole. Instead, by watching various labels (using our cc bot), you can trust that you will get CC'ed to issues related to topics, even if the triager doesn't know that you're interested in the issue! The weekly meeting makes sure that all maintainers collectively have an idea about what major issues are currently affecting PyTorch, and helps socialize what we as a project think of as a "high priority" issue. Finally, the high priority label is a good way to find impactful problems to work on in the project, even if you don't know much else about the project.

Pull request triage. Similarly, we receive a decent number of drive by pull requests from one time contributors. Those people are not in a good position to find reviewers for their contributions, so we also have a triager look through these pull requests and make sure someone is assigned to review them. If the PR is particularly simple, the triager might just go ahead and merge it themselves. There's actually some good automation for doing this (e.g., homu) but we've been too lazy to set any of it up, and by hand reviewer assignment doesn't seem to be too much burden on top of the existing oncall.

Tree hugging oncall. PyTorch has a huge CI system covering many different system configurations which most contributors rely on to test if their changes are safe. Sometimes people break master. Separate from the triage oncall, we have a tree hugging oncall whose job it is to revert jobs if they break master. This oncall involves mostly paying attention to the CI HUD and reverting commits if they result in master breakage in one of the configurations.

Importing to Facebook infrastructure. We actually run Facebook infrastructure directly off of the HEAD branch in PyTorch. The tooling that makes this possible is fbshipit, which mirrors commits between Facebook's internal monorepo and our public GitHub repository. This setup has been something of a double-edged sword for us: requiring Facebook and GitHub to be in sync means that only Facebook employees can actually land pull requests (we try to streamline the process as much as possible for external maintainers, but at the end of the day someone at Facebook has to actually push the green button), but it means we don't have to worry about doing periodic "mega-imports" into Facebook infrastructure (which we have done in the past and were quite difficult to do). We are very interested in fixing this situation and have floated some proposals on changing how we do internal releases to make it possible to let external contributors land PRs directly.

RFCs. Most feature discussion happens on GitHub issues, but sometimes, a feature is too big and complicated to adequately discuss in a GitHub issue. In those cases, they can be discussed in the rfcs repository (inspired by the Rust RFCs process). The formal process on this repository isn't too solidified yet, but generally people go there if they feel that it is too difficult to discuss the issue in GitHub issues. We don't yet have a process for shepherding unsolicited RFCs.

Conclusion. PyTorch's open source process isn't rocket science: there's an oncall, the oncall does some things. The devil is in the details: all of PyTorch's oncall responsibilities are carefully scoped so that your oncall responsibilities aren't something that will take an unbounded amount of time; they're something you can knock out in an hour or two and call it a day. You could make the argument that we rely excessively on oncalls when automation is possible, but what we have found is that oncalls require less infrastructure investment, and integrate well with existing processes and flows at Facebook. They might not be right everywhere, but at least for us they seem to be doing a good job.

  • January 6, 2021

The hidden problem(?) with basic block procedures in SSA

Years ago, Nadav Rotem related to me this story about why basic block procedures in Swift are not as good as they seem. Nelson Elhage reminded me about this on Twitter and so I thought this should be put into the public record.

Basic block procedures make certain optimizations more difficult. Consider this program:

block j3 (%y1, %y2) { ... }
block j1 () { jump j3(%x1, %x2) }
block j2 () { jump j3(%x3, %x4) }

Is this program easier or more difficult to optimize than the traditional SSA with phi-nodes formulation?

L1:
   goto L3
L2:
   goto L3
L3:
   %y1 = phi [%x1, %L1] [%x3, %L2]
   %y2 = phi [%x2, %L1] [%x4, %L2]

Suppose that the optimizer determines that y1 is unused inside j3/L3 and can be eliminated. In basic block land, y1 can be eliminated simply by deleting "y1 = phi x1 x3". However, in join point land, you have to not only eliminate y1 but also update all the call sites of j3, since you've changed the function signature. In a mutable AST, changing function signatures is a pain; in particular, the mutations you would have to do to eliminate the argument include intermediate states that are not valid ASTs (making it easy to accidentally trigger asserts.)

When I saw this example, I wondered why GHC (which has the moral equivalent of basic block procedures in the form of join points) didn't have this problem. Well, it turns out this optimization can be done as a series of local transformations. First, we do a worker/wrapper transformation, introducing an intermediate block (the worker) that drops the dead argument:

block j3 (%y1, %y2) { jump wj3(%y2) }
block j1 () { jump j3(%x1, %x2) }
block j2 () { jump j3(%x3, %x4) }
block wj3 (%y2) { ... }

Later, we inline j3, which removes the wrapper. Worker/wrapper is a very important optimization for functional programs, but it's easy to imagine why it is less preferred in mutable compiler land.

  • October 24, 2020

Idiomatic algebraic data types in Python with dataclasses and Union

Greetings from 2024! An official pattern matching PEP has been accepted https://peps.python.org/pep-0636/ and is available in Python 3.10. Class patterns are tested using isinstance, with no inheritance structure necessary, making the pattern described in this post 100% forward compatible to real pattern matching.


One of the features I miss most in non-Haskell programming languages is algebraic data types (ADT). ADTs fulfill a similar role to objects in other languages, but with more restrictions: objects are an open universe, where clients can implement new subclasses that were not known at definition time; ADTs are a closed universe, where the definition of an ADT specifies precisely all the cases that are possible. We often think of restrictions of a bad thing, but in the case of ADTs, the restriction of being a closed universe makes programs easier to understand (a fixed set of cases to understand, as opposed to a potentially infinite set of cases) and allows for new modes of expression (pattern matching). ADTs make it really easy to accurately model your data structures; they encourage you to go for precise types that make illegal states unrepresentable. Still, it is generally not a good idea to try to manually reimplement your favorite Haskell language feature in every other programming language you use, and so for years I've suffered in Python under the impression that ADTs were a no go.

Recently, however, I have noticed that a number of new features in Python 3 have made it possible to use objects in the same style of ADTs, in idiomatic Python with virtually no boilerplate. The key features:

  • A structural static type checking system with mypy; in particular, the ability to declare Union types, which let you represent values that could be one of a fixed set of other types, and the ability to refine the type of a variable by performing an isinstance check on it.
  • The dataclasses library, which allows you to conveniently define (possibly immutable) structures of data without having to write boilerplate for the constructor.

The key idea: define each constructor as a dataclass, put the constructors together into an ADT using a Union type, and use isinstance tests to do pattern matching on the result. The result is just as good as an ADT (or better, perhaps; their structural nature bears more similarity to OCaml's polymorphic variants).

Here's how it works. Let's suppose that you want to define an algebraic data type with two results:

data Result
   = OK Int
   | Failure String

showResult :: Result -> String
showResult (OK result) = show result
showResult (Failure msg) = "Failure: " ++ msg

First, we define each constructor as a dataclass:

from dataclasses import dataclass

@dataclass(frozen=True)
class OK:
    result: int

@dataclass(frozen=True)
class Failure:
    msg: str

Using the automatically generated constructors from dataclasses, we can construct values of these dataclasses using OK(2) or Failure("something wrong"). Next, we define a type synonym for the union of these two classes:

Result = Union[OK, Failure]

Finally, we can do pattern matching on Result by doing isinstance tests:

def assert_never(x: NoReturn) -> NoReturn:
    raise AssertionError("Unhandled type: {}".format(type(x).__name__))

def showResult(r: Result) -> str:
    if isinstance(r, OK):
        return str(r.result)
    elif isinstance(r, Failure):
        return "Failure: " + r.msg
    else:
        assert_never(r)

assert_never is a well known trick for doing exhaustiveness checking in mypy. If we haven't covered all cases with enough isinstance checks, mypy will complain that assert_never was given a type like UnhandledCtor when it expected NoReturn (which is the uninhabited type in Python).

That's all there is to it. As an extra bonus, this style of writing unions is compatible with the structured pattern matching PEP, if it actually gets accepted. I've been using this pattern to good effect in our recent rewrite of PyTorch's code generator. If you have the opportunity to work in a statically typed Python codebase, give this style of code a try!

  • October 14, 2020

Let’s talk about the PyTorch dispatcher

http://blog.ezyang.com/img/pytorch-dispatcher/slide-01.png

If this is your first time reading about PyTorch internals, you might want to check out my PyTorch internals post first. In this post, I want to talk about one particular part of PyTorch's internals: the dispatcher. At a first glance, the dispatcher is just a glorified if statement: based on some information about the tensor inputs, decide what piece of code should be called. So why should we care about the dispatcher?

http://blog.ezyang.com/img/pytorch-dispatcher/slide-02.png

Well, in PyTorch, a lot of things go into making an operator work. There is the kernel that does the actual work, of course; but then there is support for reverse mode automatic differentiation, e.g., the bits that make loss.backward() work. Oh, and if your code under torch.jit.trace, you can get a trace of all the operations that were run. Did I mention that if you run these operations on the inside of a vmap call, the batching behavior for the operators is different? There are so many different ways to interpret PyTorch operators differently, and if we tried to handle all of them inside a single function named add, our implementation code would quickly devolve into an unmaintainable mess. The dispatcher is not just an if statement: it is a really important abstraction for how we structure our code internally PyTorch... and it has to do so without degrading the performance of PyTorch (too much, anyway).

http://blog.ezyang.com/img/pytorch-dispatcher/slide-03.png

At the end of this post, our goal will be to understand all the different parts of this picture fit together. This post will proceed in three parts.

http://blog.ezyang.com/img/pytorch-dispatcher/slide-04.png
http://blog.ezyang.com/img/pytorch-dispatcher/slide-05.png

First, we'll talk about the dispatcher itself. What is the dispatcher, how does it decide what kernel to call? Second, we'll talk about the operator registration API, which is the interface by which we register kernels into the dispatcher. Finally, we'll talk about boxing and unboxing, which are a cross-cutting feature in the dispatcher that let you write code once, and then have it work on all kernels.

What is the dispatcher?

http://blog.ezyang.com/img/pytorch-dispatcher/slide-06.png

OK, so what is the dispatcher? For every operator, the dispatcher maintains a table of function pointers which provide implementations for each dispatch key, which corresponds roughly to one of the cross-cutting concerns in PyTorch. In the diagram above, you can see there are dispatch entries in this table for backends (CPU, CUDA, XLA) as well as higher-level concepts like autograd and tracing. The dispatcher's job is to compute a dispatch key, based on the input tensors and some other stuff (more on this shortly), and then do an indirect jump to the function pointed to by the table.

Those of you who are familiar with C++ may observe that this table of function pointers is very similar to virtual tables in C++. In C++, virtual methods on objects are implemented by associating every object with a pointer to a virtual table that contains implementations for each virtual method on the object in question. In PyTorch, we essentially reimplemented virtual tables, but with some differences:

  • Dispatch tables are allocated per operator, whereas vtables are allocated per class. This means that we can extend the set of supported operators simply by allocating a new dispatch table, in contrast to regular objects where you can extend from a class, but you can't easily add virtual methods. Unlike normal object oriented systems, in PyTorch most of the extensibility lies in defining new operators (rather than new subclasses), so this tradeoff makes sense. Dispatch keys are not openly extensible, and we generally expect extensions who want to allocate themselves a new dispatch key to submit a patch to PyTorch core to add their dispatch key.
  • More on this in the next slide, but the computation of our dispatch key considers all arguments to the operator (multiple dispatch) as well as thread-local state (TLS). This is different from virtual tables, where only the first object (this) matters.
  • Finally, the dispatcher supports boxing and unboxing as part of the calling convention for operators. More on this in the last part of the talk!

Fun historical note: we used to use virtual methods to implement dynamic dispatch, and reimplemented them when we realized we needed more juice than virtual tables could give us.

http://blog.ezyang.com/img/pytorch-dispatcher/slide-07.png

So how exactly do we compute the dispatch key which we use to index into the dispatch table? The basic abstraction we use for computing what dispatch key to use is a dispatch key set, which is a bitset over dispatch keys. The general concept is that we union together dispatch key sets from various sources (and in some case mask out some dispatch keys), giving us a final dispatch key set. Then, we pick the first dispatch key in the set (dispatch keys are implicitly ordered by some priority) and that is where we should dispatch to. What are these sources?

  • Each tensor input contributes a dispatch key set of all dispatch keys that were on the tensor (intuitively, these dispatch keys will be things like CPU, telling us that the tensor in question is a CPU tensor and should be handled by the CPU handler on the dispatch table)
  • We also have a local include set, which is used for "modal" functionality, such as tracing, which isn't associate with any tensors, but instead is some sort of thread local mode that a user can turn on and off within some scope.
  • Finally, we have a global set, which are dispatch keys that are always considered. (Since the time this slide was written, Autograd has moved off the global set and onto tensor. However, the high level structure of the system hasn't changed).

There is also a local exclude set, which is used to exclude dispatch keys from dispatch. A common pattern is for some handler to handle a dispatch key, and then mask itself off via the local exclude set, so we don't try reprocessing this dispatch key later.

Let's walk through the evolution of dispatch key through some examples.

http://blog.ezyang.com/img/pytorch-dispatcher/slide-08.png

(Warning: This description is out-of-date for PyTorch master. Instead of Autograd being in global, it is instead on the Tensor. Everything else proceeds as before.)

The most canonical example of the dispatch machinery in operation is how it handles autograd. Read the diagram from the top to the bottom. At the very top, Autograd is in the global set, and the local exclude set is empty. When we do dispatch, we find autograd is the highest priority key (it's higher priority than CPU), and we dispatch to the autograd handler for the operator. Inside the autograd handler, we do some autograd stuff, but more importantly, we create the RAII guard AutoNonVariableTypeMode, which adds Autograd to the local exclude set, preventing autograd from being handled for all of the operations inside of this operator. When we redispatch, we now skip the autograd key (as it is excluded) and dispatch to the next dispatch key, CPU in this example. As local TLS is maintained for the rest of the call tree, all other subsequent dispatches also bypass autograd. Finally, in the end, we return from our function, and the RAII guard removes Autograd from the local exclude set so subsequent operator calls once again trigger autograd handlers.

http://blog.ezyang.com/img/pytorch-dispatcher/slide-09.png

Another similar example is tracing, which is similar to autograd where when we enter the tracing handler, we disable tracing for nested calls with ExcludeDispatchKeyGuard. However, it differs from autograd in how tracing is initially triggered: tracing is toggled by a dispatch key that is added to the local include set when you turn on tracing (with IncludeDispatchKeyGuard), as opposed to the global dispatch key from Autograd (Update: now a dispatch key on tensors).

http://blog.ezyang.com/img/pytorch-dispatcher/slide-10.png

One final example is the BackendSelect key, which operates a little differently from normal keys. The problem backend select solves is that sometimes, the default dispatch key set calculation algorithm doesn't know how to work out what the correct dispatch key should be. One notable case of this are factory functions, which don't have any Tensor arguments (and so, naively, would not dispatch to anything). BackendSelect is in the global dispatch key set, but is only registered for a few operators (for the rest, it is a fallthrough key). The BackendSelect handler inspects the arguments and decides what the final dispatch key should be, and then does a direct dispatch to that key, bypassing dispatch key calculation.

http://blog.ezyang.com/img/pytorch-dispatcher/slide-11.png

The slide summarizes some of the most common sequences of handlers that get processed when dispatching some operation in PyTorch. Most of the time, it's autograd, and then the backend (with a backend select in-between if you are a factory function). For XLA, there is also an XLAPreAutograd key (Update: This key is now simply AutogradXLA) which can be used to override the behavior of the Autograd key. And of course, if you turn on every feature in PyTorch all at once, you can end up stopping at a lot of handlers. Notice that the order in which these handlers are processed matters, since handlers aren't necessarily commutative.

Operator registration

So we talked a lot about how we decide what function pointers in the dispatch table to call, but how do these pointers get in the dispatch table in the first place? This is via the operator registration API. If you have never seen this API before, you should take a look at the Dispatcher in C++ tutorial, which describes how the API works at a very high level. In this section, we'll dive into more detail about how exactly the registration API maps to the dispatch table. Below, you can see the three main ways of interacting with the operator registration API: you define schemas for operators and then register implementations at dispatch keys; finally, there is a fallback method which you can use to define a handler for all operators at some dispatch key.

http://blog.ezyang.com/img/pytorch-dispatcher/slide-12.png

To visualize the impact of these registration operators, let us imagine that the dispatch tables for all operators collectively form a grid, like this:

http://blog.ezyang.com/img/pytorch-dispatcher/slide-13.png

On one axis, we have each operator supported in PyTorch. On the other axis, we have each dispatch key we support in our system. The act of operator registration involves filling in cells with implementations under these two axes.

When we register a kernel for a single operator at a specific dispatch key, we fill in a single cell (blue below):

http://blog.ezyang.com/img/pytorch-dispatcher/slide-14.png

When you register a kernel as a "catch-all" kernel for all dispatch keys in an operator, you fill in an entire row for the operator with one kernel (red below). By the way, if this seems like a strange thing to want to do, it is! And we're working to remove this capability in favor of more specific fills for a subset of keys.

http://blog.ezyang.com/img/pytorch-dispatcher/slide-15.png

When you register a kernel as a fallback for kernel for a single dispatch key, you fill in the column for that dispatch key (green).

http://blog.ezyang.com/img/pytorch-dispatcher/slide-16.png

There's a precedence to these registrations: exact kernel registrations have the highest precedence, and catch all kernels take precedence over fallback.

http://blog.ezyang.com/img/pytorch-dispatcher/slide-17.png

Boxing and unboxing

I want to spend the last part of this post talking about the boxing and unboxing facilities in our dispatcher, which turn out to be pretty important for enabling backend fallback. When you are a programming language designer, there is a classic tradeoff you have to make in deciding whether or not you want to use a boxed or unboxed representation for data:

http://blog.ezyang.com/img/pytorch-dispatcher/slide-18.png

A boxed or homogenous representation is a data representation where every type of object in your system has the same layout. Typically, this means you have some representation that has a header describing what the object in question is, and then some regular payload after it. Homogenous representations are easy to work with in code: because you can always assume that data has some regular layout, you can write functions that work polymorphically over any type of data (think of a function in Java that takes in an arbitrary Object, for example). Most garbage-collected languages have some boxed representation for heap objects, because the garbage collector needs to be able to work over any type of heap object.

In contrast, an unboxed or heterogenous representation allows objects to have a different layout depending on the data in question. This is more efficient than a homogenous representation, as each object can tailor its internal representation to exactly what is needed for the task at hand. However, the downside is we can no longer easily write a single function that works polymorphically over many types of objects. In C++, this problem is worked around using templates: if you need a function to work on multiple types, the C++ compiler will literally create a new copy of the function specialized to each type it is used with.

http://blog.ezyang.com/img/pytorch-dispatcher/slide-19.png

By default, C++ defaults heterogenous layout, but we have implemented homogenous layout in PyTorch by way of the IValue struct (short for interpreter value), which implements a boxed representation that we can use in our interpreter. An IValue is a two word structure consisting of a payload word (usually a pointer, but it could also be an integer or float directly packed into the field) and a tag word which tells us what kind of value the IValue is.

This means we have two calling conventions for functions in PyTorch: the usual, C++, unboxed convention, and a boxed convention using IValues on a stack. Calls (from end users) can come from unboxed API (direct C++ call) or boxed API (from the JIT interpreter); similarly, kernels can be implemented as direct C++ functions (unboxed convention), or can be implemented as a boxed fallback (which by necessity is boxed, as they are polymorphic over all operators).

If I call from boxed API to a boxed fallback, it's easy to see how to plug the two components together...

http://blog.ezyang.com/img/pytorch-dispatcher/slide-22.png

...but how do I get from the unboxed API to the boxed fallback?

http://blog.ezyang.com/img/pytorch-dispatcher/slide-23.png

We need some sort of adapter to take the unboxed inputs and turn them into IValues so that they can be passed via the boxed calling convention. This is done via a boxing adapter, which is automatically generated using C++ templates working off of the unboxed C++ types in the outward facing API.

http://blog.ezyang.com/img/pytorch-dispatcher/slide-24.png

There is also an inverse problem, which is what to do if we have inputs from an boxed API and need to call into an unboxed kernel. Similarly, we have an unboxing adapter, which performs this translation. Unlike the boxing adapter, this adapter is applied to the kernel itself, since C++ templates only work at sites where the unboxed type is statically available (at the boxed API site, these types are not known, so you literally cannot implement this.) Note that we always keep the unboxed API around, so that if a user calls in from the unboxed API, we can fastpath straight to the unboxed kernel.

http://blog.ezyang.com/img/pytorch-dispatcher/slide-25.png

So here is what boxing and unboxing looks overall:

http://blog.ezyang.com/img/pytorch-dispatcher/slide-26.png

Boxing and unboxing are a key feature in the implementation of boxed fallback: without them, we could not let people write single kernels which would work everywhere (and indeed, in the past, people would write code generators to generate repetitive kernels for every function). With template-based boxing and unboxing, you can write a single boxed kernel, and then have it work for operators, even if those operators are defined externally from the library.

Conclusion

http://blog.ezyang.com/img/pytorch-dispatcher/slide-27.png

So that's PyTorch's dispatcher in a nutshell! The dispatcher is still being continuously worked on; for example, Ailing Zhang recently landed a rework of how autograd dispatch keys are handled, which means that we actually no longer have a single Autograd key but have split autograd keys for AutogradCPU/AutogradCUDA/... We're generally interested in improving the user experience for people who register kernels to the dispatcher. Let us know if you have any questions or comments!

  • September 10, 2020

Dynamic scoping is an effect, implicit parameters are a coeffect

For the longest time, I thought of implicit parameters and dynamic scoping were basically the same thing, since they both can be used to solve similar problems (e.g., the so called "configuration problem" where you need to plumb down some configuration deep into a nested body of function definitions without defining them all explicitly). But implicit parameters have a reputation of being something you shouldn't use (use reflection instead), whereas dynamic scoping via the reader monad is a useful and well understood construct (except for the bit where you have to monadify everything). Why the difference?

Oleg points out that implicit parameters are not really dynamic scoping, and gives an example where Lisp and Haskell disagree. And you don't even want the Lisp behavior in Haskell: if you think about the operational notion of dynamic scoping (walk up the stack until you find a binding site of the dynamic variable), it's not very compatible with laziness, since a thunk (which accesses a dynamic variable) will be forced at some unpredictable point in program execution. You really don't want to have to reason about where exactly a thunk will be executed to know how its dynamic variables will be bound, that way lies madness. But somehow, in a strict language, no one has trouble figuring out what should happen with dynamic scoping (well, mostly--more on this shortly).

It turns out that the research community has figured out the difference is that implicit parameters are a coeffect. I believe this was first observed in Coeffects: Unified static analysis of context-dependence (a more modern presentation is in Coeffects: A calculus of context-dependent computation; and a more Haskelly presentation can be found in Embedding effect systems in Haskell). Although, Tomas was commenting on my blog in 2012 about similar ideas, so this probably had been in the works for a while. The key point is that for some coeffects (namely, implicit parameters), call-by-name reduction preserves types and coeffects, and so implicit parameters do not blow up in your face in the same way dynamic scoping (an effect) would. These necessarily behave differently! Type classes are coeffects too, and this is why modern use of implicit parameters in Haskell explicitly acknowledges this (e.g., in the reflection package).

At this year's ICFP, I was pointed at an interesting technical report about implicit values and functions in Koka, a new twist on the dynamic scoping. I found myself wondering if Haskell implicit parameters could learn a thing or two from this work. Implicit values make the good choice of defining implicit values globally at the top level, so that they can participate in normal module namespacing, as opposed to an un-namespaced bag of dynamically scoped names (this is also an improvement that reflection makes over implicit parameters). But actually, it seems to me that implicit functions are taking a page from implicit parameters!

The big innovation is the implicit function is that it resolves all dynamic references in the function (not just lexically, but for all further dynamic calls) to the lexical scope (the dynamic scope at the time the function was defined), producing a function that has no dependence on implicit values (aka, has no effect saying that the implicit value must be defined at the time the function is called.) This is exactly what an implicit parameter let ?x = ... binding would have done, in effect directly filling in the dictionary for the implicit function at definition site, rather than waiting. Very contextual! (Of course, Koka implements this using algebraic effects, and gets to the right semantics with a very simple translation anyway). The result is not exactly dynamic scoping, but as the TR says, it leads to better abstraction.

It is difficult to see how implicit values/functions could make their way back into Haskell, at least without some sequencing constructing (e.g., a monad) lurking around. Though implicit functions behave much like implicit parameters, the rest of the dynamic scoping (including the binding of the implicit function itself) is just good old effectful (not coeffectful) dynamic scope. And you can't just do that in Haskell, without breaking type preservation under beta-reduction and eta-expansion. Haskell has no choice but to go all the way, and once you get beyond the obvious problems of implicit parameters (which reflection fixes), things seem to mostly work out.

  • August 27, 2020