Skip to main content

A lightweight machine learning framework

Project description

Unit tests Build Python Versions PyPI Version PyPI status


picograd

A lightweight machine learning framework

DescriptionFeaturesExamplesReferencesLicense

Description

A PyTorch-like lightweight deep learning framework written from scratch in Python.

The library has a built-in auto-differentiation engine that dynamically builds a computational graph. The framework is built with basic features to train neural nets: optimizers, training API, data utilities, metrics and loss functions. Additional tools are developed to visualize forward computational graph.

Features

  • PyTorch-like auto-differentiation engine (dynamically constructed computational graph)
  • Keras-like simple training API
  • Neural networks API
  • Activations: ReLU, Sigmoid, tanh
  • Optimizers: SGD, Adam
  • Loss: Mean squared error
  • Accuracy: Binary accuracy
  • Data utilities
  • Computational graph visualizer

Examples

The demo notebook showcases what picograd is all about.

Example Usage

from picograd.engine import Var
from picograd.graph_viz import ForwardGraphViz

graph_builder = ForwardGraphViz()

x = Var(1.0, label='x')
y = (x * 2 + 1).relu();
y.label = 'y'
y.backward()

graph_builder.create_graph(y)

Output:

Training MLP

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.datasets import make_moons

from picograd.nn import MLP
from picograd.engine import Var
from picograd.data import BatchIterator
from picograd.trainer import Trainer
from picograd.optim import SGD, Adam
from picograd.metrics import binary_accuracy, mean_squared_error

# Generate moon-shaped, non-linearly separable data
x_train, y_train = make_moons(n_samples=200, noise=0.10, random_state=0)

model = MLP(in_features=2, layers=[16, 16, 1], activations=['relu', 'relu', 'linear'])  # 2 hidden layers
print(model)
print(f"Number of parameters: {len(model.parameters())}")

optimizer = SGD(model.parameters(), lr=0.05)
data_iterator = BatchIterator(x_train, list(map(Var, y_train)))
trainer = Trainer(model, optimizer, loss=mean_squared_error, acc_metric=binary_accuracy)

history = trainer.fit(data_iterator, num_epochs=70, verbose=True)

Decision boundary:

References

  • Andrej Karpathy's micrograd library and intro explanation on training neural nets, which is the foundation of picograd's autograd engine.
  • Baptiste Pesquet's pyfit library, from which training API was borrowed.

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

picograd-1.0.4.tar.gz (116.3 kB view details)

Uploaded Source

Built Distribution

picograd-1.0.4-py3-none-any.whl (116.9 kB view details)

Uploaded Python 3

File details

Details for the file picograd-1.0.4.tar.gz.

File metadata

  • Download URL: picograd-1.0.4.tar.gz
  • Upload date:
  • Size: 116.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.9.14

File hashes

Hashes for picograd-1.0.4.tar.gz
Algorithm Hash digest
SHA256 e7e4dab24852b35676b9b9a3c71002ca2883e45011030dd03333c71f4b83430b
MD5 6fbc418dd1c54446427cacf1827cd9a6
BLAKE2b-256 6cfc2f4cbf28e725cd593d8e1e08848ca103302b71dd50a732db01f44c9eeaac

See more details on using hashes here.

File details

Details for the file picograd-1.0.4-py3-none-any.whl.

File metadata

  • Download URL: picograd-1.0.4-py3-none-any.whl
  • Upload date:
  • Size: 116.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.9.14

File hashes

Hashes for picograd-1.0.4-py3-none-any.whl
Algorithm Hash digest
SHA256 122c887edbeb1a83b4cc2a8734ec73c920415a4be9557159ac033fa96221d5b5
MD5 26b621ae834696847902a408191ccd49
BLAKE2b-256 b0a36bc5d9da5d45666b8f20524931f5f53c996bec33b8308c04bc8964be00e0

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page