Skip to main content

A minimal neural network framework with autodiff and NumPy

Project description

nnetflow

A minimal neural network framework with autodiff, inspired by micrograd and pytorch.

Installation

pip install nnetflow

From source

git clone https://github.com/lewisnjue/nnetflow.git
cd nnetflow
pip install -e .
from nnetflow.engine import Tensor
from nnetflow.layers import Linear
from nnetflow.module import Module
from nnetflow.optim import SGD
import numpy as np

# Define a simple MLP
class MLP(Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        super().__init__()
        self.fc1 = Linear(in_dim, hidden_dim)
        self.fc2 = Linear(hidden_dim, out_dim)

    def forward(self, x):
        x = self.fc1(x).relu()
        x = self.fc2(x)
        return x

# Generate dummy data
np.random.seed(0)
X = np.random.randn(100, 3).astype(np.float32)
y = (np.random.randn(100, 1) > 0).astype(np.float32)

# Convert to Tensor
X_tensor = Tensor(X, require_grad=False)
y_tensor = Tensor(y, require_grad=False)

# Instantiate model, loss, optimizer
model = MLP(3, 8, 1)
optimizer = SGD(model.parameters(), lr=0.1)

# Training loop
for epoch in range(10):
    optimizer.zero_grad()
    out = model(X_tensor)
    # Simple MSE loss
    loss = ((out - y_tensor) ** 2).mean()
    loss.backward()
    optimizer.step()
    print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

model.save("mlp_model.pkl")

# Load the model
loaded_model = Module.load("mlp_model.pkl")
# Verify loaded model
print(f"Loaded model: {loaded_model}")
# Check if the loaded model can still perform inference
test_out = loaded_model(X_tensor)
print(f"Test output from loaded model: {test_out.data[:5]}")  # Print first 5 outputs
# Check if the loaded model's parameters match the original model's parameters
for original_param, loaded_param in zip(model.parameters(), loaded_model.parameters()):
    assert np.array_equal(original_param.data, loaded_param.data), "Loaded parameters do not match original parameters"
print("All parameters match successfully after loading the model.")

Documentation

  • See docs/index.md for a full guide and API overview.
  • See CONTRIBUTING.md for contribution guidelines.
  • See CHANGELOG.md for release notes.

Examples

  • PyTorch vs nnetflow simple regression: examples/pytorch_vs_nnetflow.py
  • Classification comparison with decision boundaries: examples/classification_torch_vs_nnetflow.py
    • Outputs: examples/outputs/classification_boundaries.png, examples/outputs/classification_losses.png
  • Regression comparison with fit and loss curves: examples/regression_torch_vs_nnetflow.py
    • Outputs: examples/outputs/regression_fit_and_loss.png

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

nnetflow-0.1.1.tar.gz (14.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

nnetflow-0.1.1-py3-none-any.whl (13.1 kB view details)

Uploaded Python 3

File details

Details for the file nnetflow-0.1.1.tar.gz.

File metadata

  • Download URL: nnetflow-0.1.1.tar.gz
  • Upload date:
  • Size: 14.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.8.18

File hashes

Hashes for nnetflow-0.1.1.tar.gz
Algorithm Hash digest
SHA256 4eb738e9d54db45f16ea1d75c54a185e46c24c973c47d45ba7338f799a9a3d34
MD5 065fa38cdd697cdd346fe2ed82ff2d62
BLAKE2b-256 6c6043e50a647ea2ac52a33bf0e9224262f40ff1b49a7a8dc3fbce99ee7a74ea

See more details on using hashes here.

File details

Details for the file nnetflow-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: nnetflow-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 13.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.8.18

File hashes

Hashes for nnetflow-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 edc8f958fe4f7a3a6b700faff9fe1edeac968f9d61e211782b7b5e68c30571f4
MD5 4fc265952ef2d0e3d198b0ac335dacd5
BLAKE2b-256 319a25ddff52ae32c584c9a8b973ba608bbc38c73a93420dc290af11fa0c8743

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page