Skip to main content

deep learning with reverse derivatives

Project description

catgrad

You like category theory? You like tinygrad? You love catgrad! ❤️

catgrad is a bit different: instead of using autograd to train, you compile your model's reverse pass into static code. This means your training loop can run without needing a deep learning framework (not even catgrad!)

Here is a linear model in catgrad:

model = layers.linear(BATCH_TYPE, INPUT_TYPE, OUTPUT_TYPE)

catgrad can compile this model...

CompiledModel, _, _ = compile_model(model, layers.sgd(learning_rate), layers.mse)

... into static code like this...

class CompiledModel:
    backend: ArrayBackend

    def predict(self, x1, x0):
        x2 = x0 @ x1
        return [x2]

    def step(self, x0, x1, x9):
        x4, x10 = (x0, x0)
        x11, x12 = (x1, x1)
        x16 = self.backend.constant(0.0001, Dtype.float32)
        # ... snip ...
        x18 = x17 * x5
        x2 = x10 - x18
        return [x2]

... so you can train your model by just iterating step; no autograd needed:

for i in range(0, NUM_ITER):
    p = step(p, x, y)

Catgrad doesn't just compile to Python: I'm working on support for other targets like C++ (ggml), CUDA, FPGAs, and more!

Catgrad uses reverse derivatives and open hypergraphs to transform a model into its backwards pass. For details, see this paper.

Install

pip install catgrad

Examples

Train simple MLPs for the iris dataset:

./data/get-iris-data.sh
python3 -m examples.iris (linear|simple|dense|hidden)

Compilation Targets

Target backends we plan to support soon:

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

catgrad-0.2.2.tar.gz (18.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

catgrad-0.2.2-py3-none-any.whl (26.9 kB view details)

Uploaded Python 3

File details

Details for the file catgrad-0.2.2.tar.gz.

File metadata

  • Download URL: catgrad-0.2.2.tar.gz
  • Upload date:
  • Size: 18.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.1

File hashes

Hashes for catgrad-0.2.2.tar.gz
Algorithm Hash digest
SHA256 e4beb40aec2a3b1cfcb1a0ecb58113c580c0d02c058ade653429f61405d61a32
MD5 ce7117d6229b7eec248117c7fe30c3fd
BLAKE2b-256 0283cdba62d35a5784799b95f7042e20373a2203827445fb777b6cf0a7829008

See more details on using hashes here.

File details

Details for the file catgrad-0.2.2-py3-none-any.whl.

File metadata

  • Download URL: catgrad-0.2.2-py3-none-any.whl
  • Upload date:
  • Size: 26.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.1

File hashes

Hashes for catgrad-0.2.2-py3-none-any.whl
Algorithm Hash digest
SHA256 6730743f0f779256af6e5d0dcc9a0fd99dfb6c2b091aaa37d20cacf976ac7934
MD5 96abcceef297602b46fddcd7cf946698
BLAKE2b-256 262d9bc79e499653e8863e6c2e3a89d382e179002cb7bf19cb3a427a3bde83af

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page