使用抽象语法树修改Python 3代码

我现在正在玩抽象语法树,使用ast和astor模块.该文档教我如何检索和漂亮打印各种功能的源代码,网上的各种示例显示如何通过将一行的内容替换为另一行或将所有出现的更改为*来修改代码的各个部分.

但是,我想在各个地方插入其他代码,特别是当函数调用另一个函数时.例如,以下假设功能:

def some_function(param):
    if param == 0:
       return case_0(param)
    elif param < 0:
       return negative_case(param)
    return all_other_cases(param)

会产生(一旦我们使用了astor.to_source(modified_ast)):

def some_function(param):
    if param == 0:
       print ("Hey, we're calling case_0")
       return case_0(param)
    elif param < 0:
       print ("Hey, we're calling negative_case")
       return negative_case(param)
    print ("Seems we're in the general case, calling all_other_cases")
    return all_other_cases(param)

抽象语法树可以实现吗? (注意:我知道在运行代码时,调用的装饰函数会产生相同的结果,但这不是我所追求的;我需要实际输出修改后的代码,并插入比print语句更复杂的内容).

最佳答案 从你的问题中不清楚你是在询问如何将节点插入到低级别的AST树中,或者更具体地说是如何使用更高级别的工具进行节点插入来遍历AST树(例如,ast的子类) .NodeVisitor或astor.TreeWalk).

以低级别插入节点非常容易.您只需在树中的适当列表中使用list.insert.例如,这里有一些代码可以添加你想要的三个打印调用中的最后一个(另外两个几乎一样容易,他们只需要更多的索引).大多数代码都是为打印调用构建新的AST节点.实际插入很短:

source = """
def some_function(param):
    if param == 0:
       return case_0(param)
    elif param < 0:
       return negative_case(param)
    return all_other_cases(param)
"""

tree = ast.parse(source) # parse an ast tree from the source code

# build a new tree of AST nodes to insert into the main tree
message = ast.Str("Seems we're in the general case, calling all_other_cases")
print_func = ast.Name("print", ast.Load())
print_call = ast.Call(print_func, [message], []) # add two None args in Python<=3.4
print_statement = ast.Expr(print_call)

tree.body[0].body.insert(1, print_statement) # doing the actual insert here!

# now, do whatever you want with the modified ast tree.
print(astor.to_source(tree))

输出将是:

def some_function(param):
    if param == 0:
        return case_0(param)
    elif param < 0:
        return negative_case(param)
    print("Seems we're in the general case, calling all_other_cases")
    return all_other_cases(param)

(请注意,ast.Call的参数在Python 3.4和3.5之间发生了变化.如果您使用的是旧版本的Python,则可能需要添加两个额外的None参数:ast.Call(print_func,[message],[],没有,没有))

如果你正在使用更高级别的方法,事情会有点棘手,因为代码需要找出插入新节点的位置,而不是使用你自己的输入知识来硬编码.

这是TreeWalk子类的快速而又脏的实现,它在任何具有Call节点的语句之前添加一个print调用作为语句.请注意,Call节点包括对类的调用(用于创建实例),而不仅仅是函数调用.此代码仅处理嵌套调用的最外层,因此如果代码具有foo(bar()),则插入的打印将仅提及foo:

class PrintBeforeCall(astor.TreeWalk):
    def pre_body_name(self):
        body = self.cur_node
        print_func = ast.Name("print", ast.Load())
        for i, child in enumerate(body[:]):
            self.__name = None
            self.walk(child)
            if self.__name is not None:
                message = ast.Str("Calling {}".format(self.__name))
                print_statement = ast.Expr(ast.Call(print_func, [message], []))
                body.insert(i, print_statement)
        self.__name = None
        return True

    def pre_Call(self):
        self.__name = self.cur_node.func.id
        return True

你这样称呼它:

source = """
def some_function(param):
    if param == 0:
       return case_0(param)
    elif param < 0:
       return negative_case(param)
    return all_other_cases(param)
"""

tree = ast.parse(source)

walker = PrintBeforeCall()   # create an instance of the TreeWalk subclass
walker.walk(tree)   # modify the tree in place

print(astor.to_source(tree)

这次的输出是:

def some_function(param):
    if param == 0:
        print('Calling case_0')
        return case_0(param)
    elif param < 0:
        print('Calling negative_case')
        return negative_case(param)
    print('Calling all_other_cases')
    return all_other_cases(param)

这不是你想要的确切信息,但它很接近. walker无法详细描述正在处理的案例,因为它只查看被调用的名称函数,而不是查看它的条件.如果你有一套非常明确的东西需要寻找,你可以改变它来看看ast.If节点,但我怀疑这将更具挑战性.

点赞