ezyang’s blog

the arc of software bends towards understanding

Idiomatic algebraic data types in Python with dataclasses and Union

Greetings from 2024! An official pattern matching PEP has been accepted https://peps.python.org/pep-0636/ and is available in Python 3.10. Class patterns are tested using isinstance, with no inheritance structure necessary, making the pattern described in this post 100% forward compatible to real pattern matching.


One of the features I miss most in non-Haskell programming languages is algebraic data types (ADT). ADTs fulfill a similar role to objects in other languages, but with more restrictions: objects are an open universe, where clients can implement new subclasses that were not known at definition time; ADTs are a closed universe, where the definition of an ADT specifies precisely all the cases that are possible. We often think of restrictions of a bad thing, but in the case of ADTs, the restriction of being a closed universe makes programs easier to understand (a fixed set of cases to understand, as opposed to a potentially infinite set of cases) and allows for new modes of expression (pattern matching). ADTs make it really easy to accurately model your data structures; they encourage you to go for precise types that make illegal states unrepresentable. Still, it is generally not a good idea to try to manually reimplement your favorite Haskell language feature in every other programming language you use, and so for years I've suffered in Python under the impression that ADTs were a no go.

Recently, however, I have noticed that a number of new features in Python 3 have made it possible to use objects in the same style of ADTs, in idiomatic Python with virtually no boilerplate. The key features:

  • A structural static type checking system with mypy; in particular, the ability to declare Union types, which let you represent values that could be one of a fixed set of other types, and the ability to refine the type of a variable by performing an isinstance check on it.
  • The dataclasses library, which allows you to conveniently define (possibly immutable) structures of data without having to write boilerplate for the constructor.

The key idea: define each constructor as a dataclass, put the constructors together into an ADT using a Union type, and use isinstance tests to do pattern matching on the result. The result is just as good as an ADT (or better, perhaps; their structural nature bears more similarity to OCaml's polymorphic variants).

Here's how it works. Let's suppose that you want to define an algebraic data type with two results:

data Result
   = OK Int
   | Failure String

showResult :: Result -> String
showResult (OK result) = show result
showResult (Failure msg) = "Failure: " ++ msg

First, we define each constructor as a dataclass:

from dataclasses import dataclass

@dataclass(frozen=True)
class OK:
    result: int

@dataclass(frozen=True)
class Failure:
    msg: str

Using the automatically generated constructors from dataclasses, we can construct values of these dataclasses using OK(2) or Failure("something wrong"). Next, we define a type synonym for the union of these two classes:

Result = Union[OK, Failure]

Finally, we can do pattern matching on Result by doing isinstance tests:

def assert_never(x: NoReturn) -> NoReturn:
    raise AssertionError("Unhandled type: {}".format(type(x).__name__))

def showResult(r: Result) -> str:
    if isinstance(r, OK):
        return str(r.result)
    elif isinstance(r, Failure):
        return "Failure: " + r.msg
    else:
        assert_never(r)

assert_never is a well known trick for doing exhaustiveness checking in mypy. If we haven't covered all cases with enough isinstance checks, mypy will complain that assert_never was given a type like UnhandledCtor when it expected NoReturn (which is the uninhabited type in Python).

That's all there is to it. As an extra bonus, this style of writing unions is compatible with the structured pattern matching PEP, if it actually gets accepted. I've been using this pattern to good effect in our recent rewrite of PyTorch's code generator. If you have the opportunity to work in a statically typed Python codebase, give this style of code a try!

  • October 14, 2020

Let’s talk about the PyTorch dispatcher

http://blog.ezyang.com/img/pytorch-dispatcher/slide-01.png

If this is your first time reading about PyTorch internals, you might want to check out my PyTorch internals post first. In this post, I want to talk about one particular part of PyTorch's internals: the dispatcher. At a first glance, the dispatcher is just a glorified if statement: based on some information about the tensor inputs, decide what piece of code should be called. So why should we care about the dispatcher?

http://blog.ezyang.com/img/pytorch-dispatcher/slide-02.png

Well, in PyTorch, a lot of things go into making an operator work. There is the kernel that does the actual work, of course; but then there is support for reverse mode automatic differentiation, e.g., the bits that make loss.backward() work. Oh, and if your code under torch.jit.trace, you can get a trace of all the operations that were run. Did I mention that if you run these operations on the inside of a vmap call, the batching behavior for the operators is different? There are so many different ways to interpret PyTorch operators differently, and if we tried to handle all of them inside a single function named add, our implementation code would quickly devolve into an unmaintainable mess. The dispatcher is not just an if statement: it is a really important abstraction for how we structure our code internally PyTorch... and it has to do so without degrading the performance of PyTorch (too much, anyway).

http://blog.ezyang.com/img/pytorch-dispatcher/slide-03.png

At the end of this post, our goal will be to understand all the different parts of this picture fit together. This post will proceed in three parts.

http://blog.ezyang.com/img/pytorch-dispatcher/slide-04.png
http://blog.ezyang.com/img/pytorch-dispatcher/slide-05.png

First, we'll talk about the dispatcher itself. What is the dispatcher, how does it decide what kernel to call? Second, we'll talk about the operator registration API, which is the interface by which we register kernels into the dispatcher. Finally, we'll talk about boxing and unboxing, which are a cross-cutting feature in the dispatcher that let you write code once, and then have it work on all kernels.

What is the dispatcher?

http://blog.ezyang.com/img/pytorch-dispatcher/slide-06.png

OK, so what is the dispatcher? For every operator, the dispatcher maintains a table of function pointers which provide implementations for each dispatch key, which corresponds roughly to one of the cross-cutting concerns in PyTorch. In the diagram above, you can see there are dispatch entries in this table for backends (CPU, CUDA, XLA) as well as higher-level concepts like autograd and tracing. The dispatcher's job is to compute a dispatch key, based on the input tensors and some other stuff (more on this shortly), and then do an indirect jump to the function pointed to by the table.

Those of you who are familiar with C++ may observe that this table of function pointers is very similar to virtual tables in C++. In C++, virtual methods on objects are implemented by associating every object with a pointer to a virtual table that contains implementations for each virtual method on the object in question. In PyTorch, we essentially reimplemented virtual tables, but with some differences:

  • Dispatch tables are allocated per operator, whereas vtables are allocated per class. This means that we can extend the set of supported operators simply by allocating a new dispatch table, in contrast to regular objects where you can extend from a class, but you can't easily add virtual methods. Unlike normal object oriented systems, in PyTorch most of the extensibility lies in defining new operators (rather than new subclasses), so this tradeoff makes sense. Dispatch keys are not openly extensible, and we generally expect extensions who want to allocate themselves a new dispatch key to submit a patch to PyTorch core to add their dispatch key.
  • More on this in the next slide, but the computation of our dispatch key considers all arguments to the operator (multiple dispatch) as well as thread-local state (TLS). This is different from virtual tables, where only the first object (this) matters.
  • Finally, the dispatcher supports boxing and unboxing as part of the calling convention for operators. More on this in the last part of the talk!

Fun historical note: we used to use virtual methods to implement dynamic dispatch, and reimplemented them when we realized we needed more juice than virtual tables could give us.

http://blog.ezyang.com/img/pytorch-dispatcher/slide-07.png

So how exactly do we compute the dispatch key which we use to index into the dispatch table? The basic abstraction we use for computing what dispatch key to use is a dispatch key set, which is a bitset over dispatch keys. The general concept is that we union together dispatch key sets from various sources (and in some case mask out some dispatch keys), giving us a final dispatch key set. Then, we pick the first dispatch key in the set (dispatch keys are implicitly ordered by some priority) and that is where we should dispatch to. What are these sources?

  • Each tensor input contributes a dispatch key set of all dispatch keys that were on the tensor (intuitively, these dispatch keys will be things like CPU, telling us that the tensor in question is a CPU tensor and should be handled by the CPU handler on the dispatch table)
  • We also have a local include set, which is used for "modal" functionality, such as tracing, which isn't associate with any tensors, but instead is some sort of thread local mode that a user can turn on and off within some scope.
  • Finally, we have a global set, which are dispatch keys that are always considered. (Since the time this slide was written, Autograd has moved off the global set and onto tensor. However, the high level structure of the system hasn't changed).

There is also a local exclude set, which is used to exclude dispatch keys from dispatch. A common pattern is for some handler to handle a dispatch key, and then mask itself off via the local exclude set, so we don't try reprocessing this dispatch key later.

Let's walk through the evolution of dispatch key through some examples.

http://blog.ezyang.com/img/pytorch-dispatcher/slide-08.png

(Warning: This description is out-of-date for PyTorch master. Instead of Autograd being in global, it is instead on the Tensor. Everything else proceeds as before.)

The most canonical example of the dispatch machinery in operation is how it handles autograd. Read the diagram from the top to the bottom. At the very top, Autograd is in the global set, and the local exclude set is empty. When we do dispatch, we find autograd is the highest priority key (it's higher priority than CPU), and we dispatch to the autograd handler for the operator. Inside the autograd handler, we do some autograd stuff, but more importantly, we create the RAII guard AutoNonVariableTypeMode, which adds Autograd to the local exclude set, preventing autograd from being handled for all of the operations inside of this operator. When we redispatch, we now skip the autograd key (as it is excluded) and dispatch to the next dispatch key, CPU in this example. As local TLS is maintained for the rest of the call tree, all other subsequent dispatches also bypass autograd. Finally, in the end, we return from our function, and the RAII guard removes Autograd from the local exclude set so subsequent operator calls once again trigger autograd handlers.

http://blog.ezyang.com/img/pytorch-dispatcher/slide-09.png

Another similar example is tracing, which is similar to autograd where when we enter the tracing handler, we disable tracing for nested calls with ExcludeDispatchKeyGuard. However, it differs from autograd in how tracing is initially triggered: tracing is toggled by a dispatch key that is added to the local include set when you turn on tracing (with IncludeDispatchKeyGuard), as opposed to the global dispatch key from Autograd (Update: now a dispatch key on tensors).

http://blog.ezyang.com/img/pytorch-dispatcher/slide-10.png

One final example is the BackendSelect key, which operates a little differently from normal keys. The problem backend select solves is that sometimes, the default dispatch key set calculation algorithm doesn't know how to work out what the correct dispatch key should be. One notable case of this are factory functions, which don't have any Tensor arguments (and so, naively, would not dispatch to anything). BackendSelect is in the global dispatch key set, but is only registered for a few operators (for the rest, it is a fallthrough key). The BackendSelect handler inspects the arguments and decides what the final dispatch key should be, and then does a direct dispatch to that key, bypassing dispatch key calculation.

http://blog.ezyang.com/img/pytorch-dispatcher/slide-11.png

The slide summarizes some of the most common sequences of handlers that get processed when dispatching some operation in PyTorch. Most of the time, it's autograd, and then the backend (with a backend select in-between if you are a factory function). For XLA, there is also an XLAPreAutograd key (Update: This key is now simply AutogradXLA) which can be used to override the behavior of the Autograd key. And of course, if you turn on every feature in PyTorch all at once, you can end up stopping at a lot of handlers. Notice that the order in which these handlers are processed matters, since handlers aren't necessarily commutative.

Operator registration

So we talked a lot about how we decide what function pointers in the dispatch table to call, but how do these pointers get in the dispatch table in the first place? This is via the operator registration API. If you have never seen this API before, you should take a look at the Dispatcher in C++ tutorial, which describes how the API works at a very high level. In this section, we'll dive into more detail about how exactly the registration API maps to the dispatch table. Below, you can see the three main ways of interacting with the operator registration API: you define schemas for operators and then register implementations at dispatch keys; finally, there is a fallback method which you can use to define a handler for all operators at some dispatch key.

http://blog.ezyang.com/img/pytorch-dispatcher/slide-12.png

To visualize the impact of these registration operators, let us imagine that the dispatch tables for all operators collectively form a grid, like this:

http://blog.ezyang.com/img/pytorch-dispatcher/slide-13.png

On one axis, we have each operator supported in PyTorch. On the other axis, we have each dispatch key we support in our system. The act of operator registration involves filling in cells with implementations under these two axes.

When we register a kernel for a single operator at a specific dispatch key, we fill in a single cell (blue below):

http://blog.ezyang.com/img/pytorch-dispatcher/slide-14.png

When you register a kernel as a "catch-all" kernel for all dispatch keys in an operator, you fill in an entire row for the operator with one kernel (red below). By the way, if this seems like a strange thing to want to do, it is! And we're working to remove this capability in favor of more specific fills for a subset of keys.

http://blog.ezyang.com/img/pytorch-dispatcher/slide-15.png

When you register a kernel as a fallback for kernel for a single dispatch key, you fill in the column for that dispatch key (green).

http://blog.ezyang.com/img/pytorch-dispatcher/slide-16.png

There's a precedence to these registrations: exact kernel registrations have the highest precedence, and catch all kernels take precedence over fallback.

http://blog.ezyang.com/img/pytorch-dispatcher/slide-17.png

Boxing and unboxing

I want to spend the last part of this post talking about the boxing and unboxing facilities in our dispatcher, which turn out to be pretty important for enabling backend fallback. When you are a programming language designer, there is a classic tradeoff you have to make in deciding whether or not you want to use a boxed or unboxed representation for data:

http://blog.ezyang.com/img/pytorch-dispatcher/slide-18.png

A boxed or homogenous representation is a data representation where every type of object in your system has the same layout. Typically, this means you have some representation that has a header describing what the object in question is, and then some regular payload after it. Homogenous representations are easy to work with in code: because you can always assume that data has some regular layout, you can write functions that work polymorphically over any type of data (think of a function in Java that takes in an arbitrary Object, for example). Most garbage-collected languages have some boxed representation for heap objects, because the garbage collector needs to be able to work over any type of heap object.

In contrast, an unboxed or heterogenous representation allows objects to have a different layout depending on the data in question. This is more efficient than a homogenous representation, as each object can tailor its internal representation to exactly what is needed for the task at hand. However, the downside is we can no longer easily write a single function that works polymorphically over many types of objects. In C++, this problem is worked around using templates: if you need a function to work on multiple types, the C++ compiler will literally create a new copy of the function specialized to each type it is used with.

http://blog.ezyang.com/img/pytorch-dispatcher/slide-19.png

By default, C++ defaults heterogenous layout, but we have implemented homogenous layout in PyTorch by way of the IValue struct (short for interpreter value), which implements a boxed representation that we can use in our interpreter. An IValue is a two word structure consisting of a payload word (usually a pointer, but it could also be an integer or float directly packed into the field) and a tag word which tells us what kind of value the IValue is.

This means we have two calling conventions for functions in PyTorch: the usual, C++, unboxed convention, and a boxed convention using IValues on a stack. Calls (from end users) can come from unboxed API (direct C++ call) or boxed API (from the JIT interpreter); similarly, kernels can be implemented as direct C++ functions (unboxed convention), or can be implemented as a boxed fallback (which by necessity is boxed, as they are polymorphic over all operators).

If I call from boxed API to a boxed fallback, it's easy to see how to plug the two components together...

http://blog.ezyang.com/img/pytorch-dispatcher/slide-22.png

...but how do I get from the unboxed API to the boxed fallback?

http://blog.ezyang.com/img/pytorch-dispatcher/slide-23.png

We need some sort of adapter to take the unboxed inputs and turn them into IValues so that they can be passed via the boxed calling convention. This is done via a boxing adapter, which is automatically generated using C++ templates working off of the unboxed C++ types in the outward facing API.

http://blog.ezyang.com/img/pytorch-dispatcher/slide-24.png

There is also an inverse problem, which is what to do if we have inputs from an boxed API and need to call into an unboxed kernel. Similarly, we have an unboxing adapter, which performs this translation. Unlike the boxing adapter, this adapter is applied to the kernel itself, since C++ templates only work at sites where the unboxed type is statically available (at the boxed API site, these types are not known, so you literally cannot implement this.) Note that we always keep the unboxed API around, so that if a user calls in from the unboxed API, we can fastpath straight to the unboxed kernel.

http://blog.ezyang.com/img/pytorch-dispatcher/slide-25.png

So here is what boxing and unboxing looks overall:

http://blog.ezyang.com/img/pytorch-dispatcher/slide-26.png

Boxing and unboxing are a key feature in the implementation of boxed fallback: without them, we could not let people write single kernels which would work everywhere (and indeed, in the past, people would write code generators to generate repetitive kernels for every function). With template-based boxing and unboxing, you can write a single boxed kernel, and then have it work for operators, even if those operators are defined externally from the library.

Conclusion

http://blog.ezyang.com/img/pytorch-dispatcher/slide-27.png

So that's PyTorch's dispatcher in a nutshell! The dispatcher is still being continuously worked on; for example, Ailing Zhang recently landed a rework of how autograd dispatch keys are handled, which means that we actually no longer have a single Autograd key but have split autograd keys for AutogradCPU/AutogradCUDA/... We're generally interested in improving the user experience for people who register kernels to the dispatcher. Let us know if you have any questions or comments!

  • September 10, 2020

Dynamic scoping is an effect, implicit parameters are a coeffect

For the longest time, I thought of implicit parameters and dynamic scoping were basically the same thing, since they both can be used to solve similar problems (e.g., the so called "configuration problem" where you need to plumb down some configuration deep into a nested body of function definitions without defining them all explicitly). But implicit parameters have a reputation of being something you shouldn't use (use reflection instead), whereas dynamic scoping via the reader monad is a useful and well understood construct (except for the bit where you have to monadify everything). Why the difference?

Oleg points out that implicit parameters are not really dynamic scoping, and gives an example where Lisp and Haskell disagree. And you don't even want the Lisp behavior in Haskell: if you think about the operational notion of dynamic scoping (walk up the stack until you find a binding site of the dynamic variable), it's not very compatible with laziness, since a thunk (which accesses a dynamic variable) will be forced at some unpredictable point in program execution. You really don't want to have to reason about where exactly a thunk will be executed to know how its dynamic variables will be bound, that way lies madness. But somehow, in a strict language, no one has trouble figuring out what should happen with dynamic scoping (well, mostly--more on this shortly).

It turns out that the research community has figured out the difference is that implicit parameters are a coeffect. I believe this was first observed in Coeffects: Unified static analysis of context-dependence (a more modern presentation is in Coeffects: A calculus of context-dependent computation; and a more Haskelly presentation can be found in Embedding effect systems in Haskell). Although, Tomas was commenting on my blog in 2012 about similar ideas, so this probably had been in the works for a while. The key point is that for some coeffects (namely, implicit parameters), call-by-name reduction preserves types and coeffects, and so implicit parameters do not blow up in your face in the same way dynamic scoping (an effect) would. These necessarily behave differently! Type classes are coeffects too, and this is why modern use of implicit parameters in Haskell explicitly acknowledges this (e.g., in the reflection package).

At this year's ICFP, I was pointed at an interesting technical report about implicit values and functions in Koka, a new twist on the dynamic scoping. I found myself wondering if Haskell implicit parameters could learn a thing or two from this work. Implicit values make the good choice of defining implicit values globally at the top level, so that they can participate in normal module namespacing, as opposed to an un-namespaced bag of dynamically scoped names (this is also an improvement that reflection makes over implicit parameters). But actually, it seems to me that implicit functions are taking a page from implicit parameters!

The big innovation is the implicit function is that it resolves all dynamic references in the function (not just lexically, but for all further dynamic calls) to the lexical scope (the dynamic scope at the time the function was defined), producing a function that has no dependence on implicit values (aka, has no effect saying that the implicit value must be defined at the time the function is called.) This is exactly what an implicit parameter let ?x = ... binding would have done, in effect directly filling in the dictionary for the implicit function at definition site, rather than waiting. Very contextual! (Of course, Koka implements this using algebraic effects, and gets to the right semantics with a very simple translation anyway). The result is not exactly dynamic scoping, but as the TR says, it leads to better abstraction.

It is difficult to see how implicit values/functions could make their way back into Haskell, at least without some sequencing constructing (e.g., a monad) lurking around. Though implicit functions behave much like implicit parameters, the rest of the dynamic scoping (including the binding of the implicit function itself) is just good old effectful (not coeffectful) dynamic scope. And you can't just do that in Haskell, without breaking type preservation under beta-reduction and eta-expansion. Haskell has no choice but to go all the way, and once you get beyond the obvious problems of implicit parameters (which reflection fixes), things seem to mostly work out.

  • August 27, 2020

A brief taxonomy of PyTorch operators by shape behavior

I've recently been working on a revamp of how we specify tensor shape formulas in PyTorch. As part of this process, I classified every single operator in PyTorch by its shaping behavior; yes, that's all 1364 of them (this includes each variant of an operator; e.g., inplace and out= keyword variants). During the process, I tried to come up with categories to help classify what operators did. One of the surprises from the process was discovering that shaping behaviors that I previously thought were uncommon, actually showed up a bit more often than one might have expected.

These categories are interesting in their own right and can be used to help understand how PyTorch's API fits together. Here are all the categories I devised.

TensorIterator (505, e.g., add, sum) operators are PyTorch's bread and butter; these operators do pointwise operations and reductions and support broadcasting and type promotion. The name TensorIterator refers to an internal abstraction we have in PyTorch for implementing these operations; you can read more about it on the wiki and in this blog post. TensorIterator is a real workhorse in PyTorch: the plurarity (though not majority) of operators are implemented in this way! Note that this category includes some functions that used equivalent, legacy functionality (but did not exactly use TensorIterator).

Fixed (273, e.g., convolution, addbmm) operators are operators which only work on a fixed number of dimensions. This assumption makes writing efficient kernels a lot easier, as indexing math is simple with fixed dimensionality. (For example, TensorAccessor is an internal class which lets you view a tensor at fixed dimensionality known at compile time). Sometimes, the first dimension is treated as a batch dimension, but not always (unfortunately, I didn't distinguish these cases in my dataset). Some fixed operators actually support multiple dimensions, but only a fixed number of them; for example, because we only support 1-3D convolutions, this counts as fixed. (Compare with this FeatureBatched, below!)

N-Dimensional (107, e.g., squeeze, index_add, tensordot) operators are operators which work generically on tensors of arbitrary dimensionality. These are the operations for which it is difficult to write generic shaping rules for in symbolic form, as you need a language that can talk about list manipulations. An important subclass of N-dimensional operators are Identity (42, e.g., clone, contiguous; not included in the count above) operators work over arbitrary dimensionality, but they always return a tensor with the same size as their input. Another subclass are Flatten (11, e.g. take, bucketize) operators which accept tensors of any dimensionality, but always treat them as 1D tensors internally.

Composite (95, e.g., kl_div, isfinite) operators are implemented in other operators, and don't themselves have shape checking (instead, they rely on the operations they call to check shapes). Note this category is probably a bit underreported, as in some cases when it was obvious what the underlying behavior of an operator was, I classified the operator as that category, rather than Composite.

Batched (94, e.g., nll_loss, adaptive_avg_pool2d) operators are like fixed dimensionality operators, except they accept an arbitrary number of batch dimensions at their beginning. Many fixed operators should be batched operators; others cannot be converted into batched operators without introducing ambiguity as to where the batch dimensions end. Compare these with FeatureBatched (19, e.g., batch_norm, embedding) operators, which are like batched operators, but rather than accept batch dimensions at the beginning, they accept an arbitrary number of feature dimensions at the end.

Factory (90, e.g., empty) operators produce new tensors without having any tensor inputs.

Trivial (59, e.g., size, is_floating_point) operators aren't actual tensor operations, but ways to return non-Tensor information or access internal data structures

Sparse (40) operators are special because their size calculations take account of both dense and sparse dimensions.

Dynamic (15, e.g., unique) operators produce outputs whose shapes depend on the data of their input tensors

Variadic (14, e.g., cat) operators take multiple input tensors; similar to n-dimensional operations they are difficult to capture symbolic

You can take a look at the full data set at https://docs.google.com/spreadsheets/d/e/2PACX-1vQQFW0T_bucT5KZn0BHYTC1KYhkL6ZMG5ZxQWc6UmAkHUDYpqkpzXnsb59uv2TB0Jgc1Q6qO63bx6WQ/pubhtml

  • May 6, 2020

vmap in Haskell

vmap is an interface popularized by JAX which offers you a vectorizing map. Semantically, a vmap is exactly equivalent to a map in Haskell; the key difference is that operations run under a vmap are vectorized. If you map a convolution and a matrix multiply, you will have one big loop which repeatedly calls convolution and matrix multiply for each entry in your batch. If you vmap a convolution and matrix multiply, you'll call the batched versions of convolution and matrix multiply once. Unless you have a fuser, on most modern deep learning frameworks, calling the batched implementations of these operations will be much faster.

JAX implements vmap in a somewhat complicated fashion; they have a "batched interpreter" which translates operations on primitives into their batched versions, and have to track metadata about what tensors are batched and in what way so that they can insert appropriate broadcasts and unsqueezes. I mentioned this to Simon Peyton Jones, and he immediately asked, couldn't Haskell's typechecker work this out automatically? The answer is, yes! All of the book-keeping JAX has to do is effectively doing runtime type inference; if you have a compiler that can do it for you at compile time, there is nearly nothing to implement.

To give away the punchline, we are going to implement a family of functions vmap that will run these two examples:

example1 :: [Float] -> [Float] -> [Float]
example1 a0 b0 =
  vmap0_2 (\a b -> add a b) a0 b0

example2 :: [Float] -> [Float] -> [[Float]]
example2 a0 b0 =
  vmap0 (\a -> vmap1 (\b -> add a b) b0) a0

When run in an interpreter, we will see:

*Test> example1 [1,2,3] [4,6,8]
[5.0,8.0,11.0]
*Test> example2 [1,2,3] [4,6,8]
[[5.0,7.0,9.0],[6.0,8.0,10.0],[7.0,9.0,11.0]]

These results are equivalent to what you would have gotten using a plain old map; however, there will be no loop in the implementation of vmap. (The fact that we can't write a single vmap that works universally is due to a limitation in Haskell; we'll discuss this more later.)


We're going to need a few language extensions, so let's get this out of the way first:

{-# LANGUAGE RankNTypes, GADTs, MultiParamTypeClasses,
             KindSignatures, TypeApplications, FunctionalDependencies,
             FlexibleContexts, FlexibleInstances, UndecidableInstances,
             IncoherentInstances #-}

Our plan of attack is that we want to write the definitions of vmap so that we infer a type for add which makes the necessary broadcasting clear. A trivial implementation of vmap would have the signature ([a] -> [b]) -> [a] -> [b] (aka the identity function), but the standard list type doesn't let us distinguish between dimensions we should broadcast together, and dimensions we shouldn't (this is the reason example1 and example2 give different results: in example2, we broadcast along each dimension separately, so that we end up with a cartesian product in the end; in example1, we broadcast the dimensions together and get the zippy behavior). Each distinct invocation of vmap should give us a new dimension, which ought not to be mixed up with other invocations of vmap. When you hear this in Haskell, your first instinct should be, "I know, let's use a rank 2 type!" vmap moves us from the non-type-branded world of vanilla lists [Float] to a type-branded world of size-indexed vectors Vec s Float, where the s variables are all skolem variables bound by our rank 2 type:

data Vec s a = Vec { unVec :: [a] }
instance Functor (Vec s) where
  fmap f (Vec xs) = Vec (map f xs)

vmap0 :: (forall s. Vec s a -> Vec s b) -> [a] -> [b]
vmap0 f = unVec . f . Vec

The implementation of vmap0 doesn't do anything: we just wrap the lists into their type-branded equivalent vectors. We can also provide a 2-ary version of vmap0, which takes two lists and assigns them the same type branding all at once:

vmap0_2 :: (forall s. Vec s a -> Vec s b -> Vec s c) -> [a] -> [b] -> [c]
vmap0_2 f a b = unVec (f (Vec a) (Vec b))

(In principle, some sort of applicative-y thing should make it possible to write just a vap (analogous to ap) and then get all of the n-ary versions for free, but in my brief investigation I didn't see a good way of doing this.)

When we nest vmap, it may be the case that the function doesn't directly return a Vec s b, but a functor containing Vec s b. vmap1 handles this case (we'll discuss this more shortly):

vmap1 :: Functor f => (forall s. Vec s a -> f (Vec s b)) -> [a] -> f [b]
vmap1 f = fmap unVec . f . Vec

With our implementations of vmap in hand, we can take a look at our examples and ask Haskell what the type of add ought to be, if we didn't have an implementation of it:

example1 :: [Float] -> [Float] -> [Float]
example1 a0 b0 =
  vmap0_2 (\a b -> _add a b) a0 b0

Gives:

• Found hole: _add :: Vec s Float -> Vec s Float -> Vec s Float
  Where: ‘s’ is a rigid type variable bound by
           a type expected by the context:
             forall s. Vec s Float -> Vec s Float -> Vec s Float

However:

example2 :: [Float] -> [Float] -> [[Float]]
example2 a0 b0 =
  vmap0 (\a -> vmap1 (\b -> _add a b) b0) a0

Gives:

• Found hole:
    _add :: Vec s Float -> Vec s1 Float -> Vec s (Vec s1 Float)
  Where: ‘s1’ is a rigid type variable bound by
           a type expected by the context:
             forall s1. Vec s1 Float -> Vec s (Vec s1 Float)
           at test.hs:41:20-44
         ‘s’ is a rigid type variable bound by
           a type expected by the context:
             forall s. Vec s Float -> Vec s [Float]
           at test.hs:41:7-48

Notice that the inferred types of _add are different in these two cases: in the first example, we infer that we have two tensors batched in the same way, and we want to "zip" them together. In the second example, we see that each tensor has a distinct batch dimension, and we end up with a 2-D result!

At this point, the job of vmap is done: our holes have types which we can use to determine what the necessary behavior is. You could use these types to select an appropriate kernel to perform vectorized addition. But I promised runnable code, so let's implement a simple version of add using old fashioned map.

The good old fashioned way to do type level computation in Haskell is with a type class, of course! Let's define a multi-parameter type class for the function add; unlike the definition of (+) in Num, we'll let the inputs and output all have different types:

class Add a b c | a b -> c where
  add :: a -> b -> c

We can easily implement addition on plain floating point:

instance Add Float Float Float where
  add = (+)

If I pass add two arguments whose outer-most vector agree in their type brand (aka, they came from the same vmap), I should zip them together, as I did in example1. I can write another instance to express this logic:

instance Add a b r  => Add (Vec s a) (Vec s b) (Vec s r) where
  add (Vec a) (Vec b) = Vec (zipWith add a b)

Otherwise, I should broadcast one of the dimensions and then do an addition on the inside. This choice can't easily be made locally, so I have to define these two incoherent instances:

instance Add a b r => Add (Vec s a) b (Vec s r) where
  add (Vec a) b = Vec (map (\x -> add x b) a)

instance Add a b r => Add a (Vec s b) (Vec s r) where
  add a (Vec b) = Vec (map (\x -> add a x) b)

(GHC's type class resolution engine doesn't backtrack, so I'm not actually sure how it manages to pick the correct instance to use, but in my testing, I got the right instance no matter what order I specified the arguments to add.)

That's it! Running the two examples:

example1 :: [Float] -> [Float] -> [Float]
example1 a0 b0 =
  vmap0_2 (\a b -> add a b) a0 b0

example2 :: [Float] -> [Float] -> [[Float]]
example2 a0 b0 =
  vmap0 (\a -> vmap1 (\b -> add a b) b0) a0

I get:

*Test> example1 [1,2,3] [4,6,8]
[5.0,8.0,11.0]
*Test> example2 [1,2,3] [4,6,8]
[[5.0,7.0,9.0],[6.0,8.0,10.0],[7.0,9.0,11.0]]

So there you have it! vmap in less than a dozen lines of Haskell. One unsatisfactory thing about this implementation is the necessity to define vmap0, vmap1, etc. Can't we just define a generic vmapG ::  (forall s. Vec s a -> f (Vec s b)) -> [a] -> f [b] and have f unify with, well, the identity type lambda /\a. a when we need it to have the type of vmap0? Regretfully, type inference with type lambdas is undecidable (the so-called higher-order unification problem), so it seem we have to help GHC out here, even though in our particular case the unification we can do here is very restricted.

  • January 29, 2020

PyTorch internals

This post is a long form essay version of a talk about PyTorch internals, that I gave at the PyTorch NYC meetup on May 14, 2019.

http://blog.ezyang.com/img/pytorch-internals/slide-01.png

Hi everyone! Today I want to talk about the internals of PyTorch.

http://blog.ezyang.com/img/pytorch-internals/slide-02.png

This talk is for those of you who have used PyTorch, and thought to yourself, "It would be great if I could contribute to PyTorch," but were scared by PyTorch's behemoth of a C++ codebase. I'm not going to lie: the PyTorch codebase can be a bit overwhelming at times. The purpose of this talk is to put a map in your hands: to tell you about the basic conceptual structure of a "tensor library that supports automatic differentiation", and give you some tools and tricks for finding your way around the codebase. I'm going to assume that you've written some PyTorch before, but haven't necessarily delved deeper into how a machine learning library is written.

http://blog.ezyang.com/img/pytorch-internals/slide-03.png

The talk is in two parts: in the first part, I'm going to first introduce you to the conceptual universe of a tensor library. I'll start by talking about the tensor data type you know and love, and give a more detailed discussion about what exactly this data type provides, which will lead us to a better understanding of how it is actually implemented under the hood. If you're an advanced user of PyTorch, you'll be familiar with most of this material. We'll also talk about the trinity of "extension points", layout, device and dtype, which guide how we think about extensions to the tensor class. In the live talk at PyTorch NYC, I skipped the slides about autograd, but I'll talk a little bit about them in these notes as well.

The second part grapples with the actual nitty gritty details involved with actually coding in PyTorch. I'll tell you how to cut your way through swaths of autograd code, what code actually matters and what is legacy, and also all of the cool tools that PyTorch gives you for writing kernels.


http://blog.ezyang.com/img/pytorch-internals/slide-04.png
http://blog.ezyang.com/img/pytorch-internals/slide-05.png

The tensor is the central data structure in PyTorch. You probably have a pretty good idea about what a tensor intuitively represents: its an n-dimensional data structure containing some sort of scalar type, e.g., floats, ints, et cetera. We can think of a tensor as consisting of some data, and then some metadata describing the size of the tensor, the type of the elements in contains (dtype), what device the tensor lives on (CPU memory? CUDA memory?)

http://blog.ezyang.com/img/pytorch-internals/slide-06.png

There's also a little piece of metadata you might be less familiar with: the stride. Strides are actually one of the distinctive features of PyTorch, so it's worth discussing them a little more.

http://blog.ezyang.com/img/pytorch-internals/slide-07.png

A tensor is a mathematical concept. But to represent it on our computers, we have to define some sort of physical representation for them. The most common representation is to lay out each element of the tensor contiguously in memory (that's where the term contiguous comes from), writing out each row to memory, as you see above. In the example above, I've specified that the tensor contains 32-bit integers, so you can see that each integer lies in a physical address, each offset four bytes from each other. To remember what the actual dimensions of the tensor are, we have to also record what the sizes are as extra metadata.

So, what do strides have to do with this picture?

http://blog.ezyang.com/img/pytorch-internals/slide-08.png

Suppose that I want to access the element at position tensor[1, 0] in my logical representation. How do I translate this logical position into a location in physical memory? Strides tell me how to do this: to find out where any element for a tensor lives, I multiply each index with the respective stride for that dimension, and sum them all together. In the picture above, I've color coded the first dimension blue and the second dimension red, so you can follow the index and stride in the stride calculation. Doing this sum, I get two (zero-indexed), and indeed, the number three lives two below the beginning of the contiguous array.

(Later in the talk, I'll talk about TensorAccessor, a convenience class that handles the indexing calculation. When you use TensorAccessor, rather than raw pointers, this calculation is handled under the covers for you.)

Strides are the fundamental basis of how we provide views to PyTorch users. For example, suppose that I want to extract out a tensor that represents the second row of the tensor above:

http://blog.ezyang.com/img/pytorch-internals/slide-09.png

Using advanced indexing support, I can just write tensor[1, :] to get this row. Here's the important thing: when I do this, I don't create a new tensor; instead, I just return a tensor which is a different view on the underlying data. This means that if I, for example, edit the data in that view, it will be reflected in the original tensor. In this case, it's not too hard to see how to do this: three and four live in contiguous memory, and all we need to do is record an offset saying that the data of this (logical) tensor lives two down from the top. (Every tensor records an offset, but most of the time it's zero, and I'll omit it from my diagrams when that's the case.)

Question from the talk: If I take a view on a tensor, how do I free the memory of the underlying tensor?

Answer: You have to make a copy of the view, thus disconnecting it from the original physical memory. There's really not much else you can do. By the way, if you have written Java in the old days, taking substrings of strings has a similar problem, because by default no copy is made, so the substring retains the (possibly very large string). Apparently, they fixed this in Java 7u6.

A more interesting case is if I want to take the first column:

http://blog.ezyang.com/img/pytorch-internals/slide-10.png

When we look at the physical memory, we see that the elements of the column are not contiguous: there's a gap of one element between each one. Here, strides come to the rescue: instead of specifying a stride of one, we specify a stride of two, saying that between one element and the next, you need to jump two slots. (By the way, this is why it's called a "stride": if we think of an index as walking across the layout, the stride says how many locations we stride forward every time we take a step.)

The stride representation can actually let you represent all sorts of interesting views on tensors; if you want to play around with the possibilities, check out the Stride Visualizer.

Let's step back for a moment, and think about how we would actually implement this functionality (after all, this is an internals talk.) If we can have views on tensor, this means we have to decouple the notion of the tensor (the user-visible concept that you know and love), and the actual physical data that stores the data of the tensor (called storage):

http://blog.ezyang.com/img/pytorch-internals/slide-11.png

There may be multiple tensors which share the same storage. Storage defines the dtype and physical size of the tensor, while each tensor records the sizes, strides and offset, defining the logical interpretation of the physical memory.

One thing to realize is that there is always a pair of Tensor-Storage, even for "simple" cases where you don't really need a storage (e.g., you just allocated a contiguous tensor with torch.zeros(2, 2)).

By the way, we're interested in making this picture not true; instead of having a separate concept of storage, just define a view to be a tensor that is backed by a base tensor. This is a little more complicated, but it has the benefit that contiguous tensors get a much more direct representation without the Storage indirection. A change like this would make PyTorch's internal representation a bit more like Numpy's.

We've talked quite a bit about the data layout of tensor (some might say, if you get the data representation right, everything else falls in place). But it's also worth briefly talking about how operations on the tensor are implemented. At the very most abstract level, when you call torch.mm, two dispatches happen:

http://blog.ezyang.com/img/pytorch-internals/slide-12.png

The first dispatch is based on the device type and layout of a tensor: e.g., whether or not it is a CPU tensor or a CUDA tensor (and also, e.g., whether or not it is a strided tensor or a sparse one). This is a dynamic dispatch: it's a virtual function call (exactly where that virtual function call occurs will be the subject of the second half of this talk). It should make sense that you need to do a dispatch here: the implementation of CPU matrix multiply is quite different from a CUDA implementation. It is a dynamic dispatch because these kernels may live in separate libraries (e.g., libcaffe2.so versus libcaffe2_gpu.so), and so you have no choice: if you want to get into a library that you don't have a direct dependency on, you have to dynamic dispatch your way there.

The second dispatch is a dispatch on the dtype in question. This dispatch is just a simple switch-statement for whatever dtypes a kernel chooses to support. Upon reflection, it should also make sense that we need to a dispatch here: the CPU code (or CUDA code, as it may) that implements multiplication on float is different from the code for int. It stands to reason you need separate kernels for each dtype.

This is probably the most important mental picture to have in your head, if you're trying to understand the way operators in PyTorch are invoked. We'll return to this picture when it's time to look more at code.


http://blog.ezyang.com/img/pytorch-internals/slide-13.png

Since we have been talking about Tensor, I also want to take a little time to the world of tensor extensions. After all, there's more to life than dense, CPU float tensors. There's all sorts of interesting extensions going on, like XLA tensors, or quantized tensors, or MKL-DNN tensors, and one of the things we have to think about, as a tensor library, is how to accommodate these extensions.

http://blog.ezyang.com/img/pytorch-internals/slide-14.png

Our current model for extensions offers four extension points on tensors. First, there is the trinity three parameters which uniquely determine what a tensor is:

  • The device, the description of where the tensor's physical memory is actually stored, e.g., on a CPU, on an NVIDIA GPU (cuda), or perhaps on an AMD GPU (hip) or a TPU (xla). The distinguishing characteristic of a device is that it has its own allocator, that doesn't work with any other device.
  • The layout, which describes how we logically interpret this physical memory. The most common layout is a strided tensor, but sparse tensors have a different layout involving a pair of tensors, one for indices, and one for data; MKL-DNN tensors may have even more exotic layout, like blocked layout, which can't be represented using merely strides.
  • The dtype, which describes what it is that is actually stored in each element of the tensor. This could be floats or integers, or it could be, for example, quantized integers.

If you want to add an extension to PyTorch tensors (by the way, if that's what you want to do, please talk to us! None of these things can be done out-of-tree at the moment), you should think about which of these parameters you would extend. The Cartesian product of these parameters define all of the possible tensors you can make. Now, not all of these combinations may actually have kernels (who's got kernels for sparse, quantized tensors on FPGA?) but in principle the combination could make sense, and thus we support expressing it, at the very least.

There's one last way you can make an "extension" to Tensor functionality, and that's write a wrapper class around PyTorch tensors that implements your object type. This perhaps sounds obvious, but sometimes people reach for extending one of the three parameters when they should have just made a wrapper class instead. One notable merit of wrapper classes is they can be developed entirely out of tree.

When should you write a tensor wrapper, versus extending PyTorch itself? The key test is whether or not you need to pass this tensor along during the autograd backwards pass. This test, for example, tells us that sparse tensor should be a true tensor extension, and not just a Python object that contains an indices and values tensor: when doing optimization on networks involving embeddings, we want the gradient generated by the embedding to be sparse.

http://blog.ezyang.com/img/pytorch-internals/slide-15.png

Our philosophy on extensions also has an impact of the data layout of tensor itself. One thing we really want out of our tensor struct is for it to have a fixed layout: we don't want fundamental (and very frequently called) operations like "What's the size of a tensor?" to require virtual dispatches. So when you look at the actual layout of a Tensor (defined in the TensorImpl struct), what we see is a common prefix of all fields that we consider all "tensor"-like things to universally have, plus a few fields that are only really applicable for strided tensors, but are so important that we've kept them in the main struct, and then a suffix of custom fields that can be done on a per-Tensor basis. Sparse tensors, for example, store their indices and values in this suffix.


http://blog.ezyang.com/img/pytorch-internals/slide-16.png

I told you all about tensors, but if that was the only thing PyTorch provided, we'd basically just be a Numpy clone. The distinguishing characteristic of PyTorch when it was originally released was that it provided automatic differentiation on tensors (these days, we have other cool features like TorchScript; but back then, this was it!)

What does automatic differentiation do? It's the machinery that's responsible for taking a neural network:

http://blog.ezyang.com/img/pytorch-internals/slide-17.png

...and fill in the missing code that actually computes the gradients of your network:

http://blog.ezyang.com/img/pytorch-internals/slide-18.png

Take a moment to study this diagram. There's a lot to unpack; here's what to look at:

  1. First, rest your eyes on the variables in red and blue. PyTorch implements reverse-mode automatic differentiation, which means that we effectively walk the forward computations "backward" to compute the gradients. You can see this if you look at the variable names: at the bottom of the red, we compute loss; then, the first thing we do in the blue part of the program is compute grad_loss. loss was computed from next_h2, so we compute grad_next_h2. Technically, these variables which we call grad_ are not really gradients; they're really Jacobians left-multiplied by a vector, but in PyTorch we just call them grad and mostly everyone knows what we mean.
  2. If the structure of the code stays the same, the behavior doesn't: each line from forwards is replaced with a different computation, that represents the derivative of the forward operation. For example, the tanh operation is translated into a tanh_backward operation (these two lines are connected via a grey line on the left hand side of the diagram). The inputs and outputs of the forward and backward operations are swapped: if the forward operation produced next_h2, the backward operation takes grad_next_h2 as an input.

The whole point of autograd is to do the computation that is described by this diagram, but without actually ever generating this source. PyTorch autograd doesn't do a source-to-source transformation (though PyTorch JIT does know how to do symbolic differentiation).

http://blog.ezyang.com/img/pytorch-internals/slide-19.png

To do this, we need to store more metadata when we carry out operations on tensors. Let's adjust our picture of the tensor data structure: now instead of just a tensor which points to a storage, we now have a variable which wraps this tensor, and also stores more information (AutogradMeta), which is needed for performing autograd when a user calls loss.backward() in their PyTorch script.

This is yet another slide which will hopefully be out of date in the near future. Will Feng is working on a Variable-Tensor merge in C++, following a simple merge which happened to PyTorch's frontend interface.

We also have to update our picture about dispatch:

http://blog.ezyang.com/img/pytorch-internals/slide-20.png

Before we dispatch to CPU or CUDA implementations, there is another dispatch on variables, which is responsible for unwrapping variables, calling the underlying implementation (in green), and then rewrapping the results into variables and recording the necessary autograd metadata for backwards.

Some implementations don't unwrap; they just call into other variable implementations. So you might spend a while in the Variable universe. However, once you unwrap and go into the non-Variable Tensor universe, that's it; you never go back to Variable (except by returning from your function.)


In my NY meetup talk, I skipped the following seven slides. I'm also going to delay writeup for them; you'll have to wait for the sequel for some text.

http://blog.ezyang.com/img/pytorch-internals/slide-21.png
http://blog.ezyang.com/img/pytorch-internals/slide-22.png
http://blog.ezyang.com/img/pytorch-internals/slide-23.png
http://blog.ezyang.com/img/pytorch-internals/slide-24.png
http://blog.ezyang.com/img/pytorch-internals/slide-25.png
http://blog.ezyang.com/img/pytorch-internals/slide-26.png
http://blog.ezyang.com/img/pytorch-internals/slide-27.png

http://blog.ezyang.com/img/pytorch-internals/slide-28.png

Enough about concepts, let's look at some code.

http://blog.ezyang.com/img/pytorch-internals/slide-29.png

PyTorch has a lot of folders, and there is a very detailed description of what they are in the CONTRIBUTING document, but really, there are only four directories you really need to know about:

http://blog.ezyang.com/img/pytorch-internals/slide-30.png
  • First, torch/ contains what you are most familiar with: the actual Python modules that you import and use. This stuff is Python code and easy to hack on (just make a change and see what happens). However, lurking not too deep below the surface is...
  • torch/csrc/, the C++ code that implements what you might call the frontend of PyTorch. In more descriptive terms, it implements the binding code that translates between the Python and C++ universe, and also some pretty important pieces of PyTorch, like the autograd engine and the JIT compiler. It also contains the C++ frontend code.
  • aten/, short for "A Tensor Library" (coined by Zachary DeVito), is a C++ library that implements the operations of Tensors. If you're looking for where some kernel code lives, chances are it's in ATen. ATen itself bifurcates into two neighborhoods of operators: the "native" operators, which are modern, C++ implementations of operators, and the "legacy" operators (TH, THC, THNN, THCUNN), which are legacy, C implementations. The legacy operators are the bad part of town; try not to spend too much time there if you can.
  • c10/, which is a pun on Caffe2 and A"Ten" (get it? Caffe 10) contains the core abstractions of PyTorch, including the actual implementations of the Tensor and Storage data structures.

That's a lot of places to look for code; we should probably simplify the directory structure, but that's how it is. If you're trying to work on operators, you'll spend most of your time in aten.

Let's see how this separation of code breaks down in practice:

http://blog.ezyang.com/img/pytorch-internals/slide-31.png

When you call a function like torch.add, what actually happens? If you remember the discussion we had about dispatching, you already have the basic picture in your head:

  1. We have to translate from Python realm to the C++ realm (Python argument parsing)
  2. We handle variable dispatch (VariableType--Type, by the way, doesn't really have anything to do programming language types, and is just a gadget for doing dispatch.)
  3. We handle device type / layout dispatch (Type)
  4. We have the actual kernel, which is either a modern native function, or a legacy TH function.

Each of these steps corresponds concretely to some code. Let's cut our way through the jungle.

http://blog.ezyang.com/img/pytorch-internals/slide-32.png

Our initial landing point in the C++ code is the C implementation of a Python function, which we've exposed to the Python side as something like torch._C.VariableFunctions.add. THPVariable_add is the implementation of one such implementation.

One important thing to know about this code is that it is auto-generated. If you search in the GitHub repository, you won't find it, because you have to actually build PyTorch to see it. Another important thing is, you don't have to really deeply understand what this code is doing; the idea is to skim over it and get a sense for what it is doing. Above, I've annotated some of the most important bits in blue: you can see that there is a use of a class PythonArgParser to actually pull out C++ objects out of the Python args and kwargs; we then call a dispatch_add function (which I've inlined in red); this releases the global interpreter lock and then calls a plain old method on the C++ Tensor self. On its way back, we rewrap the returned Tensor back into a PyObject.

(At this point, there's an error in the slides: I'm supposed to tell you about the Variable dispatch code. I haven't fixed it here yet. Some magic happens, then...)

http://blog.ezyang.com/img/pytorch-internals/slide-33.png

When we call the add method on the Tensor class, no virtual dispatch happens yet. Instead, we have an inline method which calls a virtual method on a "Type" object. This method is the actual virtual method (this is why I say Type is just a "gadget" that gets you dynamic dispatch.) In the particular case of this example, this virtual call dispatches to an implementation of add on a class named TypeDefault. This happens to be because we have an implementation of add that is the same for every device type (both CPU and CUDA); if we had happened to have different implementations, we might have instead landed on something like CPUFloatType::add. It is this implementation of the virtual method that finally gets us to the actual kernel code.

Hopefully, this slide will be out-of-date very soon too; Roy Li is working on replacing Type dispatch with another mechanism which will help us better support PyTorch on mobile.

It's worth reemphasizing that all of the code, until we got to the kernel, is automatically generated.

http://blog.ezyang.com/img/pytorch-internals/slide-34.png

It's a bit twisty and turny, so once you have some basic orientation about what's going on, I recommend just jumping straight to the kernels.


http://blog.ezyang.com/img/pytorch-internals/slide-35.png

PyTorch offers a lot of useful tools for prospective kernel writers. In this section, we'll walk through a few of them. But first of all, what do you need to write a kernel?

http://blog.ezyang.com/img/pytorch-internals/slide-36.png

We generally think of a kernel in PyTorch consisting of the following parts:

  1. First, there's some metadata which we write about the kernel, which powers the code generation and lets you get all the bindings to Python, without having to write a single line of code.
  2. Once you've gotten to the kernel, you're past the device type / layout dispatch. The first thing you need to write is error checking, to make sure the input tensors are the correct dimensions. (Error checking is really important! Don't skimp on it!)
  3. Next, we generally have to allocate the result tensor which we are going to write the output into.
  4. Time for the kernel proper. At this point, you now should do the second, dtype dispatch, to jump into a kernel which is specialized per dtype it operates on. (You don't want to do this too early, because then you will be uselessly duplicating code that looks the same in any case.)
  5. Most performant kernels need some sort of parallelization, so that you can take advantage of multi-CPU systems. (CUDA kernels are "implicitly" parallelized, since their programming model is built on top of massive parallelization).
  6. Finally, you need to access the data and do the computation you wanted to do!

In the subsequent slides, we'll walk through some of the tools PyTorch has for helping you implementing these steps.

http://blog.ezyang.com/img/pytorch-internals/slide-37.png

To take advantage of all of the code generation which PyTorch brings, you need to write a schema for your operator. The schema gives a mypy-esque type of your function, and also controls whether or not we generate bindings for methods or functions on Tensor. You also tell the schema what implementations of your operator should be called for given device-layout combinations. Check out the README in native is for more information about this format.

http://blog.ezyang.com/img/pytorch-internals/slide-38.png

You also may need to define a derivative for your operation in derivatives.yaml.

http://blog.ezyang.com/img/pytorch-internals/slide-39.png

Error checking can be done by way of either a low level or a high level API. The low level API is just a macro, TORCH_CHECK, which takes a boolean, and then any number of arguments to make up the error string to render if the boolean is not true. One nice thing about this macro is that you can intermix strings with non-string data; everything is formatted using their implementation of operator<<, and most important data types in PyTorch have operator<< implementations.

The high level API saves you from having to write up repetitive error messages over and over again. The way it works is you first wrap each Tensor into a TensorArg, which contains information about where the tensor came from (e.g., its argument name). It then provides a number of pre-canned functions for checking various properties; e.g., checkDim() tests if the tensor's dimensionality is a fixed number. If it's not, the function provides a user-friendly error message based on the TensorArg metadata.

http://blog.ezyang.com/img/pytorch-internals/slide-40.png

One important thing to be aware about when writing operators in PyTorch, is that you are often signing up to write three operators: abs_out, which operates on a preallocated output (this implements the out= keyword argument), abs_, which operates inplace, and abs, which is the plain old functional version of an operator.

Most of the time, abs_out is the real workhorse, and abs and abs_ are just thin wrappers around abs_out; but sometimes writing specialized implementations for each case are warranted.

http://blog.ezyang.com/img/pytorch-internals/slide-41.png

To do dtype dispatch, you should use the AT_DISPATCH_ALL_TYPES macro. This takes in the dtype of the tensor you want to dispatch over, and a lambda which will be specialized for each dtype that is dispatchable from the macro. Usually, this lambda just calls a templated helper function.

This macro doesn't just "do dispatch", it also decides what dtypes your kernel will support. As such, there are actually quite a few versions of this macro, which let you pick different subsets of dtypes to generate specializations for. Most of the time, you'll just want AT_DISPATCH_ALL_TYPES, but keep an eye out for situations when you might want to dispatch to some more types. There's guidance in Dispatch.h for how to select the correct one for your use-case.

http://blog.ezyang.com/img/pytorch-internals/slide-43.png

On CPU, you frequently want to parallelize your code. In the past, this was usually done by directly sprinkling OpenMP pragmas in your code.

http://blog.ezyang.com/img/pytorch-internals/slide-42.png

At some point, we have to actually access the data. PyTorch offers quite a few options for doing this.

  1. If you just want to get a value at some specific location, you should use TensorAccessor. A tensor accessor is like a tensor, but it hard codes the dimensionality and dtype of the tensor as template parameters. When you retrieve an accessor like x.accessor<float, 3>();, we do a runtime test to make sure that the tensor really is this format; but after that, every access is unchecked. Tensor accessors handle strides correctly, so you should prefer using them over raw pointer access (which, unfortunately, some legacy kernels do.) There is also a PackedTensorAccessor, which is specifically useful for sending an accessor over a CUDA launch, so that you can get accessors from inside your CUDA kernel. (One notable gotcha: TensorAccessor defaults to 64-bit indexing, which is much slower than 32-bit indexing in CUDA!)
  2. If you're writing some sort of operator with very regular element access, for example, a pointwise operation, you are much better off using a higher level of abstraction, the TensorIterator. This helper class automatically handles broadcasting and type promotion for you, and is quite handy.
  3. For true speed on CPU, you may need to write your kernel using vectorized CPU instructions. We've got helpers for that too! The Vec256 class represents a vector of scalars and provides a number of methods which perform vectorized operations on them all at once. Helpers like binary_kernel_vec then let you easily run vectorized operations, and then finish everything that doesn't round nicely into vector instructions using plain old instructions. The infrastructure here also manages compiling your kernel multiple times under different instruction sets, and then testing at runtime what instructions your CPU supports, and using the best kernel in those situations.
http://blog.ezyang.com/img/pytorch-internals/slide-44.png

A lot of kernels in PyTorch are still written in the legacy TH style. (By the way, TH stands for TorcH. It's a pretty nice acronym, but unfortunately it is a bit poisoned; if you see TH in the name, assume that it's legacy.) What do I mean by the legacy TH style?

  1. It's written in C style, no (or very little) use of C++.
  2. It's manually refcounted (with manual calls to THTensor_free to decrease refcounts when you're done using tensors), and
  3. It lives in generic/ directory, which means that we are actually going to compile the file multiple times, but with different #define scalar_t.

This code is pretty crazy, and we hate reviewing it, so please don't add to it. One of the more useful tasks that you can do, if you like to code but don't know too much about kernel writing, is to port some of these TH functions to ATen.


http://blog.ezyang.com/img/pytorch-internals/slide-45.png
http://blog.ezyang.com/img/pytorch-internals/slide-46.png

To wrap up, I want to talk a little bit about working efficiently on PyTorch. If the largeness of PyTorch's C++ codebase is the first gatekeeper that stops people from contributing to PyTorch, the efficiency of your workflow is the second gatekeeper. If you try to work on C++ with Python habits, you will have a bad time: it will take forever to recompile PyTorch, and it will take you forever to tell if your changes worked or not.

How to work efficiently could probably be a talk in and of itself, but this slide calls out some of the most common anti-patterns I've seen when someone complains: "It's hard to work on PyTorch."

  1. If you edit a header, especially one that is included by many source files (and especially if it is included by CUDA files), expect a very long rebuild. Try to stick to editing cpp files, and edit headers sparingly!
  2. Our CI is a very wonderful, zero-setup way to test if your changes worked or not. But expect to wait an hour or two before you get back signal. If you are working on a change that will require lots of experimentation, spend the time setting up a local development environment. Similarly, if you run into a hard to debug problem on a specific CI configuration, set it up locally. You can download and run the Docker images locally
  3. The CONTRIBUTING guide explains how to setup ccache; this is highly recommended, because sometimes it will help you get lucky and avoid a massive recompile when you edit a header. It also helps cover up bugs in our build system, when we recompile files when we shouldn't.
  4. At the end of the day, we have a lot of C++ code, and you will have a much more pleasant experience if you build on a beefy server with CPUs and RAM. In particular, I don't recommend doing CUDA builds on a laptop; building CUDA is sloooooow and laptops tend to not have enough juice to turnaround quickly enough.

http://blog.ezyang.com/img/pytorch-internals/slide-47.png

So that's it for a whirlwind tour of PyTorch's internals! Many, many things have been omitted; but hopefully the descriptions and explanations here can help you get a grip on at least a substantial portion of the codebase.

Where should you go from here? What kinds of contributions can you make? A good place to start is our issue tracker. Starting earlier this year, we have been triaging issues; issues labeled triaged mean that at least one PyTorch developer has looked at it and made an initial assessment about the issue. You can use these labels to find out what issues we think are high priority or look up issues specific to some module, e.g., autograd or find issues which we think are small (word of warning: we're sometimes wrong!)

Even if you don't want to get started with coding right away, there are many other useful activities like improving documentation (I love merging documentation PRs, they are so great), helping us reproduce bug reports from other users, and also just helping us discuss RFCs on the issue tracker. PyTorch would not be where it is today without our open source contributors; we hope you can join us too!

  • May 16, 2019

A short note about functional linear maps

Some notes collected from a close read of Conal Elliot's Compiling to Categories and The Simple Essence of Automatic Differentiation.

A colleague of mine was trying to define a "tree structure" of tensors, with the hope of thereby generalizing the concept to also work with tensors that have "ragged dimensions." Let's take a look:

Suppose we have a (2, 3) matrix:

tensor([[1, 2, 3],
        [4, 5, 6]])

One way to think about this is that we have a "tree" of some sort, where the root of the tree branches to two subnodes, and then each subnode branches to three nodes:

       /- ROOT -\
  ROW 1          ROW 2
 /  |  \        /  |  \
1   2   3      4   5   6

Suppose you wanted to define this data structure in Haskell. One obvious way of going about doing this is to just say that a matrix is just a bunch of nested lists, [[Float]]. This works, true, but it isn't very illuminating, and it is certainly not type safe. Type safety could be achieved with sized vectors, but we are still left wondering, "what does it mean?"

Often, inductive definitions fall out of how we compose things together, in the same way that the inductive data structure for a programming language tells us how we take smaller programs and put them together to form a larger program. With matrices, we can think of a pictorial way of composing them, by either attaching matrices together vertically or horizontally. That gives us this vocabulary for putting together matrices, which would let us (non-uniquely) represent every matrix (Compiling to Categories, Section 8):

data Matrix
  = Scalar Float
  | Horizontal Matrix Matrix
  | Vertical Matrix Matrix

But what does it mean? Well, every matrix represents a linear map (if A : (n, m) is your matrix, the linear map is the function R^m -> R^n, defined to be f(x) = A x. We'll call a linear map from a to b, Linear a b). So the question we ask now is, what does it mean to "paste" two matrices together? It's a way of composing two linear maps together into a new linear map:

-- A function definition does not a category make!  You have to
-- prove that the resulting functions are linear.

horizontal :: Linear a c -> Linear b c -> Linear (a, b) c
horizontal f g = \(a, b) -> f a + g b

-- In matrix form:
--
--              [ a ]
-- [ F  |  G ]  [ - ] = [ F a + G b ]
--              [ b ]

vertical :: Linear a c -> Linear a d -> Linear a (c, d)
vertical f g = \a -> (f a, g a)

-- In matrix form:
--
-- [ F ]         [ F a ]
-- [ - ] [ a ] = [  -  ]
-- [ G ]         [ G a ]

Now we're cooking! Notice that the pasting shows up in the type of the linear map: if we paste horizontally, that just means that the vectors this linear map takes in have to be pasted together (with the tuple constructor); similarly, if we paste vertically, we'll produce output vectors that are the pasted results.

Cool, so we can add some type indexes, and write Linear as a GADT to refine the indices when you apply the constructor:

data Linear a b where
  Scalar :: Float -> Linear Float Float
  Horizontal :: Linear a c -> Linear b c -> Linear (a, b) c
  Vertical :: Linear a c -> Linear a d -> Linear a (c, d)

Is this the end of the story? Not quite. There are many ways you can go about combining linear maps; for example, you could (literally) compose two linear maps together (in the same sense of function composition). It's true that you can paste together any matrix you like with the data type above; how do we decide what should and shouldn't go in our language of linear maps?

To this end, Conal Elliot calls on the language of category theory to adjudicate. A category should define identity and function composition:

identity :: Linear a a
identity a = a

-- In matrix form: the identity matrix

compose :: Linear b c -> Linear a b -> Linear a c
compose g f = \a -> g (f a)

-- In matrix form: matrix multiply

We find that Horizontal and Vertical are the elimination and introduction operations of cocartesian and cartesian categories (respectively).

But this should we just slap Identity and Compose constructors to our data type? Linear map composition is a computationally interesting operation: if we just keep it around as syntax (rather than doing what is, morally, a matrix multiply), then it will be quite expensive to do operations on the final linear map. Where do representable functors come in? I'm not exactly sure how to explain this, and I've run out of time for this post; stay tuned for a follow up.

  • May 15, 2019

Microsoft Surface Book 2

Long time readers of mine may be aware that I used a ThinkPad X61T for the past decade. After the hinge on my second instance of the machine, I decided it was finally time to get a new laptop. And I had one particular model on my eye, after Simon Peyton Jones showed me his new laptop at the last Haskell Implementor's Workshop: the Microsoft Surface Book 2. It fits my primary requirement for a laptop: it's a convertible laptop into tablet mode with a digitizer pen. The pen is not Wacom branded but it has an eraser end and can magnetically attach to the laptop (no enclosure for the pen, but I think that for modern hardware that constraint is unsatisfiable.) Furthermore, there is a Linux enthusiast community around the device, which made me feel that it would be more likely I could get Linux to work. So a few weeks ago, I took the plunge, and laid down three grand for my own copy. It has worked out well, but in the classic Linux style, not without a little bit of elbow grease.

A quick review

The good:

  1. I've managed to get all of the "important" functionality to work. That's Xournal with XInput pen and hibernate (though with some caveats.)
  2. Linux support for other random features has pleasantly surprised me: I managed to get a working CUDA install and drivers (for PyTorch development), ability to boot my Linux partition bare metal as well as from a VM in Windows and I can even detach the screen while booted into Linux.
  3. The keyboard is nice; not as good as a classic Thinkpad keyboard but having actual function keys, but it has real function keys (unlike the Macbook Pro I use at work.)
  4. Two standard USB ports as well as a USB-C port means I don't need dongles for most usage (unlike my Macbook Pro, which only has USB-C ports.)

The bad:

  1. (Updated on March 19, 2019) Suspend is really slow. Although jakeday's setup.sh suggests that suspend is not working, something is working, in the sense that if I close my laptop lid, the laptop goes into a low power state of some sort. But it takes quite a long time to suspend, an even longer time to restart, and you still have to click past the bootloader (which makes me seriously wonder if we are actually suspending).
  2. The laptop un-hibernates itself sometimes when I put it in my backpack. My current hypothesis is that the power button is getting pushed (unlike most laptops, the power button is unprotected on the top of the screen). Probably some fucking around with my ACPI settings might help but I haven't looked closely into it yet.
  3. It's a high DPI screen. There's nothing wrong with this per se (and you basically can't buy a non-high DPI laptop these days), but any software that doesn't understand how to do high DPI (VMWare and Xournal, I'm looking at you) looks bad. The support of Ubuntu Unity for high DPI has gotten much better since the last time I've attempted anything like it, however; if I stick to the terminal and browser, things look reasonable.
  4. The function key is hardwired to toggle fn-lock. This is somewhat irritating because you have to remember which setting it's on to decide if you should hold it to get the other toggle. I'm also feeling the loss of dedicated page-up/page-down key.
  5. Apparently, the NVIDIA GPU downthrottles itself due to thermal sensor shenanigans (something something the fan is on the motherboard and not the GPU so the driver thinks the fan is broken and throttles? Mumble.)
  6. The speakers are... OK. Not great, just OK.
  7. It's too bad Microsoft opted for some custom charger for the Surface Book 2.

Linux setup

I did a stock install of the latest Ubuntu LTS (18.04) dual boot with Windows (1TB hard drive helps!), and then installed jakeday's custom Linux kernel and drivers. Some notes about the process:

  • I spent a while scratching my head as to why I couldn't install Linux dual-boot. Some Googling suggested that the problem was that Windows hadn't really shutdown; it had just hibernated (for quick startup). I didn't manage to disable this, so I just resized the Windows partition from inside Windows and then installed Linux on that partition.

  • Don't forget to allocate a dedicated swap partition for Linux; you won't be able to hibernate without it.

  • The Surface Book 2 has secure boot enabled. You must follow the instructions in SIGNING.md to get signed kernels.

  • One consequence of generating signed kernels, is that if you have both the unsigned and signed kernels installed update-initramfs -u will update the initrd for your unsigned kernel, meaning that you won't see your changes unless you copy the initrd over! This flummoxed me a lot about the next step...

  • If you want to use the NVIDIA drivers for your shiny NVIDIA GPU, you need to blacklist nouveau. There are plenty of instructions on the internet but I can personally vouch for remingtonlang's instructions. Make sure you are updating the correct initrd; see my bullet point above. Once this was fixed, a standard invocation of the CUDA installer got me working nvidia-smi. Note that I manually signed the NVIDIA using the instructions here since I already had generated a private key, and it seemed silly to generate another one because NVIDIA's installer asked me to.

  • Once you install the NVIDIA drivers, you have to be careful about the opposite problem: Xorg deciding it wants to do all its rendering on the NVIDIA card! The usual symptom when this occurs is that your mouse input to Linux is very laggy. If you have working nvidia-smi, you can also tell because Xorg will be a running process on your GPU. In any case, this is bad: you do NOT want to use the dGPU for plain old desktop rendering; you want the integrated one. I found that uncommenting the sample Intel config in /etc/X11/xorg.conf.d fixes the problem:

    Section "Device"
        Identifier  "Intel Graphics"
        Driver      "intel"
    EndSection
    

    But this doesn't play too nicely with VMWare; more on this below.

  • Sound did not work (it was too soft, or the right speaker wasn't working) until I upgraded to Linux 5.0.1.

  • After enabling XInput on my fork of Xournal, it did not start working until I restarted Xournal. Eraser worked right out of the box.

  • Don't forget to make a swap partition (Ubuntu default installer didn't prompt me to make one, probably because I was installing as dual-boot); otherwise, hibernate will not work.

  • Sometimes, when waking up from hibernate, networking doesn't work. Mercifully, this can be fixed by manually reloading the WiFi kernel module: modprobe mwifiex_pcie and systemctl restart NetworkManager.service. More discussion on this issue.

  • Sometimes, when waking up from hibernate/suspend, I get a big thermometer icon. When I reboot again it goes away but I have lost my hibernate/suspend. Perplexing! I don't know why this happens.

Boot via VM

The sad truth of life is that the Windows tablet experience is much better than the Linux experience--to the point where many would just install Windows and then boot Linux from a virtual machine (or Windows Subsystem for Linux). This was a non-starter for me: a bare metal boot of Linux was necessary to get the best pen input experience. However, why not also make it possible to boot the Linux partition from VMWare running on Windows? This setup is explicitly supported by VMWare, but it took a few days of fiddling to get it to actually work.

  • First, you need VMWare Workstation Pro to actually configure a VM that accesses raw disk (although the resulting VM image can be run from the free VMWare Player). You can sign up for the thirty-day trial to get it configured, and then use Player from then on, if you like. VMWare will offer the raw disk as an option when setting up disk; pick that and select the Linux partitions on your machine.
  • The primary challenge of setting up this system is that a standard install of Linux on the Surface Book 2 doesn't have a traditional Linux boot partition; instead, it has an EFI partition. Most notably, this partition is permanently mounted by Windows on boot up, so you can't remount it for your VM. Your regular partition doesn't have a bootloader, which is why when you turn on your VM, you get kicked into network boot via PXE. The workaround I ended up applying is to make a new, fake disk (vmdk-backed) and install the boot partition onto that (you don't actually need any of the kernels or initrds, since they live on your root filesystem; only /boot/efi is mounted from the EFI partition). Of course, you have to actually setup this boot partition; the way I did it was to chroot into my partition on a rescue CD and then run grub-install /dev/sda1. In the course of fiddling, I also accidentally ran update-grub which blew away my Windows boot option, but re-running this command when booted into Linux bare-metal fixed the problem (because the real /boot/efi will mount and thus Grub will find the Windows boot option.)
  • Some documentation about dual-boot is specific to VMWare Fusion. This is OS X specific, so not relevant to the Microsoft Surface Book 2.
  • Get yourself a bootable Linux CD (I personally use SystemRescueCd) to help debug problems in the installation process.
  • Make sure all of your /etc/fstab entries correspond to real disks, or your Ubuntu startup process will spend a while waiting for a disk that is never going to show up. I had this problem with the /boot/efi mount, because the mount was UUID based; I "fixed" it by changing the mount to be LABEL based and labeling my vmdk accordingly (I suppose it might also have been possible to change the UUID of my vmdk, but I couldn't find any reasonable instructions for doing so on Windows). Note that the volume doesn't actually have to successfully mount (mine doesn't, because I forgot to format it vfat); it just has to exist so system doesn't wait to see if it connects at some later point in time.
  • I don't really understand how Unity decides to provide scaling options, but although it offers magnification on a bare metal boot, those options are not available when run under a VM. I get something tolerably sized (with only slight blurriness) by setting the resolution to 1680 x 1050; play around a bit with it. I have "Stretch Mode" enabled in VMWare.
  • Whether or not you can log into your account depends on your X11 configuration; so if you're like me and uncommented the Intel configuration, I found this bricks my login (and you can unbrick it by commenting out again.) How do make both work? Don't ask me; I'm still figuring it out.

Window manager

I haven't gotten around to setting up xmonad; this is no small part due to the fact that Unity appears to support a very rudimentary form of tiling: Windows-left and Windows-right will move Windows to fill the left/right half of the display, while Windows-up will full screen a Window. I might still try to get xmonad setup on 18.04, but for now it is nice not having to fight with trayer to get the standard icons.

What's next

My two top priorities for improving the state of Linux on the Surface Book 2:

  1. Rewrite Xournal with support for hDPI (how hard could it be lol)
  2. Figure out how to make suspend/hibernate work more reliably

Otherwise, I am very happy with this new laptop. One thing in particular is how much faster my mail client (still sup) runs; previously, scanning for new mail would be a crawl, but on this laptop they stream in like a flash. Just goes to show how much an upgrade going from a 1.6GHz processor to a 4.2GHz processor is :3

  • March 17, 2019

HIW’18: Let’s Go Mainstream with Eta!

My name is Rahul Muttineni, CTO of TypeLead, working on building services around a language named Eta. To get started, I'll give an overview of how the project started, and where it is now.

It started as a HSOC project. It was called GHCVM; back then we had plans of making it both on JVM and CLR... we don't think about CLR anymore. I was mentored by Edward Kmett. We got pretty good response on this, so Jo and I decided to take the risk and work on this full time.

Big thanks to the GHC team, really good work. We've worked with the codebase for two years, and the more and more we work with it, we see how much awesome stuff there is. I've learned a lot by working with the code.

What is Eta? Eta is a fork of GHC. During the GSOC project, it started off as a Haskell program that used the GHC API. Midway in the program, I found that there were certain things that I wanted to do that I couldn't do, and I spent 3-4 days setting up a fork. I'll talk about what those limitations are. Like Haskell, it's a ... language, but the key difference is that it runs on the JVM. That is its own set of challenges, primarily with respect to tail calls. The nice thing about Eta is that it runs on the JVM, and it can run a good chunk of projects just like that. lens... recently, in the last month, we got Yesod working... it's in good shape. The next really great type of Eta is the strongly typed FFI. That works really well with the subtyping in JVM. A good chunk of the talk is about how we got that working. One of the main focuses of Eta is to be focused on industrial use. GHC is focused on industrial use, and research, both. There's a tension between the two... the nice thing we have for Eta is we don't have to face that tension; it's easy to make decisions on how to add new features, because will it help companies? If it is yes, otherwise we don't. (SPJ: That can be a hard question to answer!)

Haskell: Avoid success at all costs. We're not going to sacrifice core principles of language for benefit. Pursue success, at minimum cost. We want to make it successful as much as possible, but we want to make as little sacrifice as possible. That will be a little tricky...

What is Eta? What language features does it support? It started off as a fork of GHC 7.10.3. All extensions that work there, work with Eta as well. The only thing was TemplateHaskell and QuasiQuotes didn't work for a long time. We got it working 3-4 months ago. Biggest change is JavaFFI. GHC 7.10.3 is MINUS C FFI. We could have supported it: Java has JNI, but we tried to avoid it because we didn't want to make platform specific bindings to all the libbraries.

Joe backported a bunch of GHC 8 features: StrictData, ApplicativeDo, OverloadedLabels. Backpack was got recently. There's a very particular reason we had to do it: it has to do with the fact that we don't have green threads by default, and we wanted to give the user a choice of threaded runtime versus blocking runtime.

The compiler? It's a fork of GHC, so all the compiler passes are the same. We just chopped off everything after STG; e.g., C-- is gone. We generate bytecode from STG. We don't do any optimizations right now, and won't need to for some fine. We don't have to because in JVM, it's JIT compiled, so we don't have to optimize as much since JVM will remove a lot of the code that's not used anyway. And the driver: GHC generates object files... we decided to use JAR files. They're just zip files that package up a bunch of class files that store Java bytecodes. We also added one more mode for Uberjars. These are JAR files that are packaged up into one giant package.

I'll talk a little bit about how we implemented the REPL; template haskell. It works through the external-interpreter architecture. In GHC that's called iserv: the process, what it does, is handles running the code. So the compiler will still do the typechecking and everything, but once it's done with all that stuff, GHC will generate, a specific bytecode set, for interpreting Haskell efficiently. Because we already generated JVM bytecodes. We didn't need that custom bytecode set; we just compile with optimizations off; that gives us JVM bytecodes, then we send it to the external process, load it up, and execute them. Implementing the REPL is pretty easy how to get all this working together. JVM has a mechanism called classloading, which is very flexible. You can download bytecodes from the network, get code an runtime. Once you load the class, it's statically compiled code, it's optimized the same, etc.

The build tool we use is Etlas. We didn't want to move too far off of GHC, we stuck with Cabal. At the point we started using it, we forked off of Cabal 2.0. Main difference is that it lets you manage Eta versions. Etlas is almost like Stack, but it's much much closer to Cabal. We took the nice features of Stack and added them to Cabal. The other thing is that it does patch management. What we've been finding as we add more features and backport, Eta is not exactly GHC 7.10, nor is it GHC 8.0, it's a weird intermediate state, so certain packages that won't exactly compile without small changes, so we needed some system to apply those changes before we actually run the build. So we setup a GitHub repository that stores all the patch files. What etlas will do, it will get you the most recent set of patches. Then if you install a package, lens or something, it will download lens, apply the patch, and then it will build. Just recently, we were using base 4.8, and recently we upgraded to base 4.11. But we couldn't update to the new Generics infrastructure, because it slowed down compile times. So there were a bunch of packages that would check if they were GHC 8... and then use new generics. So we had to do a bunch of patching for that. But that's the kind of stuff we have to deal with.

The title of this talk is lets go mainstream with eta. I want to take a moment and say, what does that mean? "The ideas, attitudes, or activities that are shared by most people and regarded as normal or conventional." So at what point does a programming language become consdiered normal or conventional? It has to be used a big company, solve a big real world problem, and people have to believe it works. That's a very complicated question, multifaceted, one part of that answer is, it should make it easier to solve real world problems easier than the status quo. Take for example PHP. PHP came out when there was nothing better to program dynamic web applications. It had just the minimum features required to make it useful to build these. Now everyone here is asking the question: Haskell clearly solves a lot of problems better than the status quo. So why isn't it moving forward? That's a big question, I'm going to talk about how we're approaching it.

The strategy we're using internally, is we put on a "Big Company Hat"; we pretend we're a big company with a lot of employees, millions or billions of lines, and try to figure out what problems they'll face. Some problems are crazy long build times, when trying to build huge software; dynamic where you have to make sure junior developers get up to speed... etc. That's couple to get this conversation started.

After thinking about this a long time, we boiled it down to three basic principles, how we will develop Eta.

1. User Experience
2. Performance
3. Safety

User Experience is mainly, an emotional thing, how you feel when you use Eta technology, how you interact with it, what you feel when you get an error, psychologically. When something has good UX, we feel good. That's a very subjective thing, it can vary between different people, we have to figure out a way to standardize / make it universal. Something we forget as software and tool developers, the person developing the software is human. If they get errors persistently over time, they'll get frustrated. Machines will do what you tell them over and over again.

So what have we done in Eta to concern? We've done something very recently; it's not in master yet. Jo and I spent a week refactoring the codebase to refactor the error reporting part of the typechecker. It stores a list of strings; internally in GHC, there's a pretty printed data type, a list of those. The problem is we can't do postprocessing on that. So, what Jo did was made a giant data type with three hundred data constructors, one for every error message in GHC. That refactor to a week (SPJ: only a week?!) How it is now, it's decoupled, now you have, instead of storing in the typechecking monad, storing strings, you store a data type that stores the relevant data to print out that error message. And then at the final point, you can traverse the data type; based on the presence of other errors, you can decide what to do. Now it's pattern matching on certain error patterns and reporting them nicely. This is one example. We talked about simple errors: refactoring, adding an argument, changing the type, that's one of the most common errors you'll get working with Haskell. So we focused on those first. This shows an example of a type error... 'checker', it's an IO action.

GHC would tell you, couldn't match Int -> IO () with IO (). The problem is, for people who don't know how the typechecker works, they won't be able to understand what the typechecker is doing: going argument by argument. Because of the refactor we've done, it was easy to pattern match on this particular case, and say, hey, if the user forgot to put an argument, you can print out an error message of this form. You print an argument is missing, you highlight. (SM: You might have been missing the first argument, in this case!) That's true. It's tricky; sometimes the suggestion you give, might not. We don't tell people what they did exactly wrong, because we don't know. This is not a perfect thing, but we try to give the best suggestion that we can. And an important feature of this, most of how we decdied this layout, we studied languages like Elm and Purescript, which have done good work in this error. PureScript and Elm both, what they do, for a certain type of error, and you're not sure what to do... e.g., our info is not complete, they can go to a particular link and see other things that could have happened. So we don't have to flood the user with every suggestion, we just have to show the user what probably is the cause for it. And if it's a tricky case, not what we posted, in the link, we'll have the case as well.

(BG: There's other information that might be relevant; expanding type synonyms, etc. Do you have this info?) We're still thinking about that. Probably we'll have extra flags and stuff. Eventually, we'll have a mode that prints out JSON for IDEs, then it's easier to parse on the IDE side. (BG: Incidentally, there's a ticket, a student working with Richard, trying to figure out smoething similar).

Another aspect of UX is we added the REPL. Tried to simplify the entry point, try to make it easy. You want types, kinds, and where to find out more information. This is a statically typed language: you always hhave to be thinking about types. So we :set +t: always print out the types when you print things. One more thing, one of the former Scala engineers, has been learning Haskell, and he made a critique of one aspect of the REPL experience. f is a function of two argumets. In a second statement of the REPL, I applied 1. Find instance, show instance, for a goes to a. He said that... no show instance found, just say that this is a function, and you can't print it. That's a change we did. This was very easy for us to do.

Performance: it can mean many things. We're talking about fast developer feedback loop. Compile time and develop time, reducing that feedback loop. Some work we've done in this direction is reproducible builds. As of now, we have bit-for-bit reproducibility in Eta. That amounted to... nikita already did lots of work on reproducibility, he made HAskell interface reproducible; but the last mile of bit for bit is hard, there's many places. For our code generator, it was a lot simpler, we didn't have to do as much. It was 20 lines of code to make it deterministic. The main source of nondeterminism in GHC is the Unique data type, that changes between different runs depending on environment. What we did, was we added a counter. We used to print the uniques in the Java class name; that will make it nondeterministic. So we made a counter: the order in which the bindings make it to STG is the same.

GHCi is known to take up lots of memory, esp. with IDE. Simon Marlow has a bunch of fixes to that; we also backported those.

Another aspect of performance is the actual runtime performance. We're on the JVM, that puts us at a huge disadvantage. We don't have control over many things. The runtime system... this is Java. It's OO, so the runtime system is implemented in Java. We setup a hierarchy for values, that are defined in Eta. We have Closure, it's a class, parent class of all values, thunks, WNF. The Closure class has two methods. evaluate, evaluate to WHNF, and enter will actually enter... it's similar to GHC runtime system. The initial version was modeled exactly after GHC, except for tail calls. The terminology is similar. It's primarily used when you do the body of function. The main subclasses of Closure are Thunk and Value. Value will be the parent class, of things like functions, partiallly applied functions, and data constructors. Thunk will be the superclass of things like CAFs, single entry thunks, and updatable thunks. CAFs don't have free variables, so there's a special case for that, and you create a blackholing entry every time, to avoid two threads evaluating the same thunk. UpdatableThunk pushes an update frame, when it's finished evaluating, it will update the thunk to point to the newly computed value. And SingleEntryThunk, they're evaluated only once, so you can just evaluate it directly without pushing an update frame. This terminology is borrowed from GHC as well.

VAlues: DataCon, Function and PAPs. In the early days, and even now, every function call that was a tail call, is just a method call. This is the only way to make it remotely efficient. (More on stack soon). For static tail recursive calls: singly recursive or mutually recursive, they get compiled to loops. In most cases, they get a nice tight loop. In the mutual case, what will happen is, we collect all of the SCC, and we make one giant method that goes into a loop. Let's say you're in the even/odd example, what will happen is, when even calls odd, there's a variable called target, an integer. Even will be assigned 0, odd is assigned 1, so then you set 1 and restart. (BG: Do you always have unfoldings available for functions you compiled?) This is mutually recursive functions defined in the same module. (SPJ: They might have very different type arguments.) We cat all the arguments into one. The main problem with this argument, is parsers generated with Happy and Alex, we hit limits. (BG: Crash?) Not stack blowup. JVM has method size limit, so you can only have 65000 bytecodes. That's Eta compiled with itself. That's the only thing that's preventing us from using Eta with Eta. But all you need to do is split method into smaller chunks.

So how do we handle tail calls? When we know it , tail recursive, let's say you don't. Let's say you're using CPS. It's so common in Haskell, any fast parser uses CPS. In early days, Aeson would just blow the stack, it was pretty bad. So, we explored trampolining by default, and it was just awful, it was slow, super slow. What we did is turn it off, and let stack blow up. We found a better solution. The JVM has... the only way to unwind the stack is throwing an exception, or returning, and keep on returning until you return all the way down. It turns out, with exceptions, you can turn off the feature that captures the stack trace: that's the most expensive part of an exception. So we have a general exception. So this trampoline mechanism is optional. So, what we do, we have a function 'trampoline :: a -> a', runtime primitive, what it does is activates a boolean in the context which says, I'm going to trampoline now, and it activates a codepath that turns a counter, and once you reach a certain number, which is configurable, it will unwind the stack, and then continue where it needed to go. Our limit is 400, and then we unwind. It used to be in 1000s, but with Happy and Alex, we needed a smaller number. (BG: Inside that context, how much does it cost? But observably, it's faster. A couple months ago, we got PureScript to work in Eta, and it wasn't bad by default?) (SPJ: So you could turn it on by default: all you're doing is counting.) The counting is how we know how big the stack is. In your main function, you could call trampolineIO, and trampoline your whole program. (SPJ: Maybe it's low overhead, and you can do it all the time.) If it's low, we will do it. (How do you resume? Once you raise the exception, what do you store?) The counter happens at the entry point, and it's guarded bby the boolean. So, that, if the limit is exceeded, it will call another function that takes the context. So we store all the arguments in a context variable that gets passed to every eta function. We stash all the arguments into a function that has the state, then wjhen it unwinds, marked by this function, it will call that, with that function and those arguments.

As I mentioned, it's guarded by a boolean. JVM has an optimization, where it observes the boolean is true for a lot of times, it won't even compile that branch in the native code. So if you don't use trampolining, it doesn't affect you at all; the code for the counter will just not be there.

One nice thing I like about Eta is that you actually get stack traces for exceptions. This is because, to get good perf for Eta, you have to implement most primitives on JVM stack. This is a sample stack. You have a schedule loop, and you hare evaluting some IO acttion. applyV/applyN, these are partial applications. Execute an IO action. And another nice part, we've tried to encode it close to the original name. So you can tell this fucntion call happened in statistics.Regression, rnfAll. If you see, you notice there are line numbers. This is not perfect, and we can definitely make it better later... GHC gives you a lot of debugging info at STG time, but because the JVM doesn't have much flexibility, we can only attach one line number to code, so we have to discard all that info. This will get better; we'll stash that debug information in the classfile itself, and then access it and render a better stack trace. (BG: This is Source Notes?) Yeah.

Concurrency: One nice part is, it's nice or not. If you're evaluating a long chain of thunks, you're going to blow the stack. This happily coincides with GHC also having a space leak. Neil Mitchell wrote a blog post about how to detect space leaks: restrict stack size and then figure out which thunk was being evaluated. If you see a stack trace like this, and you see a huge chain of evaluates, in a long chain, you probably have a space leak.

How do I do interop? The way we did interop was, made a thing called the Java monad. IT's supposed to give you the experience of programming JAva. The basic implementation is inspired from IO monad. Object# c is "this", the object that is being threaded through. Because of this encoding, you get the Java experience: you can call dot on the Java object. It's almost like working with Java inside. The argument is called... that's the type constructor that forced us to fork, instead of use the API. You can't declare primitive types in the API. And we had to introduce a new low level representation. Declare wrapper types, wrapping the iterable interface in Java. We've stolen better syntax, which were type applications... resolve it somehow. I'm declaring an Eta type that wraps a JAva type, @java.lang.Iterable.

You use the java function to run the Java monad. All of these have to be imported. newArrayList, newInteger, but we brought some combinators, that let you call methods. It owrked out with the monad. This is sample code that does the same thing as Java code. it just uses standard monadic combinators. If it's a fixed c, it's an instance.

You can use Eta as a better JAva, with referential transparency! Unlike Kotlin or Scala.

How do we handle subtyping? We define uilt in type families. We have a typeclass named extends. Any time you declare a function that takes a given class and any subtype of that class, you can, instead of actually subtyping, we do it with constraints. Extends' takes the info from Inherits and figures it out. You can use the dot operator on anything that is a subclass of Iterator. We had to extend the typechecker just a little bit: a lot of times the type gets stuck in the form Extends' (List JSTring) (List a) where a is unconstrained.

Imports are tiresome, so we're setting up direct Java Interop; actually use JAva reflection to get info class files, and generate imports. "import java java.lang.Math" works, but doesn't scale. Biggest priority for the rest of the year is Java interop, really good IDE support, documentation, language extensions: UnboxedSums, TypeApplications, DerivingVia, QuantifiedConstraints. We have some new language extensions in mind, AnonymousRecords, RowTypePolymorphism... We'll see how that goes.

I was thinking about ways... we work on the same codebase, how to collaborate? We're interested in compile performance, support for unbboxed sums. Worker wrapper has some glitch, and no one got around to fixing it. At some point, maybe not any time soon, that and mutable fields. Pretty important for us. (BG: Do unboxed sums get a lot of usage? Why unboxed sums? Does Eta code make a lot of use?) No. But a lot of people on JVM are annoyed that Maybe is boxed all the time. But if you have unboxed sums, you can represent it as null. (SPJ: Or you can say, just box it, and you won't notice it. If it's fast enough all the time, focus on what's going to make a difference.)

Q: Did you consider using Graal (it's a new virtual machine that supports partial evaluation and partial escape analysis, good for functional languages)?

A: We have looked into it, it's not completely there yet to use, and we're not sure if it's something we can invest time with. We're keeping up with it. (BG: But you lose the JVM!) That's what's preventing us from going there. Maybe if it gets integrated into a mainline VN we might look at it. (Mainline Java is planning to integrate Graal)

Q: (SM) Are you keeping the fork up to date with master GHC?

A: One thing that is out of bounds for us, and for a long time, is all the dependent Haskell work. Everything else, we keep up. If there's any nice bugfixes... (SM: So you're selectively backporting).

Q: (BG) Have you considered unforking.

A: Not yet, no.

  • September 23, 2018

A year into Backpack

It's been a year since I got my hood and gown and joined Facebook (where I've been working on PyTorch), but while I've been at Facebook Backpack hasn't been sleeping; in fact, there's been plenty of activity, more than I could have ever hoped for. I wanted to summarize some of the goings on in this blog post.

Libraries using Backpack

There's been some really interesting work going on in the libraries using Backpack space. Here are the two biggest ones I've seen from the last few months:

unpacked-containers. The prolific Edward Kmett wrote the unpacked-containers package, which uses the fact that you can unpack through Backpack signatures to give you generic container implementations with hypercharged performance (15-45%) way better than you could get with a usually, boxed representation. A lot of discussion happened in this Reddit thread.

hasktorch. hasktorch, by Austin Huang and Sam Stites, is a tensor and neural network library for Haskell. It binds to the TH library (which also powers PyTorch), but it uses Backpack, giving the post Backpack for deep learning from Kaixi Ruan new legs. This is quite possibly one of the biggest instances of Backpack that I've seen thus far.

Backpack in the Ecosystem

Eta supports Backpack. Eta, a JVM fork of GHC, backported Backpack support into their fork, which means that you can use Backpack in your Eta projects now. It was announced in this Twitter post and there was some more discussion about it at this post.

GSOC on multiple public libraries. Francesco Gazzetta, as part of Google Summer of Code, is working on adding support for multiple public libraries in Cabal. Multiple public libraries will make many use-cases of Backpack much easier to write, since you will no longer have to split your Backpack units into separate packages, writing distinct Cabal files for each of them.

Backpack in GHC and Cabal

By in large, we haven't changed any of the user facing syntax or semantics of Backpack since its initial release. However, there have been some prominent bugfixes (perhaps less than one might expect), both merged and coming down the pipe:

  • #13955: Backpack now supports non-* kinds, so you can do levity polymorphism with Backpack.
  • #14525: Backpack now works with the CPP extension
  • #15138: Backpack will soon support data T : Nat signatures, which can be instantiated with type T = 5. Thank you Piyush Kurur for diagnosing the bug and writing a patch to fix this.
  • A fix for Cabal issue #4754: Backpack now works with profiling

Things that could use help

Stack support for Backpack. In Stack issue #2540 I volunteered to implement Backpack support for Stack. However, over the past year, it has become abundantly clear that I don't actually have enough spare time to implement this myself. Looking for brave souls to delve into this; and I am happy to advise about the Backpack aspects.

Pattern synonym support for Backpack. You should be able to fill a signature data T = MkT Int with an appropriate bidirectional type synonym, and vice versa! This is GHC issue #14478 We don't think it should be too difficult; we have to get the matchers induced by constructors and check they match, but it requires some time to work out exactly how to do it.

  • July 14, 2018