diff --git a/luaparser/ast.py b/luaparser/ast.py index 03cb44c..79ea0be 100644 --- a/luaparser/ast.py +++ b/luaparser/ast.py @@ -94,6 +94,42 @@ def default(self, o): def to_pretty_json(root: Node) -> str: return json.dumps(root, cls=JSONEncoder, indent=4) +class ASTReplaceVisitor: + def visit(self, root): + if root is None: + return + node_stack = [(root, None)] + + while len(node_stack) > 0: + node, parent = node_stack.pop() + + if isinstance(node, astnodes.Node): + name = "visit_" + node.__class__.__name__ + tree_visitor = getattr(self, name, None) + if tree_visitor: + new_node = tree_visitor(node) + if new_node != None: + for key in parent.__dict__.keys(): + # if key == 'values' or key == 'args': + obj = getattr(parent, key, None) + if isinstance(obj, list): + for i, item in enumerate(obj): + # print(item) + if item is node: + obj[i] = new_node + + elif obj is node: + setattr(parent, key, new_node) + + children = [ + attr for attr in node.__dict__.keys() if not attr.startswith("_") + ] + for child in reversed(children): + node_stack.append((node.__dict__[child], node)) + + elif isinstance(node, list): + for n in reversed(node): + node_stack.append((n, parent)) class ASTVisitor: def visit(self, root):