Skip to main content

A Lightweight Neural Network Library only using NumPy with Pytorch-like API

Project description

gradipy: A Lightweight Neural Network Library

Tests

gradipy is an evolving project, and it will potentially provide PyTorch-like API for building and training neural networks, there are some features that are actively developed on and plan to support in future releases.

Desired features to add:

  • PyTorch like API for most important blocks for training NNs
  • Convolutional layers for image processing
  • Recurrent layers for sequence data
  • Potentially GPU acceleration

Please note that the library is currently in its early stages, and these features are expected in future updates.

Sample Usage

Here's a basic example of using gradipy to create and train a simple neural network for MNIST (please refer to the example usage):

import gradipy.nn as nn
from gradipy import datasets
from gradipy.nn import optim

X_train, y_train, X_val, y_val, X_test, y_test = datasets.MNIST()

# define some utility function here...

class DenseNeuralNetwork:
    def __init__(self, input_dim, hidden_dim, output_dim):
        self.W1 = nn.init_kaiming_normal(input_dim, hidden_dim, nonlinearity="relu")
        self.W2 = nn.init_kaiming_normal(hidden_dim, output_dim, nonlinearity="relu")

    def forward(self, X):
        logits = X.matmul(self.W1).relu().matmul(self.W2)
        return logits

model = DenseNeuralNetwork(input_dim, hidden_dim, output_dim)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam([model.W1, model.W2], lr=lr)

for epoch in range(epochs + 1):
    optimizer.zero_grad()
    xb, yb = get_batch()
    logits = model.forward(xb)
    loss = criterion(logits, yb)
    loss.backward()
    optimizer.step()

    # log the results on each epoch...

In this example, we define a simple feedforward neural network, compile it, and train it on random data. gradipy will provide building blocks like Linear layers, activation functions, loss functions, and optimizers for creating and training neural networks.

Feature Roadmap

Here's a list of features we plan to implement in gradipy, along with their current status:

To-Do

  • Backward passes for: mul (problem with broadcasting), tanh
  • Add more operations and their gradients
  • Batchnorm
  • Convolutional layers for image processing
  • PyTorch's nn.Module
  • More Loss functions (nn.MSELoss and nn.NLLLoss)
  • Recurrent layers for sequence data
  • GPU acceleration (no idea how to do that)

Done

  • Basic Tensor wrapper around NumPy ndarray
  • Forward and backward passes implemented for: add, matmul, softmax, relu, sub, log, exp, log softmax, cross entropy
  • Autograd just like PyTorch's backward method using topological sort
  • nn.CrossEntropyLoss function
  • Train MNIST with gradipy
  • Kaiming, Xavier init (normal + uniform)
  • Implemented Adam and added momentum to SGD

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

gradipy-0.1.0.tar.gz (7.4 kB view details)

Uploaded Source

Built Distribution

gradipy-0.1.0-py3-none-any.whl (7.4 kB view details)

Uploaded Python 3

File details

Details for the file gradipy-0.1.0.tar.gz.

File metadata

  • Download URL: gradipy-0.1.0.tar.gz
  • Upload date:
  • Size: 7.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.0

File hashes

Hashes for gradipy-0.1.0.tar.gz
Algorithm Hash digest
SHA256 4e763aac225a285263f6d40184c850cebd3726ad7b3bd481e132a2f7ba2d0160
MD5 dd57bb7ce46cd7975add4cebe0af6247
BLAKE2b-256 bcce1df61c37dc1129db5ef579ddcc25b543132003b6117625fde6ca15e13f66

See more details on using hashes here.

File details

Details for the file gradipy-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: gradipy-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 7.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.0

File hashes

Hashes for gradipy-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 ba14a546e8fe8dc048e8bcd149c707f9f99e41a4aaaca066e6e463a8cfc645ca
MD5 8993f62e4be2692fb5e11de45bd843e7
BLAKE2b-256 782c26eada0cccc16e3ea73a87d90d78f0ea22c56217e2aeb55b8ce18dea175f

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page