Skip to main content

A lightweight machine learning framework

Project description


picograd

A lightweight machine learning framework

DescriptionFeaturesExamplesReferencesLicense

Description

A PyTorch-like lightweight deep learning framework written from scratch in Python.

The library has a built-in auto-differentiation engine that dynamically builds a computational graph. The framework is built with basic features to train neural nets: optimizers, training API, data utilities, metrics and loss functions. Additional tools are developed to visualize forward computational graph.

Features

  • PyTorch-like auto-differentiation engine (dynamically constructed computational graph)
  • Keras-like simple training API
  • Neural networks API
  • Activations: ReLU, Sigmoid, tanh
  • Optimizers: SGD, Adam
  • Loss: Mean squared error
  • Accuracy: Binary accuracy
  • Data utilities
  • Computational graph visualizer

Examples

The demo notebook showcases what picograd is all about.

Example Usage

from picograd.engine import Var
from picograd.graph_viz import ForwardGraphViz

graph_builder = ForwardGraphViz()

x = Var(1.0, label='x')
y = (x * 2 + 1).relu(); y.label='y'
y.backward()

graph_builder.create_graph(y)

Output:

simple Graph

Training MLP

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.datasets import make_moons

from picograd.nn import MLP
from picograd.engine import Var
from picograd.data import BatchIterator
from picograd.trainer import Trainer
from picograd.optim import SGD, Adam
from picograd.metrics import binary_accuracy, mean_squared_error

# Generate moon-shaped, non-linearly separable data
x_train, y_train = make_moons(n_samples=200, noise=0.10, random_state=0)

model = MLP(in_features=2, layers=[16, 16, 1], activations=['relu', 'relu', 'linear']) # 2 hidden layers
print(model)
print(f"Number of parameters: {len(model.parameters())}")

optimizer = SGD(model.parameters(), lr=0.05)
data_iterator = BatchIterator(x_train, list(map(Var, y_train)))
trainer = Trainer(model, optimizer, loss=mean_squared_error, acc_metric=binary_accuracy)

history = trainer.fit(data_iterator, num_epochs=70, verbose=True)

Decision boundary:

mlp

References

  • Andrej Karpathy's micrograd library and intro explanation on training neural nets, which is the foundation of picograd's autograd engine.
  • Baptiste Pesquet's pyfit library, from which training API was borrowed.

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

picograd-1.0.1.tar.gz (115.9 kB view details)

Uploaded Source

Built Distribution

picograd-1.0.1-py3-none-any.whl (116.8 kB view details)

Uploaded Python 3

File details

Details for the file picograd-1.0.1.tar.gz.

File metadata

  • Download URL: picograd-1.0.1.tar.gz
  • Upload date:
  • Size: 115.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.9.14

File hashes

Hashes for picograd-1.0.1.tar.gz
Algorithm Hash digest
SHA256 c93b5563e3ea8f61f6996ce58c1289cb7db35b25cd505b191bdfeed3d77d4d6e
MD5 f70f7e7761aa9d58941a317c7ab5115d
BLAKE2b-256 2dd454c43921c3e08539ce0ee7ab0f557d77bca455eb3608e72979291b0e54b7

See more details on using hashes here.

File details

Details for the file picograd-1.0.1-py3-none-any.whl.

File metadata

  • Download URL: picograd-1.0.1-py3-none-any.whl
  • Upload date:
  • Size: 116.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.9.14

File hashes

Hashes for picograd-1.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 72152cdc6137619e30288454f5d85596f680c410c22cb6693fe66eaf4aa2a47e
MD5 79b4db5e29dd89db688dda37b99af548
BLAKE2b-256 3599714d08bee2f79f03f3fb0c88ed5f9102328073017cde5660715e155eb4ae

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page