Decompile python functions, from bytecode to source code!
Project description
🐍 depyf: decompile python functions, from bytecode to source code!
This is used primarily to understand the bytecode produced by PyTorch 2.0 Dynamo (PT 2.0 compiler stack).
Installation
Stable release on pypi: pip install depyf
Nightly code: pip install git+https://github.com/youkaichao/depyf.git
Usage
Simple Usage:
# obtain a callable object or codeobject
def func():
print("hello, world!")
# import the `decompile` function
from depyf import decompile
# and decompile it into source code!
print(decompile(func))
Example output:
def func():
print('hello, world!')
return None
The output source code is semantically equivalent to the function, but not syntactically the same. It verbosely adds many details that are hidden in the python code. For example, the above output code explicitly returns None
, which is typically ignored.
Used to understand PyTorch generated bytecode
First, run a pytorch program with torch.compile
:
from typing import List
import torch
from torch import _dynamo as torchdynamo
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
print("my_compiler() called with FX graph:")
gm.graph.print_tabular()
return gm.forward # return a python callable
@torchdynamo.optimize(my_compiler)
def toy_example(a, b):
x = a / (torch.abs(a) + 1)
if b.sum() < 0:
b = b * -1
return x * b
for _ in range(100):
toy_example(torch.randn(10), torch.randn(10))
Second, get compiled code and guard code from pytorch:
from torch._dynamo.eval_frame import _debug_get_cache_entry_list
cache_entries = _debug_get_cache_entry_list(toy_example._torchdynamo_orig_callable.__code__)
guard, code = cache_entries[0]
Third, decompile the code to see how the code works:
from depyf import decompile
print("guard code:")
print(decompile(guard))
print("compiled code:")
print(decompile(code))
Output on my computer:
guard code:
def guard(L):
if not getattr(___guarded_code, 'valid'):
return False
else:
_var0 = L['a']
if not hasattr(_var0, '_dynamo_dynamic_indices') == False:
return False
else:
_var1 = L['b']
if not hasattr(_var1, '_dynamo_dynamic_indices') == False:
return False
elif not ___is_grad_enabled():
return False
elif ___are_deterministic_algorithms_enabled():
return False
elif not ___is_torch_function_enabled():
return False
elif not getattr(utils_device, 'CURRENT_DEVICE') == None:
return False
elif not ___check_tensors(_var0, _var1, tensor_check_names=
tensor_check_names):
return False
else:
return True
compiled code:
def toy_example(a, b):
__temp_1 = __compiled_fn_0(a, b)
x = __temp_1[0]
if __temp_1[1]:
return __resume_at_30_1(b, x)
else:
return __resume_at_38_2(b, x)
Furthermore, we can see how the compiled subgraph works. In this case, we pass a simple my_compiler
function as the backend compiler, therefore the subgraph code __resume_at_38_2
, __resume_at_30_1
, and __compiled_fn_0
remain python code. We can dig into more details by decompiling them:
print("source code of __compiled_fn_0:")
print(decompile(__compiled_fn_0._torchdynamo_orig_callable))
print("=" * 60)
print("source code of __resume_at_30_1:")
print(decompile(__resume_at_30_1))
print("=" * 60)
print("source code of __resume_at_38_2:")
print(decompile(__resume_at_38_2))
Output on my computer:
source code of __compiled_fn_0:
def forward(self, L_a_, L_b_):
l_a_ = L_a_
l_b_ = L_b_
abs_1 = torch.abs(l_a_)
add = abs_1 + 1
abs_1 = None
truediv = l_a_ / add
l_a_ = None
add = None
sum_1 = l_b_.sum()
l_b_ = None
lt = sum_1 < 0
sum_1 = None
return truediv, lt
============================================================
source code of __resume_at_30_1:
def <resume in toy_example>(b, x):
b = b * -1
return x * b
============================================================
source code of __resume_at_38_2:
def <resume in toy_example>(b, x):
return x * b
Hopefully, by using this package, you can understand python bytecode now!
:warning: The above example should be run using pytorch nightly. Some debug functions like _debug_get_cache_entry_list
might not exist in stable releases yet.
Python Version Coverage
The following python major versions are tested:
- Python 3.12
- Python 3.11
- Python 3.10
- Python 3.9
- Python 3.8
- Python 3.7
You can see the coverage report by simply running python python_coverage.py
.
Full Python Syntax Is Not Supported
This package is intended to understand the generated pytorch bytecode, and does not aim to fully cover all the syntax of python. For example, async operations like async/await
is not supported.
I collected all the bytecode generated by PyTorch when benchmarking timm and huggingface transformers. Then, I have several observations:
- No while loops (no jump back instructions).
- try-except-finally only has try-finally.
- No complicated conditions like
if a and b or c or (d and e)
.
Then, I overfit the decompiler to work for the bytecode generated by pytorch. How? Pure labor work. Implement all bytecode for all the supported python versions, one by one. Yes, that's it.
Contributions are welcome!
If you find any error in the decompilation, feel free to open issues or pull requests to fix it!
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.