Abstracting over Branch and Bound
Introduction
Wordle is a Mastermind-like browser game where you guess five-letter words. The game answers which letters are correct, in the wrong position, or flat out wrong. You have 6 guesses to win. This is a great balance because most people win most of the time, but I still feel like an absolute genius whenever I do.
Here is an interesting meta-puzzle: What is the optimal Wordle strategy? We can figure this out with Branch and Bound (BnB). BnB describes a class of algorithms I hadn’t heard of for the longest time, probably because the literature tends to be really vague.
Maybe I should mention my final solution ended up being in Rust because it turned out immensely easier to write memory-efficient code there. The first Haskell version took 20GB+ of memory, Rust using bitsets used 50MB. But while reimplementing the code I really wanted to abstract over the pruning search. And I came up with an (albeit hacky) approach - but it only works in Haskell. This post is a brief sketch of the abstraction.
All descriptions of BnB I could find concur on the following steps:
- We have some representation for subproblems
- We can split unsolved subproblems into smaller subproblems
- We can prune some subproblems which won’t lead to an optimal solution
Admittedly that’s more concrete than “we use computers and algorithms to find a solution”, but not by much. So, I will focus on one specific type of pruning which can be found in most BnB implementations:
- We try to give subproblems a minimum and maximum cost
- If we have two subproblems
a
andb
, whenmax_cost(a) < min_cost(b)
then we can pruneb
because we know it won’t lead to an optimal solution
Much more concrete, and quite close to alpha-beta-pruning for game trees. For solved subproblems the max_cost
and actual cost coincide so we don’t have to bother writing a max_cost
heuristic - we can use the smallest finished solution found so far. All intuitions about A*
heuristics work for min_cost
because it’s the same idea, turns out A*
can be seen as a BnB algorithm. Beyond these basics we could use search strategies other than depth first, use smarter goal expansion, and add further filtering like plane cutting or dominance pruning. For today let’s stick to the basics, though.
In Haskell
To abstract over all this in Haskell we will use a monad transformer. We can express alternatives with the Alternative
typeclass, a <|> b
. Sequential steps naturally fit as a monad. The monad can keep track of the slack we have left - whenever we emit a cost we reduce the slack, once we go below 0
we can prune the current computation:
newtype Bound s a = Bound { unBound :: StateT s [] a }
deriving (Monad, Applicative, Alternative, Functor, MonadState s, MonadFail)
class Monad m => MonadBound s m where
tellCost :: s -> m ()
class Slack s where
dropSlack :: s -> s -> Maybe s
addSlack :: s -> s -> s
instance Slack s => MonadBound s (Bound s) where
tellCost cost = do
slack <- Bound get
case dropSlack slack cost of
Nothing -> empty
Just slack' -> Bound (put slack')
To keep things interesting, I will add another constraint, supporting the following optimization:
If we can split our subgoal into n
independent steps child_1...child_n
, then we can give a better (i.e. lower) slack value:
min_costs = children.map(min_cost)
child_slack[i] = slack - (sum(min_costs) - min_costs[i])
In words, to compute the slack for subgoal i
we can safely subtract the min_cost
of all other independent subgoals from the current slack in advance. This often gives us a much better estimate! Thankfully Haskell has a typeclass for independent steps (Applicative
) and a language extension to rewrite Monadic code to Applicative steps (-XApplicativeDo
). The extension is mostly used for implicit parallelism but collecting as much min_cost
information as possible works perfectly fine.
However, wiring this cost information and monadic flow up sounds complicated. Instead, we will build a stack of monad transformers that each have narrow purposes. For real code we probably would want to inline everything as a final step because deeply nested monad transformers do not optimize well.
Memoization (planning stages)
As a final complication I want to support memoization. This makes the slack computation harder because cost might be context sensitive. Let’s use Wordle as an example. For the following tree our total cost is 1+2+3+2+3+3=14
.
If we use this cached tree in another position, and this position is at depth 3
, then we must update the cost to 4+5+6+5+6+6=32
! We can work around this by splitting context, cost, and slack into three types:
newtype WordleSlack = WSlack Int
instance Slack WordleSlack where
addSlack = coerce (+)
dropSlack (WSlack l) (WSlack r)
| leftover >= 0 = Just (WSlack leftover)
| otherwise = Nothing
where leftover = l - r
class (Slack s, Semiring o) => Cost c o s where
inContext :: c -> o -> s
-- | The monoid instance merges sequential costs,
-- Semiring merges alternative cost, e.g. Maybe (Sum Int)
-- zero is failure, plus takes the smaller non-failed cost
class Monoid o => Semiring o where
zero :: o
plus :: o -> o -> o
data WordleCost = WCost { totalCost :: Int, nodeCount :: Int }
deriving Monoid via GenericAs WordleCost (Sum Int, Sum Int)
deriving Ord
newtype WordleContext = Depth Int
instance Cost WordleContext WordleCost WordleSlack where
-- adjust for the cost of shifting `nodeCount` nodes down by `depth`
inContext (Depth depth) (WCost {totalCost, nodeCount}) = WSlack (totalCost + nodeCount * depth)
Now we can build our next monad transformer. Annoyingly, slack must be global because we want to adjust it for future passes whenever we find a new and better solution. Context could live in a ReaderT
monad, and Cost
in a WriterT
, but we will put them in a single StateT
to keep things simple. Stream
is some version of ListT
-done-right.
newtype BoundM c o s m a = BoundM { unBoundM :: StateT (c,o) (Stream (StateT s m)) a }
We also need a monad to track the minimum of every Applicative
branch we can see, and we do this using a fake monad:
newtype LowerBound o a = LowerBound o
deriving (Functor)
instance Monad (LowerBound o) where
LowerBound l >>= _ = LowerBound l
instance (Monoid o) => Applicative (LowerBound o) where
pure _ = LowerBound mempty
LowerBound l <*> LowerBound r = LowerBound $ l <> r
instance (MonoidAlt o) => Alternative (LowerBound o) where
empty = LowerBound zero
LowerBound l <|> LowerBound r = LowerBound $ l `plus` r
We can then combine the monads by running them Both
:
data Both m n a = MB { bFirst :: (m a), bSecond :: (n a) }
deriving Functor
instance (Monad m, Monad n) => Monad (Both m n) where
return a = MB (return a) (return a)
MB m n >>= f = MB (m >>= bFirst . f) (n >>= bSecond . f)
instance (Applicative m, Applicative n) => Applicative (Both m n) where
pure a = MB (pure a) (pure a)
MB m n <*> MB m' n' = MB (m <*> m') (n <*> n')
instance (Alternative m, Alternative n) => Alternative (Both m n) where
empty = MB empty
MB m n <|> MB m' n' = MB (m <|> m') (n <|> n')
newtype BnB c o s m a = BnB { unBnB :: Both (BoundM c o s m) (LowerBound o) a }
deriving (Functor, Alternative)
Before we run a Monadic bind, we pre-pay the minimum cost for the left hand side.
instance (Monad m, Cost c o s) => Monad (BnB c o s m) where
return = pure
l >>= r = let (cost, l') = collectCost l in reduceSlack cost *> BnB (unBnB l' >>= unBnB . r)
When we reach a withMinCost
annotation, which gives a heuristic cost for the containing block, we emit this minimum cost so it will be paid in advance. But before entering the block we refund this cost so it can be paid for real this time. During execution we might find that the cost is higher than expected, which either prunes the branch or at least further reduces the slack for following steps.
withMinCost :: Cost c o s => o -> BnB c o s m a -> BnB c o s m a
withMinCost o m = liftRight (LowerBound o) *> (liftLeft (increaseSlack o) >> m)
A small example might be useful here:
test fooBar barFoo = do
withMinCost 5 $ do
when fooBar (tellCost 1)
tellCost 5
withMinCost 3 $ do
when barFoo (tellCost 1)
tellCost 3
Now if we run test True False
after -XApplicativeDo
rewrites the definition into *>
, we execute:
reduceSlack 8 -- `>>=` (or the run function) pre-pays (-8)
increaseSlack 5 -- `withMinCost` refunds (-3)
when True (tellCost 1) -- (-4)
tellCost 5 -- (-9)
increaseSlack 3 -- `withMinCost` refunds (-6)
when False (tellCost 1) -- (-6)
tellCost 3 -- (-9)
If at any step our incoming slack is insufficient, we abort without having to run the rest. Because the code is trivial withMinCost
doesn’t carry more information than the bodies, but for branching code or recursive functions we need to supply a heuristic.
We can then write a rather ugly loop which keeps track of the best solution found so far:
pickBest :: (Monad m, Cost c o s, Semiring o, Ord o) => BnB c o s m a -> c -> s -> m (Maybe (a,o))
pickBest (BnB (MB m0 (LowerBound bound0))) ctx0 slack0 = flip evalStateT slack0 $ go slack0 Nothing $ flip runStateT (ctx0, mempty) $ unBoundM $ reduceSlack bound0 *> m0
where
go oldSlack oldBest m = do
put oldSlack
step <- runStream m
case step of
Done -> pure oldBest
Yield (a,(ctx,newCost)) n -> case oldBest of
Just (_,oldCost) | newCost >= oldCost -> go oldSlack oldBest n
_ -> go (inContext ctx newCost) (Just (a,newCost)) n
There is a slight complication, we need an initial value for slack. Either we do a first heuristic pass to find some reasonable guess, or we instantiate with Maybe WordleSlack
and skip pruning when the cost is Nothing
. If we want to add additional pruning, like not going past depth 6, we can similarly adjust the cost/slack/context types.
Memoization (for real this time)
We can add caching with yet another state monad, we only need to produce a cache key for the arguments to the cached function. For Wordle we can do this as 5 Word32
arguments that encode a bitset. If letter a
can still occur in position 1
, then the first bit in the first Word32
is set. This becomes rather janky if we guess a word in which a letter occurs multiple times because we cannot store frequency information like letter a must appear twice in unknown positions
. It works well enough to find an optimal solution and is reasonably memory efficient, though.
But the context sensitivity strikes again. If we first compute a solution at depth 4, we prune whenever we go above depth 6, so we only consider solutions of height 2 or less. If we later encounter the same arguments at depth 3 then we must reconsider previously pruned solutions, maybe some solution is mostly flat but has a single length-three guess chain. On the other hand, if we found an amazingly great solution that has only depth 2 then we should use the cached result.
A simple solution is to cache by key and context and allowing a single call to emit solutions for multiple contexts. This means a result with depth 2 is inserted in the cache at depths [1..4]
, and independently merged at each level.
newtype Caching k c o a = Caching { unCaching :: State (M.Map k (M.Map c o)) a }
deriving (Functor, Applicative, Monad)
Some subproblems (guessing pizza
first) take longer to compute, but because they usually get pruned before finishing, we don’t actually cache their cost. We can exploit the polymorphism by inserting Either MinCost RealCost
in the cache, storing some lower bound as Left
if we prune early. This makes it more likely that we can use the cache when computing the min_cost
heuristic, letting us prune earlier in bad search areas.
Rambly Ending
And that’s everything I wanted to cover.
I do not think I would use this approach. -XApplicativeDo
is immensely hard to reason about, and the wrong batching could cause orders of magnitude slowdown in a search. Using newtypes in the style of the async
library on the other hand seems reasonable.
The Caching strategy, and creating multiple entries in a map, seems inefficient. It’d be much worse for contexts with partial orderings, maybe storing a list of cache entries as an lru cache or compressing them with chain decomposition could work.
We could add a fancier search monad than Stream
to mix in best first search/cyclic best first search or add fancier pruning strategies.
For a deep dive into the world of branch-and-bound algorithms this paper seems like a good overview.
If there are existing approaches of encoding automatic pruning in Haskell (or approaches for hacky context gathering) out there please tell me, I’d love to hear about them!