Posted on September 4, 2020
by Huw Campbell

Using traversals and a simple fixpoint monad, we can efficiently crunch Icicle expressions to a minimal, efficient kernel. This speeds up downstream compilation, reduces generated code size, and makes our queries run faster.

Icicle is a high level language. The source language has mode polymorphism for error handling and runtime staging, type inference, function abstractions, and syntactic sugar to make describing features fast and easy. This is great, but due to the large surface area, makes its source language is not a great target for writing compiler optimisation passes.

Fortunately though, during compilation we pass through a simple, typed expression language called *Core*.

In this post, we’ll look at how we can apply high level optimisations in *Core* to significantly reduce our amount of code and remove redundant code branches.

If we write a traversal over the *children* of an expression tree, we can pattern match on key terms, and then trivially recurse elsewhere in the tree to find all the sub-expressions where this pattern match applies.

The `Plated`

type class from the Haskell lens library allows us to scrap our boilerplate by recursively traversing syntax trees in this manner. Twan van Laarhoven, who perhaps not coïncidentally invented modern lenses, has a great blog post about traversals like plated.

Here’s the simplified definition:

```
class Plated a where
plate :: forall f. Applicative f => (a -> f a) -> a -> f a
```

And here is our core expression language^{1}.

```
data Exp
-- | Read a variable from environment
= XVar !Name
-- | A predefined primitive
| XPrim !Prim
-- | A constant simple value with its type
| XValue !ValType !BaseValue
-- | Application
| XApp !Exp !Exp
-- | Lambda abstraction
| XLam !Name !ValType !Exp
-- | Let binding
| XLet !Name !Exp !Exp
deriving (Eq, Show)
data Prim
-- | Minimal things like numeric, string, and date primitives.
= PrimMinimal !PrimMinimal
-- | Fold and return type
| PrimFold !PrimFold !ValType
deriving (Eq, Show)
-- | Folds for destructing things (catamorphisms)
data PrimFold
-- | If then else fold over bools
= PrimFoldBool
-- | Case expression over an Option
| PrimFoldOption !ValType
-- | Case expression over a Sum
| PrimFoldSum !ValType !ValType
deriving (Eq, Show)
```

Pretty standard stuff. Our primitive type contains a number of folds which we compile case matches to. `PrimFoldBool`

for example is what an `if _ then _ else`

from the source language becomes.

*It’s the combination of our folds and expressions which are most interesting when it comes to optimisation rules.*

Remembering that `plate`

is a traversal over the *children* of an expression tree, here is our `Plated`

instance:

```
instance Plated Exp where
XApp x y) =
plate f (XApp <$> f x <*> f y
XLam n t x) =
plate f (XLam n t <$> f x
XLet n x y) =
plate f (XLet n <$> f x <*> f y
@XVar {} = pure x
plate _ x@XPrim {} = pure x
plate _ x@XValue {} = pure x plate _ x
```

This allows for extremely concise traversals, for example, we can easily count how many times a variable is used in an expression (this can be useful when figuring out if we should inline it for instance).

```
foldExp :: Monoid x => (Exp -> x) -> (Exp -> x)
= foldMapOf plate
foldExp
varCount :: Name -> Exp -> Sum Int
XVar j) | i == j = Sum 1
varCount i (= foldExp (varCount i) x varCount i x
```

The key transformer we’ll use during optimisation however is `transformM`

, which traverses every element in the tree, in a bottom-up manner.

`transformM :: (Monad m, Plated a) => (a -> m a) -> a -> m a`

There’s also a pure counterpart `transform`

, but we’re going to need a fresh name supply, and a way to indicate if we’ve reached a fix point. We do this with a custom monad `FixT`

. For this post we’ll use an alternative version, as it is isomorphic to `WriterT Any`

(with `Any`

from `Data.Monoid`

), where if we make any progress, we indicate this with a progress function instead or `pure`

or `return`

.

```
type FixT m a = WriterT Any m a
progress :: Monad m => a -> FixT m a
= do
progress a Any True)
tell (return a
```

Then to reach a fixpoint, we recursively call this function until there is no more work to do

```
fixpoint :: Monad m => (a -> FixT m a) -> a -> m a
fixpoint f a= do (a', progress) <- runWriter (f a)
case progress of
Any True -> fixpoint f a'
Any False -> return a'
```

That’s all the tools we need to write our optimisation passes.

The simplest optimisation we use is constant folding of primitive functions. If a primitive function is fully saturated with real values, we can run our core evaluator and replace the expression with the new result.

It looks something like this (where we have helper function type signatures as comments).

```
-- takePrimApps :: Exp -> Maybe (Prim, [Exp])
-- takeValue :: Exp -> Maybe Value
-- simpPrim :: Prim -> [Value] -> Maybe Exp
constantFold :: Monad m => Exp -> FixT m Exp
constantFold unsimplified| Just (prim, args) <- takePrimApps unsimplified
Just valueArgs <- traverse takeValue args
, Just simplified <- simpPrim prim valueArgs
, = progress simplified
| otherwise
= return unsimplified
```

Notice that we’re matching on the whole expression here. It’s extremely unlikely that the user’s whole program is a single primitive application, so we need to traverse the leaves of the expression tree, finding where we might be able to constant fold. Fortunately, we have just the function:

```
simpExp :: Monad m => Exp -> FixT m Exp
= transformM constantFold simpExp
```

A slightly more advanced optimisation is known as the case of known constructor optimisation. Let’s have a look with an example:

```
Sum_fold#
(\a -> False)
(\b -> eq# b b_test)
(right# (get_location# val))
```

Here, `Sum_fold#`

is a primitive which acts like a Haskell `case`

expression over an `Either`

. The arguments are: the lambda of the left case; that of the right case; then the scrutinised expression.

Here, even though we don’t have a `val`

at compile time. We can see that the result is always going to take the right branch of the fold, and therefore rewrite this expression to

```
let b = get_location# val
in eq# b b_test
```

The scrutinee in the example above (`right# (get_location# val)`

), is always going to end up with a `Right`

, even though it can’t be constant folded, due to `right#`

being the known constructor of a `Right`

value. We refer to this as a *irrefutable* expression.

A more challenging example of a pattern which appears pretty regularly due to inlining and the way in which our modal type system handles error conditions is something like this:

```
Sum_fold#
(\a -> False)
(\b -> eq# b b_test)
(Sum_fold#
(\err -> left# err)
(\val -> right# (get_location# val))
scrutinee)
```

The expression above is not optimal. It’s not a constant though, and the scrutinee isn’t guaranteed to force a single branch either, so we can’t constant fold this expression or use the case of known constructor optimisation… if we look really closely though, we can see that scrutinising a second time isn’t actually required at all.

We can justifiably rewrite this expression as:

```
Sum_fold#
(\err ->
let a = err
in False)
(\val ->
let b = get_location# val
in eq# b b_test)
scrutinee
```

The key insight here is that if we scrutinise the `scrutinee`

, and find a `Right`

value, we’re always going to exercise the `Right`

case of the outer fold; and similarly for a `Left`

case we know what side we’ll receive. *Both* sides of the expression are *irrefutably* going to be a `Right`

value or a `Left`

value, even though we don’t know what values they will hold.

We can therefore skip the outer case expression, in what we call the *Case of Irrefutable Case* optimisation^{2}.

Restating what we saw above: if we *know* which cases both branches of the inner expression will produce, we don’t actually need to check them again.

This function as implemented in our compiler is actually a little hairy, as it also sees through let bindings, handles other types of folds, and takes care of shadowing and renaming, so I won’t reproduce it here. Crucially though, it does not worry about recursion and tree traversal and reports progress to our fixpoint monad. We can therefore compose it with our constant fold and case of known constructor passes.

```
simpExp :: Monad m => Exp -> FixT m Exp
=
simpExp
transformM transformationswhere
transformations= constantFoldExp
>=> caseOfKnownConstructor
>=> caseOfIrrefutableCase
>=> inline
```

*This insight is what makes our optimisations efficient.* We have a little bundle of optimisation functions which act in this manner, all apply to a single expression in the tree. By combining them here, at the single leaf or node, we can crunch individual leaves to their optimal state, before repeating the process for their parents. We can then do this process just a few times to reach a fixpoint. If we were to run each optimisation function independently leaf to root, we’d need to traverse the whole expression a lot more.

Our whole *Core* optimisation pipeline, is essentially this:

```
crunch :: Exp -> Fresh Exp
= fixpoint simpExp crunch
```

And using it, we reduced code sizes by up to 70%, with a commensurate reduction in downstream compilation time and a healthy boost to our runtime performance.

This is actually slightly simplified in that I’ve specialised or removed unnecessary type parameters. All of the case branches of

`Exp`

are present, while some folds and primitives are elided.↩︎GHC, being a highly optimising compiler, has a similar optimisation called the

*Case of Case*optimisation. The difference is that in GHC, the outer case expression is duplicated in its entirety. In the cases where we would do*our*optimisation, a case of known constructor would apply and reach the same result. Practically speaking, in GHC, some heuristics optimistically decide when to perform a*Case of Case*; while in Icicle, a*Case of Irrefutable Case*is always a good idea.↩︎