#!/usr/bin/python # -*- coding: utf-8 -*- """XXX unfinished reverse-mode automatic differentiation for Python. I had previously done forward-mode automatic differentiation in Python just using dual numbers, but reverse-mode is much better for nonlinear optimization, because in nonlinear optimization you are normally trying to minimize (or maximize) a scalar objective function, and you want to calculate the gradient of that scalar with respect to a potentially large set of decision variables that you can change to improve it. Forward-mode automatic differentiation associates a gradient vector with each expression node, its gradient with respect to the inputs, while reverse-mode automatic differentiation instead associates a vector of partial derivatives of *outputs* with each node --- in the scalar case, just a number. So if you have a thousand scalar inputs and a scalar output, you would expect reverse-mode automatic differentiation to be about a thousand times faster, about half as fast as computing the value without any differentiation. It's not quite that simple because there *is* a cost: reverse-mode differentiation can't begin until the objective function has been computed, at which point it can begin to propagate the derivatives back through the graph to the inputs. This means that you can't discard all the temporary values computed while your algorithm runs; you have to propagate derivatives back through them later. So they take up memory. For cases where this is prohibitive, as I understand it, advanced automatic-differentiation systems checkpoint the computation state periodically, and when it comes time to propagate the derivative backwards, they iterate backwards through the checkpoints, using each one to reconstruct the dependency graph until the next checkpoint, so that the derivatives can be propagated backwards over that time interval. However, I hadn't figured out how to do reverse-mode automatic differentiation *at all*, and now I think I have a viable attack on the problem in Python. I'll build a computation graph using operator overloading, so that running a normal computation with one or more magic nodes as an input will produce a magic node as an output which is really a handle to the whole computational dependency graph. (Aside from permitting reverse-mode automatic differentiation, this also allows using the graph to do incremental recomputation.) """ from __future__ import division, print_function import weakref def test(): g = Graph() # Define the independent variables: x, y = g.var(0.4, 'x'), g.var(4.2, 'y') # These are not variables, for the purpose of our automatic # differentiation; they are constants: x1, y1, r1 = 0, 0, 5 x2, y2, r2 = 6, 0, 5 # Build the computation graph: d1 = ((x-x1)**2 + (y-y1)**2 - r1)**2 d2 = ((x-x2)**2 + (y-y2)**2 - r2)**2 loss = d1 + d2 print("loss is %f" % float(loss)) # This works, but suggests that incrementally updating the graph # after changing x costs about a millisecond on my netbook: # for i in range(1000): # x += 0.01 # print("loss is %f" % float(loss)) # return i = 0 while loss > 0.01 or i < 1000: i += 1 gradient = loss.gradient() print("with (x, y) = (%.2f, %.2f), loss = %.4f, gradient=%s" % ( float(x), float(y), float(loss), gradient)) # Update the variables, resulting in recomputation: x += gradient[x] * 0.1 y += gradient[y] * 0.1 class Graph(object): def __init__(self): pass def var(self, val, name="(anonymous)"): return Var(val, name) class GraphNode(object): def __init__(self, valid): # Create a unique object to serve as an ID in dictionaries, # because we have to make `==` use the underlying value, in # order to be able to pass as a primitive type in algorithms # that do comparisons for equality. self.idtag = object() self.valid = valid # We use a weak dictionary here so that if there are a lot of # observers that get discarded but never invalidated, they # won't take up a lot of memory. self.observers = weakref.WeakValueDictionary() def observe(self, tag, observer): self.eval() self.observers[tag] = observer return self.val def invalidate(self): if self.valid: self.valid = False del self.val for observer in self.observers.values(): observer.invalidate() self.observers.clear() def observees(self): return [] def __int__(self): self.eval() return int(self.val) def __float__(self): self.eval() return float(self.val) def __str__(self): self.eval() return str(self.val) def __add__(self, other): return Addition(self, as_graph_node(other)) def __sub__(self, other): return Subtraction(self, as_graph_node(other)) def __pow__(self, other): return Power(self, as_graph_node(other)) # XXX all the other arithmetic operation overloads here # and also comparisons def gradient(self): self.eval() return Gradient(self) def as_graph_node(thing): return thing if isinstance(thing, GraphNode) else Constant(thing) class Constant(GraphNode): def __init__(self, val): super(Constant, self).__init__(valid=True) self.val = val def eval(self): pass class Var(GraphNode): def __init__(self, val, name): super(Var, self).__init__(valid=True) self.val, self.name = val, name def eval(self): pass def __iadd__(self, increment): self.set_val(self.val + increment) return self def set_val(self, new_val): self.invalidate() self.val, self.valid = new_val, True class BinaryOp(GraphNode): def __init__(self, a, b): super(BinaryOp, self).__init__(valid=False) self.rands = a, b def observees(self): return self.rands def eval(self): if not self.valid: a_val = self.rands[0].observe(self.idtag, self) b_val = self.rands[1].observe(self.idtag, self) self.val = self.op(a_val, b_val) self.valid = True class Subtraction(BinaryOp): def op(self, a, b): return a - b def chain(self): assert self.valid return [(1, self.rands[0]), (-1, self.rands[1])] class Addition(BinaryOp): def op(self, a, b): return a + b def chain(self): assert self.valid return [(1, self.rands[0]), (1, self.rands[1])] class Power(BinaryOp): def op(self, a, b): return a ** b def chain(self): assert self.valid a, b = self.rands av = a.val bv = b.val # XXX handle the general case! assert bv == 2 assert isinstance(b, Constant) return [(2, a)] class Gradient(object): """Compute and contain a gradient. The easy question is how to get the gradient values out: we store them in a dictionary by the nodes' idtags. The slightly harder part is how to do the graph traversal, because each node in the graph might be receiving derivativity from any number of dependents, and it's desirable to avoid exponential blowups. Now, we actually have that information in the .observers dictionary of the graph node, so we could in fact wait to visit each node until we have visited all of its .observers. However, what if we can't reach the whole graph from the node we're starting from? We really only want to care about observers that we will eventually visit. And actually we don't even need to use .observers in that case: we can merely do a topological sort of the relevant subgraph graph after traversing it to find all the relevant edges. """ def __init__(self, destination): self.gradients = {} # Do initial traversal to construct reverse dependency graph observers = {} candidates = {destination} while candidates: candidate = candidates.pop() for observee in candidate.observees(): if observee not in observers: observers[observee] = set() observers[observee].add(candidate) # fuck, I'm lost, maybe tomorrow pass def __getitem__(self, node): return self.gradients[node.idtag] if __name__ == '__main__': test()