Skip to main content

deep learning with reverse derivatives

Project description

catgrad

You like category theory? You like tinygrad? You love catgrad! ❤️

catgrad is a bit different: instead of using autograd to train, you compile your model's reverse pass into static code. This means your training loop can run without needing a deep learning framework (not even catgrad!)

Here is a linear model in catgrad:

model = layers.linear(BATCH_TYPE, INPUT_TYPE, OUTPUT_TYPE)

catgrad can compile this model...

CompiledModel, _, _ = compile_model(model, layers.sgd(learning_rate), layers.mse)

... into static code like this...

class CompiledModel:
    backend: ArrayBackend

    def predict(self, x1, x0):
        x2 = x0 @ x1
        return [x2]

    def step(self, x0, x1, x9):
        x4, x10 = (x0, x0)
        x11, x12 = (x1, x1)
        x16 = self.backend.constant(0.0001, Dtype.float32)
        # ... snip ...
        x18 = x17 * x5
        x2 = x10 - x18
        return [x2]

... so you can train your model by just iterating step; no autograd needed:

for i in range(0, NUM_ITER):
    p = step(p, x, y)

Catgrad doesn't just compile to Python: I'm working on support for other targets like C++ (ggml), CUDA, FPGAs, and more!

Catgrad uses reverse derivatives and open hypergraphs to transform a model into its backwards pass. For details, see this paper.

Install

pip install catgrad

Examples

Train simple MLPs for the iris dataset:

./data/get-iris-data.sh
python3 -m examples.iris (linear|simple|dense|hidden)

Compilation Targets

Target backends we plan to support soon:

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

catgrad-0.2.1.tar.gz (29.4 kB view details)

Uploaded Source

Built Distribution

catgrad-0.2.1-py3-none-any.whl (38.0 kB view details)

Uploaded Python 3

File details

Details for the file catgrad-0.2.1.tar.gz.

File metadata

  • Download URL: catgrad-0.2.1.tar.gz
  • Upload date:
  • Size: 29.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.11.8

File hashes

Hashes for catgrad-0.2.1.tar.gz
Algorithm Hash digest
SHA256 af46f9f521bb59923367c40ddaf8f9342d0d208b329cc146685036e952802313
MD5 0d5240c3ca67c26d2e5dbabecce730a5
BLAKE2b-256 bc1432f890203caa704109254ec6bf145d8b0ace97a686c2247d64bcbbd90707

See more details on using hashes here.

File details

Details for the file catgrad-0.2.1-py3-none-any.whl.

File metadata

  • Download URL: catgrad-0.2.1-py3-none-any.whl
  • Upload date:
  • Size: 38.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.11.8

File hashes

Hashes for catgrad-0.2.1-py3-none-any.whl
Algorithm Hash digest
SHA256 5393b18150a5d7d071e4439d7ca175c6312ca963796a2666aee13ccaec1a3f2c
MD5 22cbde722e2d41f390c6f30f9ae21eaf
BLAKE2b-256 e79f48a8905f53bc8a4db620c7954781da6de351cf09e680f5b0b637d68c43c0

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page