Introduction
A few days ago I stumbled upon a recent line of research that applies an old and somewhat obscure idea from functional programming languages (delimited continuation-passing) to an old but very much alive idea from numerical computing (automatic differentiation, AD). The result is an elegant algorithm, which remains close to the textbook treatment of reverse-mode AD (“backpropagation”) and could rightly be considered its “natural” implementation.
I will first “read aloud” the reference Scala implementation from the original paper [1], and then do the same for the corresponding parts of a Haskell library I’ve written that implements its ideas, ad-delcont. In the Haskell version all effects are made explicit and tracked at the type level without relying on any compiler plugin.
Along the way I’ll also walk through an elementary example that hopefully clarifies how delimited continuations work in general.
In a previous post I illustrated the fundamentals of automatic differentiation, as implemented in most imperative programming languages. It turns out, a computational tape is not the only available option for inverting control flow, in sufficiently advanced programming languages. How can we implement reverse-mode automatic differentiation in a purely-functional setting? Read on!
Staring at higher-order Scala code
Here are two short code snippets that originally appeared in [1].
class NumR (val x: Double, var d: Double) {
def + (that: NumR) = shift { (k: NumR => Unit) =>
val y = new NumR(this.x + that.x, 0.0);
k(y);
this.d += y.d;
.d += y.d
that}
This is a Scala implementation of a “plus” function that sums dual numbers. It relies on delimited continuations to achieve non-local control flow and specify what to do when a continuation returns. My Scala is pretty rusty so this has been a head scratcher for a while. I’ll first document how my train of thought went while reading this code, and then try to break it down more formally.
First we declare a dual number type
NumR
, which has fields.x
and.d
for the primal and adjoint respectively.The implementation of the
+
method is bracketed within a mysteriousshift
higher-order function, which declares a continuationk
, to be used later.A temporary variable
y
is declared, having 0 dual value and primal value set to the function result.k
is then applied toy
, and the return value ofk
is discarded (?!). This must mean thaty
itself is mutated within the execution ofk
.Upon returning from
k
, the dual part of the mutated value ofy
is used to update by accumulation the dual parts of the input variablesx
andy
.
The other interesting snippet is where all the work happens : the function value is computed, the adjoint accumulation process kicks off (in the “backwards” sweep) and the gradient is returned:
def grad(f: NumR => NumR @cps[Unit] )(x: Double) = {
val z = new NumR(x, 0.0)
{
reset f(z).d = 1.0 }
.d
z}
grad
is a higher-order function that takes the function to be differentiated as a parameter (f: NumR => NumR
, overloaded to act upon dual numbersNumR
), and an evaluation pointx
.A temporary variable
z
is declared, having 0 adjoint part and primal part corresponding to the point of interestx
.z
will accumulate the partial derivative off
with respect tox
.Within another mysterious bracket called
reset
, the functionf
is evaluated atz
, then its adjoint part is set to 1.Upon exiting from the
reset
block, the adjoint part ofz
is returned : the partial derivative ∂xf we are interested in.
Delimited continuations in Haskell with shift
/reset
The shift
and reset
operators are one variant of a notion of “delimited continuations”, which originated in the Lisp community in the late 80s: the scope of a continuation is made explicit, thus control can “bounce back” at points specified by the programmer. More specifically, shift
“captures” a continuation, and reset
delimits it.
I’m not a programming languages researcher so diving into the original publications didn’t exactly help. Fortunately, a bit of tinkering can save us hours of poring over old papers.
shift
/reset
are readily available in the transformers
Haskell library, within module Control.Monad.Trans.Cont
.
Here’s a minimal snippet to use both shift
and reset
, composed with variable “mutation” in the State
monad. To be precise we will use the continuation monad transformer ContT
, and its corresponding operators shiftT
and resetT
, to compose other “effects” together with continuations:
t1 :: ContT Int (State [Int]) Int
= resetT $ do
t1 let
= 1 -- input
x = lift $ modify (w :)
cons w <- shiftT $ \k -> do
r
cons xlet y = x + 1
<- lift $ k y -- 1)
z -- 4)
cons z pure y -- 5)
0 -- 2)
cons pure r -- 3)
Running the example above elucidates how the order of variable mutation is affected :
λ> flip runState [] $ evalContT t1
(2,[2,0,1])
As soon as the continuation
k
is invoked (applied to valuey = 2
), control exits from theshiftT
block,continues at the next line (in this case appending a
0
to the list used as state variable),and when the “boundary” defined by the lexical scope enclosed by
resetT
is encountered, control returns to the next line after the one that calledk
.At this point (within
shiftT
)z
is bound to whatever was returned by theresetT
block, which in turn is the valuek
was applied to, i.e.y = 2
. This is why the next appended value is a 2.Since
k
was resolved by a matchingresetT
, there’s nothing else to do and execution terminates.
Pretty mind-bending the first time I saw it.
Introducing ad-delcont
As it turns out, this non-local control flow (i.e. delegating to a continuation, doing something and returning to the starting point with the results) is well suited to implementing the forward-backward computation needed in reverse-mode automatic differentiation.
In order to convince myself of this, I’ve implemented the ideas of [1] in a Haskell library. Overall, I find the result pretty satisfying both from a theoretical and ergonomic standpoint.
op1 :: (Num da, Num db) =>
-> (b, db -> da)) -- ^ returns : (function result, pullback)
(a -> ContT x (ST s) (DVar s a da)
-> ContT x (ST s) (DVar s a db)
= do
op1 f ioa <- ioa
ra D xa _) <- lift $ readSTRef ra
(let (xb, g) = f xa -- 1)
$ \ k -> lift $ do
shiftT <- var xb 0 -- 2)
rb <- k rb -- 3)
ry D _ yd) <- readSTRef rb -- 4)
(-> rda0 + g yd)) -- 5)
modifySTRef' ra (withD (\rda0 pure ry
The above is a pretty faithful port of the Scala version (for a unary function such as $\sqrt{ \cdot }$ to reduce clutter), in which the major differences are the explicit tracking of the effects (mutation and continuation) at the type level. How does this work ?
Compute the function result and bind the function inputs to the adjoint updating function (the “pullback”)
Allocate a fresh STRef
rb
with the function result and 0 adjoint partrb
is passed downstream as an argument to the continuationk
, with the expectation that the STRef will be mutatedUpon returning from the
k
(bouncing from the boundary ofresetT
), the mutated STRef is read back inThe adjoint part of the input variable is updated using
rb
(accumulating the adjoints by summing them together, as this variable might be used in more than one program branch) and the result of the continuation is returned.
In the Haskell case, we pass mutable references to dual variables within the ST
monad (introduced in [2] and readily available in the Haskell standard library at Control.Monad.ST
)
The code computing the gradient is correspondingly succint and maps almost exactly (modulo “lifting” and mutation in ST
) to its Scala counterpart from [1]:
rad1 :: (Num a, Num b) =>
forall s . AD' s a -> AD' s b) -- ^ function to be differentiated
(-> a -- ^ function argument
-> (b, a) -- ^ (result, adjoint)
= runST $ do
rad1 f x <- var x 0
xr <- evalContT $
zr' $ do
resetT let
= f (AD (pure xr))
z <- unAD z
zr $ modifySTRef' zr (withD (const 1))
lift pure zr
D z _) <- readSTRef zr'
(D _ x_bar) <- readSTRef xr
(pure (z, x_bar)
AD
is just a newtype wrapper around ContT .. (ST s) ..
, in which the return variables are STRef
s containing our dual variables; it implements the Num
, Fractional
, Floating
interface and the library provides combinators for implementing new typeclasses as well.
Discussion
This was a rather long and technical post. I hope I suceeded in showing how delimited continuations can be put to work to implement a purely-functional version of reverse-mode AD. This is a crucial component in the modern optimization and machine learning toolbox, and I find its functional version to be particularly pleasing.
ad-delcont
is a small but fully functional library. It’s lightweight (fits anywhere you use transformers
) and easily extensible as shown in its documentation, e.g. by specializing it to different number-like types. I’m looking forward to see what people will use it for!
Feel free to contact me on github or twitter with feedback or just to have a chat on these and related things!
References
[1] Wang, Rompf - A Language and Compiler View on Differentiable Programming - ICLR 2018 https://openreview.net/forum?id=SJxJtYkPG
[2] Launchbury, Peyton Jones - Lazy Functional State Threads - PLDI 1994 https://www.microsoft.com/en-us/research/wp-content/uploads/1994/06/lazy-functional-state-threads.pdf