deep learning with reverse derivatives
Project description
catgrad
You like category theory? You like tinygrad? You love catgrad! ❤️
catgrad is a bit different: instead of using autograd to train, you compile your model's reverse pass into static code. This means your training loop can run without needing a deep learning framework (not even catgrad!)
Here is a linear model in catgrad:
model = layers.linear(BATCH_TYPE, INPUT_TYPE, OUTPUT_TYPE)
catgrad can compile this model...
CompiledModel, _, _ = compile_model(model, layers.sgd(learning_rate), layers.mse)
... into static code like this...
class CompiledModel:
backend: ArrayBackend
def predict(self, x1, x0):
x2 = x0 @ x1
return [x2]
def step(self, x0, x1, x9):
x4, x10 = (x0, x0)
x11, x12 = (x1, x1)
x16 = self.backend.constant(0.0001, Dtype.float32)
# ... snip ...
x18 = x17 * x5
x2 = x10 - x18
return [x2]
... so you can train your model by just iterating step
; no autograd needed:
for i in range(0, NUM_ITER):
p = step(p, x, y)
Catgrad doesn't just compile to Python: I'm working on support for other targets like C++ (ggml), CUDA, FPGAs, and more!
Catgrad uses reverse derivatives and open hypergraphs to transform a model into its backwards pass. For details, see this paper.
Install
pip install catgrad
Examples
Train simple MLPs for the iris dataset:
./data/get-iris-data.sh
python3 -m examples.iris (linear|simple|dense|hidden)
Compilation Targets
Target backends we plan to support soon:
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file catgrad-0.2.1.tar.gz
.
File metadata
- Download URL: catgrad-0.2.1.tar.gz
- Upload date:
- Size: 29.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.11.8
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | af46f9f521bb59923367c40ddaf8f9342d0d208b329cc146685036e952802313 |
|
MD5 | 0d5240c3ca67c26d2e5dbabecce730a5 |
|
BLAKE2b-256 | bc1432f890203caa704109254ec6bf145d8b0ace97a686c2247d64bcbbd90707 |
File details
Details for the file catgrad-0.2.1-py3-none-any.whl
.
File metadata
- Download URL: catgrad-0.2.1-py3-none-any.whl
- Upload date:
- Size: 38.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.11.8
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 5393b18150a5d7d071e4439d7ca175c6312ca963796a2666aee13ccaec1a3f2c |
|
MD5 | 22cbde722e2d41f390c6f30f9ae21eaf |
|
BLAKE2b-256 | e79f48a8905f53bc8a4db620c7954781da6de351cf09e680f5b0b637d68c43c0 |