Skip to main content

A lightweight machine learning framework

Project description

Unit tests Build Python Versions PyPI Version PyPI status


picograd

A lightweight machine learning framework

DescriptionFeaturesExamplesReferencesLicense

Description

A PyTorch-like lightweight deep learning framework written from scratch in Python.

The library has a built-in auto-differentiation engine that dynamically builds a computational graph. The framework is built with basic features to train neural nets: optimizers, training API, data utilities, metrics and loss functions. Additional tools are developed to visualize forward computational graph.

Features

  • PyTorch-like auto-differentiation engine (dynamically constructed computational graph)
  • Keras-like simple training API
  • Neural networks API
  • Activations: ReLU, Sigmoid, tanh
  • Optimizers: SGD, Adam
  • Loss: Mean squared error
  • Accuracy: Binary accuracy
  • Data utilities
  • Computational graph visualizer

Examples

The demo notebook showcases what picograd is all about.

Example Usage

from picograd.engine import Var
from picograd.graph_viz import ForwardGraphViz

graph_builder = ForwardGraphViz()

x = Var(1.0, label='x')
y = (x * 2 + 1).relu();
y.label = 'y'
y.backward()

graph_builder.create_graph(y)

Output:

Training MLP

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.datasets import make_moons

from picograd.nn import MLP
from picograd.engine import Var
from picograd.data import BatchIterator
from picograd.trainer import Trainer
from picograd.optim import SGD, Adam
from picograd.metrics import binary_accuracy, mean_squared_error

# Generate moon-shaped, non-linearly separable data
x_train, y_train = make_moons(n_samples=200, noise=0.10, random_state=0)

model = MLP(in_features=2, layers=[16, 16, 1], activations=['relu', 'relu', 'linear'])  # 2 hidden layers
print(model)
print(f"Number of parameters: {len(model.parameters())}")

optimizer = SGD(model.parameters(), lr=0.05)
data_iterator = BatchIterator(x_train, list(map(Var, y_train)))
trainer = Trainer(model, optimizer, loss=mean_squared_error, acc_metric=binary_accuracy)

history = trainer.fit(data_iterator, num_epochs=70, verbose=True)

Decision boundary:

References

  • Andrej Karpathy's micrograd library and intro explanation on training neural nets, which is the foundation of picograd's autograd engine.
  • Baptiste Pesquet's pyfit library, from which training API was borrowed.

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

picograd-1.0.5.tar.gz (111.3 kB view details)

Uploaded Source

Built Distribution

picograd-1.0.5-py3-none-any.whl (105.8 kB view details)

Uploaded Python 3

File details

Details for the file picograd-1.0.5.tar.gz.

File metadata

  • Download URL: picograd-1.0.5.tar.gz
  • Upload date:
  • Size: 111.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.9.14

File hashes

Hashes for picograd-1.0.5.tar.gz
Algorithm Hash digest
SHA256 64a154c48bb7f00818626f41a4f2a6643b03b1c445117145112408c6af8d952c
MD5 e700aa31ab6fa0b65d42a8309a8a1950
BLAKE2b-256 73894f3390c0877841c63a45afad34c90dc5f4d13d9f47551a8b9b7cc4072fe2

See more details on using hashes here.

File details

Details for the file picograd-1.0.5-py3-none-any.whl.

File metadata

  • Download URL: picograd-1.0.5-py3-none-any.whl
  • Upload date:
  • Size: 105.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.9.14

File hashes

Hashes for picograd-1.0.5-py3-none-any.whl
Algorithm Hash digest
SHA256 92258a99f474ee543515f685dff7efc159d33bbaa0d1aa06717cee5daa26c7ab
MD5 5edd74dbf6760c53737414063f832b2a
BLAKE2b-256 96e62890d2d5340a3ba4ee2fdfd58c5efcbf8ccd5f86abdbdf86cc2cd5fad4db

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page