vmap is an interface popularized by JAX which offers you a vectorizing map. Semantically, a vmap is exactly equivalent to a map in Haskell; the key difference is that operations run under a vmap are vectorized. If you map a convolution and a matrix multiply, you will have one big loop which repeatedly calls convolution and matrix multiply for each entry in your batch. If you vmap a convolution and matrix multiply, you’ll call the batched versions of convolution and matrix multiply once. Unless you have a fuser, on most modern deep learning frameworks, calling the batched implementations of these operations will be much faster.
Read more...
This post is a long form essay version of a talk about PyTorch internals, that I gave at the PyTorch NYC meetup on May 14, 2019.

Hi everyone! Today I want to talk about the internals of PyTorch.

This talk is for those of you who have used PyTorch, and thought to yourself, “It would be great if I could contribute to PyTorch,” but were scared by PyTorch’s behemoth of a C++ codebase. I’m not going to lie: the PyTorch codebase can be a bit overwhelming at times. The purpose of this talk is to put a map in your hands: to tell you about the basic conceptual structure of a “tensor library that supports automatic differentiation”, and give you some tools and tricks for finding your way around the codebase. I’m going to assume that you’ve written some PyTorch before, but haven’t necessarily delved deeper into how a machine learning library is written.
Read more...