A Lightweight Neural Network Library only using NumPy with Pytorch-like API
Project description
gradipy: A Lightweight Neural Network Library
gradipy is an evolving project, and it will potentially provide PyTorch-like API for building and training neural networks, there are some features that are actively developed on and plan to support in future releases.
Desired features to add:
- PyTorch like API for most important blocks for training NNs
- Convolutional layers for image processing
- Recurrent layers for sequence data
- Potentially GPU acceleration
Please note that the library is currently in its early stages, and these features are expected in future updates.
Sample Usage
Here's a basic example of using gradipy to create and train a simple neural network for MNIST (please refer to the example usage):
import gradipy.nn as nn
from gradipy import datasets
from gradipy.nn import optim
X_train, y_train, X_val, y_val, X_test, y_test = datasets.MNIST()
# define some utility function here...
class DenseNeuralNetwork:
def __init__(self, input_dim, hidden_dim, output_dim):
self.W1 = nn.init_kaiming_normal(input_dim, hidden_dim, nonlinearity="relu")
self.W2 = nn.init_kaiming_normal(hidden_dim, output_dim, nonlinearity="relu")
def forward(self, X):
logits = X.matmul(self.W1).relu().matmul(self.W2)
return logits
model = DenseNeuralNetwork(input_dim, hidden_dim, output_dim)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam([model.W1, model.W2], lr=lr)
for epoch in range(epochs + 1):
optimizer.zero_grad()
xb, yb = get_batch()
logits = model.forward(xb)
loss = criterion(logits, yb)
loss.backward()
optimizer.step()
# log the results on each epoch...
In this example, we define a simple feedforward neural network, compile it, and train it on random data. gradipy will provide building blocks like Linear layers, activation functions, loss functions, and optimizers for creating and training neural networks.
Feature Roadmap
Here's a list of features we plan to implement in gradipy, along with their current status:
To-Do
- Backward passes for:
mul(problem with broadcasting),tanh - Add more operations and their gradients
- Batchnorm
- Convolutional layers for image processing
- PyTorch's
nn.Module - More Loss functions (
nn.MSELossandnn.NLLLoss) - Recurrent layers for sequence data
- GPU acceleration (no idea how to do that)
Done
- Basic Tensor wrapper around NumPy
ndarray - Forward and backward passes implemented for:
add,matmul,softmax,relu,sub,log,exp,log softmax,cross entropy - Autograd just like PyTorch's
backwardmethod using topological sort - nn.CrossEntropyLoss function
- Train MNIST with
gradipy - Kaiming, Xavier init (normal + uniform)
- Implemented Adam and added momentum to SGD
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file gradipy-0.1.0.tar.gz.
File metadata
- Download URL: gradipy-0.1.0.tar.gz
- Upload date:
- Size: 7.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
4e763aac225a285263f6d40184c850cebd3726ad7b3bd481e132a2f7ba2d0160
|
|
| MD5 |
dd57bb7ce46cd7975add4cebe0af6247
|
|
| BLAKE2b-256 |
bcce1df61c37dc1129db5ef579ddcc25b543132003b6117625fde6ca15e13f66
|
File details
Details for the file gradipy-0.1.0-py3-none-any.whl.
File metadata
- Download URL: gradipy-0.1.0-py3-none-any.whl
- Upload date:
- Size: 7.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ba14a546e8fe8dc048e8bcd149c707f9f99e41a4aaaca066e6e463a8cfc645ca
|
|
| MD5 |
8993f62e4be2692fb5e11de45bd843e7
|
|
| BLAKE2b-256 |
782c26eada0cccc16e3ea73a87d90d78f0ea22c56217e2aeb55b8ce18dea175f
|