Useful links:

This notes is based on the deep dive video given by Edward Z Yang. He also has a blog that hosts some PyTorch internal deep dives.

Some knowledge not covered by the video: deploy requires no graph breaks.

Overview

TorchDynamo wants to do similar things as TorchScript and FX Script: catch the graph, and then operate on it. In that sense it’s similar to JIT. TorchDynamo takes an easier route than Pypy, in that instead of operating directly on the Python code itself, it operated on the parsed bytecode, to be run by CPython. Operating on byte code provides a simpler interface.

What’s that bytecode thing look like? Some opcode that Python provides a disassembler for. All documented in the dis module.

In summary, Torch Dynamo convert bytecode to FXGraph, possibly multiple subgraphs, and they can be output as bytecode, just like JIT.

How does achieve that?

In CPython’s C API, one is allowed to set a frame handler for Python functions. That frame means function frame here, so all the stack information. So Torch Dynamo basically does a cached compilation of the code (an Python bytecode interpreter inside Python, symbolically interpreting the code), converting original opcodes into a FX Graph. Specifically, we pay special attention to Torch tensors, so they are put into a graph and can be compiled.

What does Python bytecode look like

def toy_example(a, b):
	return a + b

is

0 LOAD_FAST       0 (a)
2 LOAD_FAST       1 (b)
4 BINARY_ADD
6 RETURN_VALUE

This is a stack machine, so we don’t need to provide register no. to BINRAY_ADD . The name stack machine / register machine is a common theme among VMs, e.g. JVM is a stack machine. It can also be thought of as postfix notation.

An example of how LOAD_FAST is handled: (for source, see here)

    def LOAD_FAST(self, inst):
        name = inst.argval
 
        if self.exec_recorder and name in self.f_locals:
            self.exec_recorder.add_local_var(name, self.f_locals[name])
 
        try:
            self.push(self.symbolic_locals[name].unwrap())
        except KeyError:
            if name.startswith("."):
                try:
                    # This happens in dict/list comprehensions
                    self.push(self.symbolic_locals[name.replace(".", "implicit")])
                except KeyError:
                    unimplemented("undefined LOAD_FAST (implicit)")
            else:
                unimplemented("undefined LOAD_FAST")
 
        # for continuation functions
        if name.startswith("___stack"):
            self.symbolic_locals.pop(name)

You can see really the important part is self.push, which means pushing the stuff to a self maintained stack.

    def push(self, val: Optional[VariableTracker]):
        assert val is None or isinstance(
            val, VariableTracker
        ), f"push expects VariableTracker, got {typestr(val)}"
        self.stack.append(val)  # type: ignore[arg-type]

What about BINARY_ADD, etc.? It’s implemented via this:

def stack_op(fn: typing.Callable[..., object]):
    nargs = len(inspect.signature(fn).parameters)
    fn_var = BuiltinVariable(fn)
 
    @functools.wraps(fn)
    def impl(self: "InstructionTranslatorBase", inst: Instruction):
        self.push(fn_var.call_function(self, self.popn(nargs), {}))
 
    return impl

So it’s popping n vars from stack, call_function, which concretely or symbolically evaluated result.

[! Graph break vs guard tradeoff] We can get a more generic tracing by having graph breaks, saying “it’s unclear how to handle it”, or put guards saying “this generated bytecode is only valid if b is a constant that == 42”. If it’s not the same, then we need to compile a new version.

Where’s call_function defined then? Depends on the variable. E.g. for a constant dict, here.

elif name == "items":
	assert not (args or kwargs)
	if self.source:
		tx.output.guard_on_key_order.add(self.source.name())
	return TupleVariable(
		[TupleVariable([k.vt, v]) for k, v in self.items.items()]
	)
elif name == "__len__":
	assert not (args or kwargs)
	return ConstantVariable.create(len(self.items))
elif name == "__setitem__" and arg_hashable and self.mutable_local:
	assert not kwargs and len(args) == 2
	tx.output.side_effects.mutation(self)
	self.items[Hashable(args[0])] = args[1]
	return ConstantVariable.create(None)

The video and the latest source has already differ in how it handles __getitem__, but as can be seen, it overloads all methods currently exposed via the Python interface, and does somenthing with them, mutating the graph and sometimes directly getting the result. Same for tensor, where all operations are logged into graph.