我有一个场景,我需要在Python中动态修饰函数内的递归调用.关键要求是动态地实现这一点,而不修改当前作用域中的函数.让我解释一下目前的情况以及我目前所做的努力.
假设我有一个函数traverse_tree
,它递归遍历一个二元树并产生它的值.现在,我想修饰这个函数中的递归调用,以包含额外的信息,比如递归深度.当我直接将decorator 与函数一起使用时,它会像预期的那样工作.然而,我希望动态地实现相同的功能,而不修改当前作用域中的函数.
import functools
class Node:
def __init__(self, value, left=None, right=None):
self.value = value
self.left = left
self.right = right
def generate_tree():
root = Node(1)
root.left = Node(2)
root.right = Node(3)
root.left.left = Node(4)
root.left.right = Node(5)
root.right.left = Node(6)
root.right.right = Node(7)
return root
def with_recursion_depth(func):
"""Yield recursion depth alongside original values of an iterator."""
class Depth(int): pass
depth = Depth(-1)
def depth_in_value(value, depth) -> bool:
return isinstance(value, tuple) and len(value) == 2 and value[-1] is depth
@functools.wraps(func)
def wrapper(*args, **kwargs):
nonlocal depth
depth = Depth(depth + 1)
for value in func(*args, **kwargs):
if depth_in_value(value, depth):
yield value
else:
yield value, depth
depth = Depth(depth - 1)
return wrapper
# 1. using @-syntax
@with_recursion_depth
def traverse_tree(node):
"""Recursively yield values of the binary tree."""
yield node.value
if node.left:
yield from traverse_tree(node.left)
if node.right:
yield from traverse_tree(node.right)
root = generate_tree()
for item in traverse_tree(root):
print(item)
# Output:
# (1, 0)
# (2, 1)
# (4, 2)
# (5, 2)
# (3, 1)
# (6, 2)
# (7, 2)
# 2. Dynamically:
def traverse_tree(node):
"""Recursively yield values of the binary tree."""
yield node.value
if node.left:
yield from traverse_tree(node.left)
if node.right:
yield from traverse_tree(node.right)
root = generate_tree()
for item in with_recursion_depth(traverse_tree)(root):
print(item)
# Output:
# (1, 0)
# (2, 0)
# (4, 0)
# (5, 0)
# (3, 0)
# (6, 0)
# (7, 0)
问题似乎在于函数中的递归调用是如何修饰的.当动态地使用修饰器时,它只修饰外部函数调用,而不是函数内的递归调用.我可以通过重新赋值(traverse_tree = with_recursion_depth(traverse_tree)
)来实现这一点,但是现在函数在当前作用域中已经被修改了.我希望动态地实现这一点,这样我就可以使用非修饰函数,或者可选地包装它以获得递归深度的信息.
我更喜欢保持简单,如果有替代解决方案的话,我希望避免使用字节码操作等技术."不过,如果这是必要的道路,我愿意go 探索,我在那个方向上做了try ,但还没有成功."
import ast
def modify_recursive_calls(func, decorator):
def decorate_recursive_calls(node):
if isinstance(node, ast.Call) and isinstance(node.func, ast.Name) and node.func.id == func.__name__:
func_name = ast.copy_location(ast.Name(id=node.func.id, ctx=ast.Load()), node.func)
decorated_func = ast.Call(
func=ast.Name(id=decorator.__name__, ctx=ast.Load()),
args=[func_name],
keywords=[],
)
node.func = decorated_func
for field, value in ast.iter_fields(node):
if isinstance(value, list):
for item in value:
if isinstance(item, ast.AST):
decorate_recursive_calls(item)
elif isinstance(value, ast.AST):
decorate_recursive_calls(value)
tree = ast.parse(inspect.getsource(func))
decorate_recursive_calls(tree)
ast.fix_missing_locations(tree)
modified_code = compile(tree, filename="<ast>", mode="exec")
modified_function = types.FunctionType(modified_code.co_consts[1], func.__globals__)
return modified_function