在sympy中运行表达式的AST
问题描述:
我正在使用sympy计算复杂函数表达式的一些高阶导数。我想遍历表达式的AST,例如首先通过节点深度。我该怎么做?
I'm using sympy to compute some higher order derivatives of a complicated function expression. I'd like to traverse the AST of the expression, e.g. go through the nodes depth first. How do I do that ?
答
简单的深度优先旅行是这样的:
A simple depth-first travel would be like this:
from sympy import pi, sin
from sympy.abc import a,x,y
def depth_first_traverse(expr):
for arg in expr.args:
depth_first_traverse(arg, depth+1, new_marks+'+---', parent_ind=ind)
if len(expr.args) == 0:
# we reached a leaf of the tree
pass # do something with leaf expr
else:
pass # do something with compound expr
depth_first_traverse(sin(a*x*pi+1.5)/y)
带有附加参数,更多可以达到复杂的目标,例如以深度优先的顺序打印树,同时显示所有内容如何组合在一起:
With additional parameters, more complex goals can be reached, for example printing the tree in depth-first order while showing how everything fits together:
from sympy import srepr, pi, sin
from sympy.abc import a,x,y
def depth_first_traverse(expr, depth=0, marks='', parent_ind=None):
for ind, arg in enumerate(expr.args):
new_marks = marks.replace('+', '|').replace('-', ' ')
if parent_ind == 0:
new_marks = new_marks[:-4] + ' ' + new_marks[-3:]
depth_first_traverse(arg, depth+1, new_marks+'+---', parent_ind=ind)
if len(expr.args) == 0:
print(marks, end="> ")
print("symbol", srepr(expr))
else:
print(marks, end="+ ")
print("function", expr.func, "had", len(expr.args), "arguments")
print(marks.replace('+', '|').replace('-', ' '))
depth_first_traverse(sin(a*x*pi+1.5)/y)
输出:
+---> symbol Symbol('y')
|
+---> symbol Integer(-1)
|
+---+ function <class 'sympy.core.power.Pow'> had 2 arguments
|
| +---> symbol Float('1.5', precision=53)
| |
| | +---> symbol pi
| | |
| | +---> symbol Symbol('a')
| | |
| | +---> symbol Symbol('x')
| | |
| +---+ function <class 'sympy.core.mul.Mul'> had 3 arguments
| |
| +---+ function <class 'sympy.core.add.Add'> had 2 arguments
| |
+---+ function sin had 1 arguments
|
+ function <class 'sympy.core.mul.Mul'> had 2 arguments