## Category theory for loop optimizations

Christopher de Sa and I have been working on a category theoretic approach to optimizing MapReduce-like pipelines. Actually, we didn’t start with any category theory—we were originally trying to impose some structure on some of the existing loop optimizations that the Delite compiler performed, and along the way, we rediscovered the rich relationship between category theory and loop optimization.

On the one hand, I think the approach is pretty cool; but on the other hand, there’s a lot of prior work in the area, and it’s tough to figure out where one stands on the research landscape. As John Mitchell remarked to me when I was discussing the idea with him, “Loop optimization, can’t you just solve it using a table lookup?” We draw a lot of inspiration from existing work, especially the *program calculation* literature pioneered by Bird, Meertens, Malcom, Meijer and others in the early 90s. The purpose of this blog post is to air out some of the ideas we’ve worked out and get some feedback from you, gentle reader.

There are a few ways to think about what we are trying to do:

- We would like to
*implement*a calculational-based optimizer, targeting a real project (Delite) where the application of loop optimizations can have drastic impacts on the performance of a task (other systems which have had similar goals include Yicho, HYLO). - We want to venture where theorists do not normally tread. For example, there are many “boring” functors (e.g. arrays) which have important performance properties. While they may be isomorphic to an appropriately defined algebraic data type, we argue that in a calculational optimizer, we want to
*distinguish*between these different representations. Similarly, many functions which are not natural transformations*per se*can be made to be natural transformations by way of partial application. For example,`filter p xs`is a natural transformation when`map p xs`is incorporated as part of the definition of the function (the resulting function can be applied on any list, not just the original`xs`). The resulting natural transformation is*ugly*but*useful*. - For stock optimizers (e.g. Haskell), some calculational optimizations can be supported by the use of
*rewrite rules*. While rewrite rules are a very powerful mechanism, they can only describe “always on” optimizations; e.g. for deforestation, one always wants to eliminate as many intermediate data structures as possible. In many of the applications we want to optimize, the best performance can only be achieved by*adding*intermediate data structures: now we have a space of possible programs and rewrite rules are woefully inadequate for specifying*which*program is the best. What we’d like to do is use category theory to give an account for rewrite rules*with structure*, and use domain specific knowledge to pick the best programs.

I’d like to illustrate some of these ideas by way of an example. Here is some sample code, written in Delite, which calculates an iteration of (1-dimensional) k-means clustering:

(0 :: numClusters, *) { j => val weightedPoints = sumRowsIf(0,m){i => c(i) == j}{i => x(i)}; val points = c.count(_ == j); val d = if (points == 0) 1 else points weightedPoints / d }

You can read it as follows: we are computing a result array containing the position of each cluster, and the outermost block is looping over the clusters by index variable `j`. To compute the position of a cluster, we have to get all of the points in `x` which were assigned to cluster `j` (that’s the `c(i) == j` condition) and sum them together, finally dividing by the sum by the number of points in the cluster to get the true location.

The big problem with this code is that it iterates over the entire dataset *numClusters* times, when we’d like to only ever do one iteration. The optimized version which does just that looks like this:

val allWP = hashreduce(0,m)(i => c(i), i => x(i), _ + _) val allP = hashreduce(0,m)(i => c(i), i => 1, _ + _) (0::numClusters, *) { j => val weightedPoints = allWP(j); val points = allP(j); val d = if (points == 0) 1 else points return weightedpoints / d }

That is to say, we have to precompute the weighted points and the point count (note the two hashreduces can and should be fused together) before generating the new coordinates for each of the clusters: generating *more* intermediate data structures is a win, in this case.

Let us now calculate our way to the optimized version of the program. First, however, we have to define some functors:

`D_i[X]`is an array of`X`of a size specified by`i`(concretely, we’ll use`D_i`for arrays of size`numPoints`and`D_j`for arrays of size`numClusters`). This family of functors is also known as the diagonal functor, generalized for arbitrary size products. We also will rely on the fact that`D`is representable, that is to say`D_i[X] = Loc_D_i -> X`for some type`Loc_D_i`(in this case, it is the index set`{0 .. i}`.`List[X]`is a standard list of`X`. It is the initial algebra for the functor`F[R] = 1 + X * R`. Any`D_i`can be embedded in`List`; we will do such conversions implicitly (note that the reverse is not true.)

There are a number of functions, which we will describe below:

`tabulate`witnesses one direction of the isomorphism between`Loc_D_i -> X`and`D_i[X]`, since`D_i`is representable. The other direction is`index`, which takes`D_i[X]`and a`Loc_D_i`and returns an`X`.`fold`is the unique function determined by the initial algebra on`List`. Additionally, suppose that we have a function`*`which combines two algebras by taking their cartesian product,`bucket`is a natural transformation which takes a`D_i[X]`and buckets it into`D_j[List[X]]`based on some function which assigns elements in`D_i`to slots in`D_j`. This is an example of a natural transformation that is not a natural transformation until it is partially applied: if we compute`D_i[Loc_D_j]`, then we can create a natural transformation that doesn’t ever look at`X`; it simply “knows” where each slot of`D_i`needs to go in the resulting structure.

Let us now rewrite the loop in more functional terms:

tabulate (\j -> let weightedPoints = fold plus . filter (\i -> c[i] == j) $ x points = fold inc . filter (\i -> c[i] == j) $ x in divide (weightedPoints, points) )

(Where `divide` is just a function which divides its arguments but checks that the divisor is not zero.) Eliminating some common sub-expressions and fusing the two folds together, we get:

tabulate (\j -> divide . fold (plus * inc) . filter (\i -> c[i] == j) $ x)

At this point, it is still not at all clear that there are any rewrites we can carry out: the `filter` is causing problems for us. However, because filter is testing on equality, we can rewrite it in a different way:

tabulate (\j -> divide . fold (plus * inc) . index j . bucket c $ x)

What is happening here? Rather than directly filtering for just items in cluster `j`, we can instead view this as *bucketing* `x` on `c` and then indexing out the single bucket we care about. This shift in perspective is key to the whole optimization.

Now we can apply the fundamental rule of natural transformations. Let `phi = index j` and `f = divide . fold (plus * inc)`, then we can push `f` to the other side of `phi`:

tabulate (\j -> index j . fmap (divide . fold (plus * inc)) . bucket c $ x)

Now we can eliminate `tabulate` and `index`:

fmap (divide . fold (plus * inc)) . bucket c $ x

Finally, because we know how to efficiently implement `fmap (fold f) . bucket c` (as a `hashreduce`), we split up the `fmap` and join the fold and bucket:

fmap divide . hashreduce (plus * inc) c $ x

And we have achieved our fully optimized program.

All of this is research in progress, and there are lots of open questions which we have not resolved. Still, I hope this post has given you a flavor of the approach we are advocating. I am quite curious in your comments, from “That’s cool!” to “This was all done 20 years ago by X system.” Have at it!