Skip to main content

A minimal neural network framework with autodiff and NumPy

Project description

nnetflow

A minimal neural network framework with autodiff, inspired by micrograd and pytorch.

Installation

pip install nnetflow

From source

git clone https://github.com/lewisnjue/nnetflow.git
cd nnetflow
pip install -e .
from nnetflow.engine import Tensor
from nnetflow.layers import Linear
from nnetflow.module import Module
from nnetflow.optim import SGD
import numpy as np

# Define a simple MLP
class MLP(Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        super().__init__()
        self.fc1 = Linear(in_dim, hidden_dim)
        self.fc2 = Linear(hidden_dim, out_dim)

    def forward(self, x):
        x = self.fc1(x).relu()
        x = self.fc2(x)
        return x

# Generate dummy data
np.random.seed(0)
X = np.random.randn(100, 3).astype(np.float32)
y = (np.random.randn(100, 1) > 0).astype(np.float32)

# Convert to Tensor
X_tensor = Tensor(X, require_grad=False)
y_tensor = Tensor(y, require_grad=False)

# Instantiate model, loss, optimizer
model = MLP(3, 8, 1)
optimizer = SGD(model.parameters(), lr=0.1)

# Training loop
for epoch in range(10):
    optimizer.zero_grad()
    out = model(X_tensor)
    # Simple MSE loss
    loss = ((out - y_tensor) ** 2).mean()
    loss.backward()
    optimizer.step()
    print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

model.save("mlp_model.pkl")

# Load the model
loaded_model = Module.load("mlp_model.pkl")
# Verify loaded model
print(f"Loaded model: {loaded_model}")
# Check if the loaded model can still perform inference
test_out = loaded_model(X_tensor)
print(f"Test output from loaded model: {test_out.data[:5]}")  # Print first 5 outputs
# Check if the loaded model's parameters match the original model's parameters
for original_param, loaded_param in zip(model.parameters(), loaded_model.parameters()):
    assert np.array_equal(original_param.data, loaded_param.data), "Loaded parameters do not match original parameters"
print("All parameters match successfully after loading the model.")

Documentation

  • See docs/index.md for a full guide and API overview.
  • See CONTRIBUTING.md for contribution guidelines.
  • See CHANGELOG.md for release notes.

Examples

  • PyTorch vs nnetflow simple regression: examples/pytorch_vs_nnetflow.py
  • Classification comparison with decision boundaries: examples/classification_torch_vs_nnetflow.py
    • Outputs: examples/outputs/classification_boundaries.png, examples/outputs/classification_losses.png
  • Regression comparison with fit and loss curves: examples/regression_torch_vs_nnetflow.py
    • Outputs: examples/outputs/regression_fit_and_loss.png

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

nnetflow-0.1.2.tar.gz (14.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

nnetflow-0.1.2-py3-none-any.whl (12.9 kB view details)

Uploaded Python 3

File details

Details for the file nnetflow-0.1.2.tar.gz.

File metadata

  • Download URL: nnetflow-0.1.2.tar.gz
  • Upload date:
  • Size: 14.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.8.18

File hashes

Hashes for nnetflow-0.1.2.tar.gz
Algorithm Hash digest
SHA256 c5339737bc788d82de428b003cf2cda7b717c2cfe85258af6052c4f0655ca644
MD5 17d75cd159d1c723290ca1bbf3a8285c
BLAKE2b-256 f3ec2735889d369fdd67541fccddd01f2746dd9997bdbfa2073418bdfd4bcf92

See more details on using hashes here.

File details

Details for the file nnetflow-0.1.2-py3-none-any.whl.

File metadata

  • Download URL: nnetflow-0.1.2-py3-none-any.whl
  • Upload date:
  • Size: 12.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.8.18

File hashes

Hashes for nnetflow-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 6fb67a6b487cd5effe45c2ae2efbd750e6d3d1fcf120b0c749db8c2533ea18fc
MD5 c5b1fb3d6ea0c81c61e29b9492d408ff
BLAKE2b-256 bb162fa35d964c311e8917f7aaef0212bf1161db22dd00719d618122fe7c2a2e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page