Back Original

A Practical Introduction to the Python Abstract Syntax Tree

Read the "What is an Abstract Syntax Tree?" section What is an Abstract Syntax Tree?

An Abstract Syntax Tree (AST) is a representation of the source code structure as data. It is called 'abstract' because it removes surface details like comments, formatting or extra parentheses. It can be thought of as a nested dictionary or JSON-like document that describes the code.

(1 + 2) can be represented in different ways:


[{"type": "id", "value": "+"},
  [{"type": "literal", "value": 1},
   {"type": "literal", "value": 2}]
]


{
  "type": "Add",
  "left": {"type": "Number", "value": 1},
  "right": {"type": "Number", "value": 2}
}

+12

Even a tiny expression becomes a tree

Representing code as data allows writing all sorts of software: Abstract Syntax Trees are used in interpreters, compilers (including things like babel in JavaScript), linters (flake8/black/ruff in Python, eslint in JavaScript) and many more tools.

These programs can be surprisingly approachable, at least conceptually.

def evaluate(node):
    """An ast interpreter for basic arithmetic expressions."""
    if node["type"] == "Add":
        return evaluate(node["left"]) + evaluate(node["right"])
    if node["type"] == "Subtract":
        return evaluate(node["left"]) - evaluate(node["right"])
    if node["type"] == "Number":
        return node["value"]
    raise ValueError(f"unsupported node type {node['type']}")

While a linter might do things like:

def lint(node):
    if node["type"] == "FunctionDefinition" and len(node["args"]) > 5:
        report_error(node.location, "too many parameters in function")

Visualizing ASTs makes them much less mysterious.

Read the "In Python" section In Python

When executing Python code with CPython:

  1. the source code is parsed into an AST (a tree of Python objects where nodes are instances of classes like ast.FunctionDef or ast.Call)
  2. the AST is compiled into code objects (sometimes cached on disk as .pyc files)
  3. the compiled code is executed by the bytecode interpreter

When people say "the Python interpreter", they usually refer to this whole pipeline1.

Source codeASTCode objectsOutput ParserCompilerBytecode interpreter

Python uses the AST as an intermediate step when compiling code

Source codeASTCode objectsOutput ParserCompilerBytecode interpreter

Python uses the AST as an intermediate step when compiling code

The Python standard library directly exposes this process:

import ast
tree = ast.parse("a = 1 + 2; print(a)")
code_object = compile(tree, filename="<ast>", mode="exec")


exec(code_object)

This direct access to the AST means we can modify the tree before calling compile, changing the program before it runs. These modifications are known as AST transforms.

Read the "AST Transforms" section AST Transforms

AST transforms get the source code to behave like different source code. A fun example is what pytest does to display nice error messages when an assert fails.

Read the "A real-world example: pytest assertions" section A real-world example: pytest assertions

Python default behavior on AssertionError is to print:

Traceback (most recent call last):
  File "/home/laurent/test/assert.py", line 2, in <module>
    assert a == b
AssertionError

Want to know the values of a and b? Add print statements, run the code again.

>       assert a == b
E       assert 1 == 2

assert.py:3: AssertionError

It is a little magical: we did not change the code, we just used pytest test_module.py instead of python test_module.py. It is still Python executing the code, only when running with pytest, instead of:

assert a == b

Python executes something like:

try:
    assert a == b
except AssertionError:
    [pytest-generated code to display a nice error message]
    raise

Pytest does that without modifying the source code (your files on disk), by transforming the AST before Python gets to execute it2.

Regular PythonWith pytestSource codeASTOutputSource codeASTTransformed ASTOutput Python parserCompiler + Bytecode interpreterPython parserpytest transformCompiler + Bytecode interpreter

Pytest loads test files, transforms their AST, and passes the result to Python

The pytest code for this uses the parse/transform/compile pipeline we described (and is even easier to read than the diagram):

tree = ast.parse(source, filename=strfn)
rewrite_asserts(tree, source, strfn, config)
co = compile(tree, strfn, "exec", dont_inherit=True)

So all that remains is to write the rewrite_asserts transform logic.

Read the "Mechanics of an AST transform" section Mechanics of an AST transform

An AST transform takes a tree as input and modifies it. ast.NodeTransformer is a helper class that traverses the tree for us.

An AST is made of many different node types (ast.Assign, ast.Call, ast.Name and many more). NodeTransformer walks the tree and calls the visit_<node_type> methods when they exist.

The example from the official documentation: replace all variables (name lookups) with data["variable_name"]

class RewriteName(NodeTransformer):

    def visit_Name(self, node):
        return Subscript(
            value=Name(id='data', ctx=Load()),
            slice=Constant(value=node.id),
            ctx=node.ctx
        )

This turns b = a + 1 into data['b'] = data['a'] + 1. Not too useful in practice, but it shows NodeTransformer in action:

  • it walks the syntax tree and calls the visit_<node_type> method
  • returning a new node replaces the original in the tree, returning None deletes it
  • generic_visit is called for nodes that are not handled by a specific method (should be called explicitly to process children of a visited node)
  • by default nodes are untouched

Read the "Building a pytest-like assert transformer" section Building a pytest-like assert transformer

Say we want to write a transformer that converts:

assert a == b

into:

try:
    assert a == b
except AssertionError:
    raise AssertionError(f"a == b failed\na = {a}\nb = {b}")

This achieves a behavior similar to pytest: it improves on Python's assert by showing us the values of variables in the error message.

General advice to write such a transformer:

  • We don't need to know all the node types ahead of time. It's easy to pick them up as needed.
  • It helps a lot to visualize the source and target ASTs, with a web tool or ast.dump.
  • ast.unparse can be used to convert an AST to source code: useful to test the transformer.

Some moderately ugly visualizations of the source and target ASTs:

Assert

test

source Assert node

Try

body[0]

handlers[0]

body[0]

exc

args[0]

JoinedStr subtree (collapsed)

target node we want to replace the Assert node with

Here, looking at the code and their trees:

  • We want to replace each Assert node with a Try node, that contains the original Assert node and has an exception handler with a custom Raise node.
  • We'll need to construct a few new AST nodes like JoinedStr, Constant or FormattedValue.
  • We'll want the a == b part of assert a == b as a string literal, to include in our custom error message. The assert node exposes this expression via the test attribute, which we can turn back into a code string with ast.unparse.

Along the way we will need to collect the variables used in the Assert node (a and b in our example) so they can be included in the error message. Variables are represented as Name nodes, where the id attribute is the variable name.

Here's an implementation:

test_runner.py

import ast


class RewriteAssertNodeTransformer(ast.NodeTransformer):

    def __init__(self):
        
        self.assert_stack = []

    def visit_Assert(self, node: ast.Assert):
        self.assert_stack.append([])
        
        self.generic_visit(node)
        name_nodes = self.assert_stack.pop()

        assertion_test_as_text = ast.unparse(node.test)
        assertion_msg_parts = get_assertion_message_parts(assertion_test_as_text, name_nodes)
        assertion_error_message = ast.JoinedStr(values=assertion_msg_parts)

        except_handler = ast.ExceptHandler(
            type=ast.Name(id="AssertionError", ctx=ast.Load()),
            body=[ast.Raise(ast.Call(ast.Name(id="AssertionError", ctx=ast.Load()), [assertion_error_message], []))],
        )

        return ast.Try(
            body=[node],
            handlers=[except_handler],
            orelse=[],
            finalbody=[],
        )

    def visit_Name(self, node: ast.Name):
        
        if self.assert_stack:
            self.assert_stack[-1].append(node)
        return node


def get_assertion_message_parts(assertion_test_as_text, name_nodes):
    assertion_msg_parts = [ast.Constant(assertion_test_as_text + " failed")]
    for name_node in name_nodes:
        assertion_msg_parts.append(ast.Constant(f"\n{name_node.id} = "))
        assertion_msg_parts.append(ast.FormattedValue(
            ast.Name(id=name_node.id, ctx=ast.Load()),
            conversion=-1,
            format_spec=None)
        )
    return assertion_msg_parts


if __name__ == "__main__":
    source = "assert a == b"
    tree = ast.parse(source)
    transformed = RewriteAssertNodeTransformer().visit(tree)
    
    transformed = ast.fix_missing_locations(transformed)
    print(ast.unparse(transformed))

This program transforms the AST for assert a == b and prints code for the transformed AST. The result matches our target source code.

This could be adapted to take a file as input, compile the modified tree and exec it. This would give an interface similar to pytest: test_runner.py test_module.py.

This example is intentionally minimal to remain approachable. If this still feels a bit overwhelming, spending time looking at the ASTs in the explorer tool should help.

An interesting exercise could be to support attribute access (assert a.b == c.d) and indexing into lists (assert a[b] == c[d]). The real-world pytest transform does a lot more, like showing intermediate values of computations/function calls.

Read the "The code pytest generates" section The code pytest generates

Strictly speaking, pytest does not generate Python code. As we have seen, it transforms the AST instead. But we can still use ast.unparse to get source code from the transformed AST.

Here's what it looks like for assert a == b (using pytest 7.4.3 with python 3.11.23):

import builtins as @py_builtins
import _pytest.assertion.rewrite as pytest_ar

@py_assert1 = a == b
if not @py_assert1:
    @py_format3 = pytest_ar._call_reprcompare(('==',), (@py_assert1,), ('%(py0)s == %(py2)s',), (a, b)) % {
        'py0': pytest_ar._saferepr(a) if 'a' in @py_builtins.locals() or pytest_ar._should_repr_global_name(a) else 'a',
        'py2': pytest_ar._saferepr(b) if 'b' in @py_builtins.locals() or pytest_ar._should_repr_global_name(b) else 'b'}
    @py_format5 = ('' + 'assert %(py4)s') % {'py4': @py_format3}
    raise AssertionError(@pytest_ar._format_explanation(@py_format5))
@py_assert1 = None

That's a lot of code for just a == b!

Read the "Conclusion" section Conclusion

Manipulating code as data can make for some fun and powerful tools. Like most metaprogramming techniques, it should be used with great responsibility.

Even if you rarely use the ast module directly in application code, understanding Abstract Syntax Trees makes many of the developer tools we use less mysterious.