Friday, November 7, 2008

Beautiful folding


> {-# LANGUAGE ExistentialQuantification #-}

> import Data.List (foldl')

If you're not a Haskeller, and were thus hoping to learn how to fold a shirt beautifully, I'm afraid you're out of luck. I don't know either.

Much has been said about writing a Haskell function to calculate the mean of a list of numbers. For example, see Don Stewart's "Write Haskell as fast as C". Basically, one wants to write "nice, declarative" code like this:

> naiveMean :: Fractional a => [a] -> a
> naiveMean xs = sum xs / fromIntegral (length xs)

but if xs is large, sum will bring the whole thing into memory, but the garbage collector won't be able to collect it, since we still need it to calculate the length.

The solution is to calculate both the sum and the length in one pass, and it's usually written something like this:

> uglyMean :: Fractional a => [a] -> a
> uglyMean xs = divide $ foldl' f (P 0 0) xs
> where
> f :: Num a => Pair a Int -> a -> Pair a Int
> f (P s l) x = P (s + x) (l + 1)
> divide (P x y) = x / fromIntegral y

where P is a strict pair constructor. This works, but where is the elegance, abstraction and modularity that Haskell is supposed to be famous for? Don's solution is even uglier (sorry, Don): not only does he write the reductor (our f) explicitly, but also the fold itself.

What I hope to do here is to abstract this pattern away, by making "combinable folds". I only do foldl', although foldl1' could be handy.

To make folds combinable, we need to turn folds into data: a fold is a function (the reductor) with an initial value. To make folds more readily combinable, we add a post-processing function (here it is divide). Now that we have the post-processor, we don't need to look at the accumulator directly, so we make it existential. The type Fold b c is for folds overs lists of type [b], with
results of type c.

> data Fold b c = forall a. F (a -> b -> a) a (a -> c)

We'll need a strict pair type, and I don't want to give my blog a dependency on the strict package, so I introduce my own:

> data Pair a b = P !a !b

Now that folds are data, we can start manipulating them. For example, we can
combine two folds to get a pair of results (we make the result an ordinary tuple for convenience, but use strict pairs for the accumulator to get the rightstrictness). The (***) defined here is like the one in Control.Arrow, but takes a strict pair as input. The reductor (comb f g) is basically (first f) . (second g) for strict pairs.

> both :: Fold b c -> Fold b c' -> Fold b (c, c')
> both (F f x c) (F g y c') = F (comb f g) (P x y) (c *** c')
> where
> comb f g (P a a') b = P (f a b) (g a' b)
> (***) f g (P x y) = (f x, g y)

Our next combinator simply adds an extra post-processor.

> after :: Fold b c -> (c -> c') -> Fold b c'
> after (F f x c) d = F f x (d . c)

The next one, bothWith, is a combination of both and after.

> bothWith :: (c -> c' -> d) -> Fold b c -> Fold b c' -> Fold b d
> bothWith combiner f1 f2 = after (both f1 f2) (uncurry combiner)

Now that we have tools to build folds, we want to actually fold them, so here is combinator foldl':

> cfoldl' :: Fold b c -> [b] -> c
> cfoldl' (F f x c) = c . (foldl' f x)

Now lets see a few basic folds:

> sumF :: Num a => Fold a a
> sumF = F (+) 0 id

> productF :: Num a => Fold a a
> productF = F (*) 1 id

> lengthF :: Fold a Int
> lengthF = F (const . (+1)) 0 id

And, the moment we've all been waiting for, combining basic folds to get the mean of a list:

> meanF :: Fractional a => Fold a a
> meanF = bothWith (/) sumF (after lengthF fromIntegral)

> mean :: Fractional a => [a] -> a
> mean = cfoldl' meanF

Pretty simple, eh? Perhaps not quite as pretty as naiveMean, but best of all, it doesn't eat your memory and kill your swap like naiveMean does.

> main = do
> let xs = [1..10000000]
> print $ mean xs

Compiling with GHC 6.8.2 and -O2, this runs in about 1.2 seconds (on my three-year-old laptop) and uses less than a meg of memory. GHC generates the same code for mean and uglyMean. [Originally uglyMean was slightly faster, but this was because of type defaulting: the result of lengthF defaulted to Integer]

One thing remains. What do Haskellers do when there's a pretty way and a fast way (or at least a way that's more susceptible to optimisation) to do the same thing? We write rewrite rules. So we'd like to convert sum, length, etc. into combinable folds, and then combine them. Something like this:

> {-
> {-# RULES
> "sum/sumF" sum = cfoldl' sumF
> "product/productF" product = cfoldl' productF
> "length/lengthF" length = cfoldl' lengthF
> "multi-cfoldl'" forall c f g xs. c (cfoldl' f xs) (cfoldl' g xs)
> = cfoldl' (bothWith c f g) xs
> #-}
> -}

So why are these commented out? Unfortunately, GHC doesn't like the
all-important "multi-foldl'" rule: it doesn't have a named function at its head (it has the variable c). GHC doesn't allow rules of this form, presumably for efficiency and simplicity in the compiler.

So unfortunately, we can't go back to writing pretty-but-naive code, but with these combinators at our disposal, we are at least saved from writing *ugly* code.