Skip to main content

deep learning with reverse derivatives

Project description

catgrad

You like category theory? You like tinygrad? You love catgrad! ❤️

catgrad is a bit different: instead of using autograd to train, you compile your model's reverse pass into static code. This means your training loop can run without needing a deep learning framework (not even catgrad!)

Here is a linear model in catgrad:

model = layers.linear(BATCH_TYPE, INPUT_TYPE, OUTPUT_TYPE)

catgrad can compile this model...

CompiledModel, _, _ = compile_model(model, layers.sgd(learning_rate), layers.mse)

... into static code like this...

class CompiledModel:
    backend: ArrayBackend

    def predict(self, x1, x0):
        x2 = x0 @ x1
        return [x2]

    def step(self, x0, x1, x9):
        x4, x10 = (x0, x0)
        x11, x12 = (x1, x1)
        x16 = self.backend.constant(0.0001, Dtype.float32)
        # ... snip ...
        x18 = x17 * x5
        x2 = x10 - x18
        return [x2]

... so you can train your model by just iterating step; no autograd needed:

for i in range(0, NUM_ITER):
    p = step(p, x, y)

Catgrad doesn't just compile to Python: I'm working on support for other targets like C++ (ggml), CUDA, FPGAs, and more!

Catgrad uses reverse derivatives and open hypergraphs to transform a model into its backwards pass. For details, see this paper.

Install

pip install catgrad

Examples

Train simple MLPs for the iris dataset:

./data/get-iris-data.sh
python3 -m examples.iris (linear|simple|dense|hidden)

Compilation Targets

Target backends we plan to support soon:

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

catgrad-0.2.1.tar.gz (29.4 kB view hashes)

Uploaded Source

Built Distribution

catgrad-0.2.1-py3-none-any.whl (38.0 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page