From:
"andrew cooke" <andrew@...>
Date:
Sun, 31 Dec 2006 20:21:52 -0300 (CLST)
I read this post -
http://wmfarr.blogspot.com/2006/10/automatic-differentiation-in-ocaml.html
- and realised I had written the same code in Python. I hunted around and
found it on my (password protected) diary -
http://www.acooke.org/andrew/diary/2004/mar/4.html - so here it is in public.
It parses, differentiates, and then simplifies simple numerical terms.
Andrew
output:
the original expression is a+3*b*a
it was parsed to (a+(3*(b*a)))
the differential wrt a is (1+(3*b))
the original expression is a/b
it was parsed to (a/b)
the differential wrt b is (0-(a*(1/(b*b))))
the original expression is a+b+c+d
it was parsed to (a+(b+(c+d)))
the differential wrt b is 1
the original expression is a+b*c+d
it was parsed to (a+((b*c)+d))
the differential wrt b is c
code:
# calculate first derivatives of an arithmetic expression
# basic code using just +,-,/,*,(), integers and variables (lowercase), #
but functions aren't any harder conceptually
# here's the interesting bit
# walk the ast to calculate the derivative wrt some variable
# you'd add functions in the normal way - for example sin(...) would # map
to diff(...) * sin(...) + ... * cos(....)
def diffwrt(n,var):
def ifn(n): return 0
def vfn(v):
if v is var: return 1
else: return 0
def nfn(op, n1, n2, dn1, dn2):
if op is '+' or op is '-':
return (op, dn1, dn2)
elif op is '*':
return ('+',
('*', dn1, n2),
('*', n1, dn2))
elif op is '/':
return ('-',
('/', dn1, n2),
('*', n1,
('/', dn2, ('*', n2, n2))))
return folddown(ifn, vfn, nfn, n)
def tidy(n):
def ifn(n): return n
def vfn(v): return v
def nfn(op, n1, n2, dn1, dn2):
if isinstance(dn1, int):
if dn1 is 0:
if op is '+': return dn2
elif op is '*': return 0
elif op is '/': return 0
elif dn1 is 1 and op is '*': return dn2
if isinstance(dn2, int):
if dn2 is 0:
if op is '+': return dn1
elif op is '*': return 0
elif op is '/': raise "division by zero"
elif dn2 is 1 and op is '*': return dn1
if isinstance(dn1, int) and isinstance(dn2, int):
if op is '+': return dn1 + dn2
elif op is '-': return dn1 - dn2
elif op is '*': return dn1 * dn2
elif op is '/': return dn1 / dn2
return (op, dn1, dn2)
return folddown(ifn, vfn, nfn, n)
def folddown(ifn, vfn, nfn, n):
if isinstance(n, int): return ifn(n)
elif isinstance(n, str): return vfn(n)
else:
(op,n1,n2) = n
(dn1, dn2) = (folddown(ifn,vfn,nfn,n1), folddown(ifn,vfn,nfn,n2))
return nfn(op, n1, n2, dn1, dn2)
# a "simple" recursive descent parser
# grammar:
# expr: term ((+|-) term)*
# term: fact ((*|/) fact)*
# fact: '(' expr ')' | var | num
# the ast is just (operator, node, node) tuples
# utilities
def dropspace(s):
if s and s[0] is ' ': return dropspace(s[1:])
else: return s
def empty(s): return dropspace(s) is ""
# so these are a bunch of 'recognisers' (eg cousineau + mauny)
# (you can think of them as tokenizers - they either return a match plus #
the remaning text or None)
def add(s): return mkonechar('+')(s)
def subtract(s): return mkonechar('-')(s)
def multiply(s): return mkonechar('*')(s)
def divide(s): return mkonechar('/')(s)
def openbracket(s): return mkonechar('(')(s)
def closebracket(s): return mkonechar(')')(s)
def variable(s): return mkmanychar(lambda c : c >= 'a' and c <= 'z')(s)
def number(s): return mkmanychar(lambda c : c >= '0' and c <= '9')(s)
def mkonechar(c):
def localonechar(s):
ss = dropspace(s)
if ss and ss[0] is c: return (c,ss[1:])
else: return None
return lambda s: localonechar(s)
def mkmanychar(p):
def accum(s,id=""):
if s and p(s[0]): return accum(s[1:],id+s[0])
elif id != "": return (id,s)
else: return None
return lambda s: accum(dropspace(s))
# this handles the "nxt ((p1|p2) nxt)*" structure in the grammar
def mkextend(p1,p2,nxt):
def localextend(n1,s):
if p1(s):
(x,s) = p1(s)
(n2,s) = nxt(s)
return ((x,n1,n2),s)
elif p2(s):
(x,s) = p2(s)
(n2,s) = nxt(s)
return ((x,n1,n2),s)
else: return (n1,s)
return lambda n,s: localextend(n,s)
# and this is the parser itself
def expr(s):
if term(s):
(n,s) = term(s)
return extendexpr(n,s)
else: raise ("cannot parse " + s)
def extendexpr(n,s): return mkextend(add,subtract,expr)(n,s)
def term(s):
if fact(s):
(n,s) = fact(s)
return extendterm(n,s)
else: raise ("cannot parse " +s)
def extendterm(n,s): return mkextend(multiply,divide,term)(n,s)
def fact(s):
if openbracket(s):
(x,s) = openbracket(s)
if expr(s):
(n,s) = expr(s)
if closebracket(s):
(x,s) = closebracket(s)
return (n,s)
else: raise ("missing ): " + s)
else: raise ("cannot parse: " + s)
elif variable(s):
(x,s) = variable(s)
return (x,s)
elif number(s):
(x,s) = number(s)
return (int(x),s)
else: raise ("cannot parse: " + s)
# finally, a pretty printer
def asttostring(n):
if isinstance(n, int): return str(n)
elif isinstance(n, str): return n
else: return astoptostring(n)
def astoptostring((op,n1,n2)):
return "("+asttostring(n1)+op+asttostring(n2)+")"
# and test
def demo(text, var):
(ast,x) = expr(text)
ast2 = tidy(ast)
print
print "the original expression is", text
print "it was parsed to", asttostring(ast2)
diff = tidy(diffwrt(ast2, var))
print "the differential wrt", var, "is", asttostring(diff)
print
demo("a+3*b*a", "a")
demo("a/b", "b")
demo("a+b+c+d", "b")
demo("a+b*c+d", "b")
Differentiating Functions is Important
From:
Will M Farr <farr@...>
Date:
Thu, 4 Jan 2007 10:42:26 -0500
Andrew,
I think you have it correct in your comment above, but I'd like to
emphasize that differentiating functions (as I do) is very
important. Using reflection (or lisp-style macros) to obtain a
textual representation of a function (as you do) doesn't work well
when you compose functions together, or use non-mathematical
operators in a function. For example, what's the derivative of
fun x ->
if x > 0 then
x
else
0 - x
at x = 3? (You could do this textually if your text processor knew
enough to avoid processing the if statement, but in order to make
this work in general, you would have to process the whole language.)
How about this:
fun x ->
let y = 0 - x in
if y < 0 then
x
else
0 - y
It gets even worse if you define a function like
fun x ->
let y = other_fun x in
let z = another_fun y in
x *. y *. z
To process this textually, you need a database which stores the text
of other_fun and another_fun so they can be re-differentiated w.r.t x
and then y (respectively). Don't forget to apply the chain rule to
the derivative of another_fun (since the argument is y)! It's really
a mess.
By the way, your last paragraph is entirely correct---I just take
some existing code, add "open Deriv" to the top, place (C ...) in
front of all numerical constants, and then I get derivatives for free.
A disadvantage of my method, which yours doesn't share, is that it
would compute the derivative of this function:
fun x ->
x +. 3.0 -. x
as
fun x ->
1.0 +. 0.0 -. 1.0
I don't do any simplification, because I don't have the text of the
expressions available at all.
Thanks for the interesting post---it's fun to see other people doing
this kind of stuff!
Will