Skip to the content.

Scrap your boilerplate, with recursive continuations 1

The core trick I want to introduce is simple: Adding some knot-tying to continuations lets us add a recurse operator which is really useful when writing generic traversals. Weirdly the resulting continuation passing style closely mirrors the v-tables which implement OOP inheritance. This blog post focuses on concrete implementation based on scrap your boilerplate, though it would work for KURE or GHC.Generics based traversals. The code for this post can be found in this gist. I have some minor open design questions so it hasn’t quite made it into a library yet.

What?

Our end goal is to write queries and transformations over mutually recursive types, while only requiring a deriving Data on each type:

data Expr = Plus Expr Expr | Minus Expr Expr | Lit Int | Ref Var
   deriving (Eq, Ord, Show, Data)
 
type Var = String
data Lang
   = Let { var :: Var, expr :: Expr, body :: Lang }
   | If { ifHead :: Expr, trueBranch :: Lang, falseBranch :: Lang } 
   | Return Expr
   deriving (Eq, Ord, Show, Data)
   
test :: Lang
test 
  = If 
    (Plus (Lit 1) 
          (Minus (Ref "x") (Ref "x"))) 
    (Return (Ref "a"))
    (Return (Ref "b"))

bottomUp :: Trans m
bottomUp =
   recurse >>> 
     (   tryTrans_ @Expr \case
              Minus x y
                | y == x -> Just (Lit 0)
              Plus (Lit a) (Lit b) -> Just (Lit (a + b))
              _ -> Nothing
      ||| tryTrans_ @Lang \case
              If (Lit i) a b -> Just (if i == 0 then b else a)
              _ -> Nothing
   )
   
>>> run bottomUp test
Return (Ref "a")

We will transform Minus (Ref "x") (Ref "x") into Lit 0, then Plus (Lit 1) (Lit 0) into Lit 1, and finally the entire if-statement into Return (Ref "a"). Note that the transformation didn’t cover all constructors. The default base-case is the identity transform, and recurse automatically targets all sub-fields

Here, the datatypes are fairly small so a manual implementation would be easy. Even small real languages are much larger, though, and GHC’s typechecking AST has over a hundred constructors! Imagine the above transformation but there are an extra 98 cases to recurse on child-terms. No wonder Haskell has so many approaches to generic programming.

By abstracting over an applicative, queries are just a special kind of transform with a MonadWriter constraint:

-- | Collect all references which are used but not bound in the code block
freeVarsQ :: (MonadWriter (Set.Set Var) m) => Trans m
freeVarsQ =
     tryQuery_ @Expr \case
       Ref v -> Just (Set.singleton v)
       _ -> Nothing
 ||| tryQuery @Lang (\rec -> \case
      Let {var, expr, body} -> Just (rec expr <> Set.delete var (rec body))
       _ -> Nothing)
 ||| recurse
 
 
 >>> runQ freeVarsQ test
 S.fromList [Ref "x", Ref "a", Ref "b"]

Again we only mention the constructors which matter. Because we use generic programming, runQ freeVarsQ :: Data a => a -> S.Set Var works on any type implementing Data. This makes it easy extend types, add extra data such as source locations, or transform collections.

Why

So we’ll write composable tranformations. How is this approach different from existing approaches such as scrap-your-boilerplate?
There are three big points:

How?

To implement this API, we will use the scrap your boilerplate approach to generic programming. The type signaturess can be a bit confusing and we are not going to go in-depth. See here for a full-fledged introduction. Data.Data is also notoriously slow, but we will borrow a neat optimization from the lens library.

The scrap your boilerplate approach is based on two key pieces:

-- We can use Typeable to cast types at runtime
-- Internally, this compares the TypeRep's and performs an unsafeCoerce if they match
tryCast :: forall a b. (Typeable a, Typeable b) => (a -> String) -> b -> Maybe String
tryCast f x = case eqTy @a @b of
   Just Refl -> Just (f x) -- Here, type @a@ equals @b@
   Nothing -> Nothing -- Here they don't, @f x@ would be a type error!

-- We can use the `gfoldl` method in Data.Data to visit child terms
gmapM :: forall m a. (Data a, Applicative m) => (forall d. Data d => d -> m d) -> a -> m a
gmapM visitChild = gfoldl k pure
  where
    k :: Data d => m (d -> b) -> d -> m b
    k holeWithoutField focusedField = holeWithoutField <*> visitChild focusedField

Data.Typeable is quite magic and automatically derived by GHC for all types. We do not even get the opportunity for hand-written instances! For Data.Data we require a -XDeriveDataTypeable extension and an explicit deriving Data. There are good reasons to write these instances manually: For GADTs we usually have to. But even for normal types we may want to ban some constructors from being generated, or some fields from getting visited.

The gfoldl implementation has quite a confusing type signature. The idea is that we use a z function to wrap the constructor, and repeatedly apply a k function to visit each argument. The k function can use the Data constraint to recurse further, and use the Typeable super-class to branch on the current type.

instance Data Lang where
    gfoldl :: (forall d b. Data d => m (d -> b) -> d -> m b)
           -> (forall g. g -> m g)
           -> Lang
           -> m Lang
    gfoldl k z (Bind a b c) = z Bind `k` a `k` b `k` c
    gfoldl k z (If a b c) = z If `k` a `k` b `k` c
    ...

Data.Data makes it easy to throw all transformations into one simple shape:

type Trans1 m = forall x. Data x => x -> m x

tryTrans1 :: forall a m. (Typeable a, Monad m) => (a -> m a) -> Trans1 m
tryTrans1 f (x :: tx) = case eqT @a @tx of
   Just Refl -> f x -- apply the transformation
   Nothing -> pure x -- keep the old value here

Is SYB enough?

Do we need anything on top of Data.Data? The recursive CPS style does not add expressiveness, so strictly speaking no. We could re-write the freeVarsQ example using plain (if awkward) SYB:

freeVarsSYB :: Data a => a -> S.Set Var
freeVarsSYB = (mconcat . gmapQ freeVarsSYB) `extT` freeVarsExpr `extT` freeVarsLang

freeVarsExpr :: Expr -> Set.Set Var
freeVarsExpr (Var v) = S.singleton v
freeVarsExpr a = mconcat (gmapQ freeVarsSYB a)

freeVarsLang :: Lang -> Set.Set Var
freeVarsLang (Let expr v body) = freeVarsExpr expr <> S.delete v (freeVarsLang body)
freeVarsLang a = mconcat (gmapQ freeVarsSYB a)

In my experience SYB with explicitely recursive functions invites subtle bugs or infinite loops when we miss a case. We will use a CPS transformation to factor out the recursive mconcat (gmapQ freeVarsSYB a) into recurse.

Additionally, the gmapQ call doesn’t know what we can match so it must visit irrelevant sub-terms and take much more time. Of course we could build a smarter gmapQ and pass the possible targets (Set.fromList [typeRep @Lang, typeRep @Expr]) everywhere, but manually passing this set would be another opportunity for hard-to-debug mistakes.

What’s that about vtables?

When a transformation recurses, succeeds, or fails, we have to decide what to do next. To keep things composable, we want to delay this decision.

The classic approach to abstract over some implementation is vtables, aka records of functions. Haskell can do vtables: We could literally pass around records of functions, or make GHC do the legwork by using type-classes. At a surface level these require very different styles:

Type-classes optimize better, but would require a lot of type-level programming to be as expressive. We will just use functions, ending up with a slightly weird continuation-passing-style.

We stash every possible continuation into a struct for readability:

type Trans1 m = (forall x. Data x => x -> m x)
-- | VTable for our traversal
data Ctx m = Ctx {
  -- | Transformation when case matched
  onSuccess :: Trans1 m,
  -- | Transformation when case fails
  onFailure :: Trans1 m,
  -- | Top-level transformation for recursion on child-terms
  onRecurse :: Trans1 m
  }

A transformation now looks like Ctx m -> Trans1 m. Typically, CPS in Haskell is curried. But Trans1 m -> Trans1 m -> Trans1 m -> Trans1 m makes compiler errors a living nightmare, so lets not go there. As an optimization we will also track some meta-data such as relevant types, so we can skip any sub-expressions whose types remain untouched by the transformation. I brazenly stole the idea and implementation from the lens library, only changing a couple line so that it works for multiple target types.

data Trans m = T {
    relevant :: !(S.HashSet TypeRep),
    toplevelRecursion :: Bool,
    withCtx :: Ctx m -> Trans1 m
}

This allows us to chain Trans m types sequentially or alternatively:

-- | Alternative composition of transformations
-- In @a ||| b@, we only run @b@ if @a@ fails.
(|||) :: forall m. Monad m => Trans m -> Trans m -> Trans m
l ||| r = T relevantTypes containsRecursion trans
  where
    relevantTypes = relevant l `S.union` relevant r
    containsRecursion = toplevelRecursion l || toplevelRecursion r
    trans :: Ctx m -> Trans1 m
    trans ctx = withCtx l (ctx { onFailure = withCtx r ctx })
infixl 1 |||

-- | Sequential composition of transformations
-- In @a >>> b@, we only run @b@ if @a@ succeeds.
(>>>) :: forall m. Monad m => Trans m -> Trans m -> Trans m
l >>> r = T relevantTypes containsRecursion trans
  where
    relevantTypes = relevant l `S.union` relevant r
    containsRecursion = toplevelRecursion l
    trans :: Ctx m -> Trans1 m
    trans ctx = withCtx l ctx{ onSuccess = withCtx r ctx }
infixl 1 >>>

-- Apply transformation to each child-term, always succeeds
recurse :: Monad m => Trans m
recurse = T mempty True $ \Ctx{..} -> onSuccess <=< gmapM onRecurse

And wrap simple transformation functions into Trans:

tryTrans :: forall a m. (Monad m, Data a) => (Trans1 m -> a -> Maybe a) -> Trans m
tryTrans f = T relevantTypes containsRecurions transformation 
  where
    relevantTypes = S.singleton (typeRep @a)
    containsRecursion = False
    transformation Ctx{..} (a::a') = Case eqT @a @a'
      Just Refl l -> case f onRecurse a of
           Just a' -> onSuccess a'
           Nothing -> onFailure a
      Nothing -> onFailure a

To run traversals we have to tie the context knot. Here is where all the recursion lives: f uses ctx, ctx refers to f. We also finally use the collected meta-data to use the lens hitTest function.

runT :: forall m a. (Monad m, Data a) => Trans m -> a -> m a
runT trans a0 = f a0
  where
    Oracle oracle = hitTest a0 (relevant trans)
    ctx = Ctx { onSuccess = pure, onFailure = pure, onRecurse = f }
    f :: forall x. Data x => x -> m x
    f x = case oracle x of
      -- When the type is relevant, apply the transformation
      Hit _ -> withCtx trans ctx x
      -- If the type contains relevant types
      -- and the transformation would `recurse`, recurse
      Follow 
        | toplevelRecursion trans -> gmapM f x
      -- otherwise short-cicuit
      _ -> pure x

Conclusion

I accidentally re-discovered this pattern several times and have found it incredibly useful for quickly prototyping transformations. I also have not seen it in the wild before and figured others might find it useful.

However, I have not done much benchmarking. The HitTest optimization seems to help, but a thorough comparison against handwritten traversals would be nicer. I also have two open design questions:

If anyone has ideas, or has experience with similar patterns, I’m all ears!

Also, there is a reason fast OOP languages tend to run with a JIT compiler - specializing the indirect calls away could make this as fast as hand-written code. The Optimizing SYB is easy! paper may work, though last I checked Hermit (the GHC transformation DSL it was written in) was stuck at the GHC 7.4 API.

Thanks for reading!

Postscript: A larger example

I thought I’d add a bigger example to show how this scales to larger transformations.

Occasionally, we may want to relabel variables so that every name is globally unique. This can be useful for pretty-printing or tos implify a code analysis. Our logic will be the following:

Using a state monad, we can turn this into two recursive blocks which we execute in sequence:

locally :: (MonadState s m) => m a -> m a
locally m = do
  old <- get
  a <- m
  put old
  pure a

compactVarsT :: (MonadVar m, MonadState (M.Map Var Var) m) => Trans m
compactVarsT
  =   block (refreshGlobalVar ||| recurse)
  >>> loggingM "Global var mappings: " (gets M.toList)
  >>> block (refreshLocalBinder ||| lookupRenamedVar ||| recurse)
 where
  refreshGlobalVar = transM_ \(Source v) -> Source <$> refreshVar v
  
  refreshLocalBinder
    =  tryTransM @Lang (\rec -> \case
         Bind expr var body -> Just $ do
              expr <- rec expr
              locally $ do
                  var <- refreshVar var
                  body <- rec body
                  pure (Bind expr var body)
         AsyncBind binders body -> Just $ do
              -- async binders are independent;
              -- rewrite the bound expressions before the new names are in scope!
              binders <- traverseOf (each . _2) rec binders
              locally $ do
                  binders <- traverseOf (each . _1) refreshVar binders
                  body <- rec body
                  pure (AsyncBind binders body)
         _ -> Nothing)
    ||| tryTransM @OpLang (\rec -> \case
         Let var expr body -> Just $ do
              expr <- rec expr
              locally $ do
                  var <- refreshVar var
                  body <- rec body
                  pure (Let var expr body)
         _ -> Nothing)
         
  lookupRenamedVar
     = tryTransM_ @Lang \case
         LRef r -> Just $ gets (LRef . (M.! r))
         _ -> Nothing
     ||| tryTransM_ @Expr \case
          Ref r -> Just $ gets (Ref . (M.! r))
          _ -> Nothing
          
  refreshVar v = do
     gets (M.!? v) >>= \case
       Nothing -> do
         v' <- genVar (name v)
         modify (M.insert v v')
         pure v'
       Just v' -> pure v'
  1. I cut out a section about OOP vtables because it wasn’t terribly relevant to SYB and made people think this was an OOP blog post.