Skip to main content

Jax Decompiler

Project description

JaxDecompiler

Jax Decompiler

The JAX decompiler takes jaxpr code and produces Python code. Even if some information about the original function is lost (obfuscated code) like variable names being lost, it is an important tool for reverse-engineering. There are many applications where decompiling gradient functions are useful.

Associated pr: https://github.com/google/jax/issues/13398

Requirements

jax==0.4.0

jaxlib==0.4.0

Installation

pip3 install JaxDecompiler

Usage example

Given any jaxpr function, here "df", we want to generate the associated Python code.

import jax

def f(x, smooth_rate):
    local_minimums = (1 - smooth_rate) * jax.numpy.cos(x)
    global_minimum = smooth_rate * x**2
    return global_minimum + local_minimums


df = jax.grad(f, (0,))

Function df is implemented with jaxpr code. You can display it with:

from JaxDecompiler import decompiler

decompiler.display_wrapped_jaxpr(df, (1.0, 1.0))

returns:

===== HEADER =======
invars: [a, b]
outvars: [p]
constvars: []
===== CODE =======
{ lambda ; a:f32[] b:f32[]. let
    c:f32[] = sub 1.0 b
    d:f32[] = cos a
    e:f32[] = sin a
[...]

The below code decompiles it automatically. It generates the python function and its python code as text.

from JaxDecompiler import decompiler

decompiled_df, python_code = decompiler.python_jaxpr_python(
    df, (1.0, 1.0), is_python_returned=True
)

Let's check df and decompiled_df behave the same:

print("df: ", df(4.0, 0.99)) # ~7.927568
print("decompiled df: ", decompiled_df(4.0, 0.99))  # ~7.927568

They produce the same result in spite to be written in different languages!

Now Let's display what is inside decompiled_df:

print(python_code)

Display:

def f(a, b):
    c = 1.0 - b
    d = cos(a)
    e = sin(a)
    f = c * d
    g = a ** 2
    h = a ** 1
    i = 2.0 * h
    j = b * g
    _ = j + f
    k = c * 1.0
    l = -k
    m = l * e
    n = b * 1.0
    o = n * i
    p = m + o
    return p

Now, the user owns its derivative code and may easily refactor/edit it! This is a reverse-engineering tool, for example, we can now improving arithemtic stability, manually optimize the code, ...

Notice: python_jaxpr_python create out/ folder in the current directory.

Next steps

There are the next steps:

  • More operators. Today >70 jaxpr operators are implemented ('add', 'mul', 'cos', ...). The exhaustive list of the implemented operators is in the file "primitive_mapping.py". This python file aims to map jaxpr operator (the name of the functions) into python code (string returned by the function).

  • Automatic refactoring. There is room for improvement to make the automatically produced Python code easier to read/maintain. An automatic refactoring tool should be able to translate this low-level Python style into a more readable one for humans.

  • Automatic detection of useless codes. In the example above, "j" variable is useless.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

JaxDecompiler-0.0.7.tar.gz (18.0 kB view details)

Uploaded Source

Built Distribution

JaxDecompiler-0.0.7-py3-none-any.whl (12.1 kB view details)

Uploaded Python 3

File details

Details for the file JaxDecompiler-0.0.7.tar.gz.

File metadata

  • Download URL: JaxDecompiler-0.0.7.tar.gz
  • Upload date:
  • Size: 18.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.9

File hashes

Hashes for JaxDecompiler-0.0.7.tar.gz
Algorithm Hash digest
SHA256 a1de9d3be7e37ace4feed4b2a0b30e693e1eaf1315c6961305ea238a252b43fa
MD5 5a9676bbaf6b81147d2cf4fa5ad14581
BLAKE2b-256 0d1cbdaafb8873caffcc78e50a9871bcb6fe556541550c66c3778c46a2f4fa04

See more details on using hashes here.

File details

Details for the file JaxDecompiler-0.0.7-py3-none-any.whl.

File metadata

File hashes

Hashes for JaxDecompiler-0.0.7-py3-none-any.whl
Algorithm Hash digest
SHA256 3a0e6dd2718368577ee77ee77251d275d85456b03f636441f42d7d6b903698e7
MD5 40854829a0d8d461083ddeba58f00acb
BLAKE2b-256 642fa1c134dacd1e6d702edf799fb2fd61b0be0d89fe91c2422102a571160579

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page