Skip to main content

A minimal neural network framework with autodiff and NumPy

Project description

nnetflow

A minimal neural network framework with autodiff, inspired by micrograd and pytorch.

Installation

pip install nnetflow
from nnetflow.engine import Tensor
from nnetflow.layers import Linear
from nnetflow.module import Module
from nnetflow.optim import SGD
import numpy as np

# Define a simple MLP
class MLP(Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        super().__init__()
        self.fc1 = Linear(in_dim, hidden_dim)
        self.fc2 = Linear(hidden_dim, out_dim)

    def forward(self, x):
        x = self.fc1(x).relu()
        x = self.fc2(x)
        return x

# Generate dummy data
np.random.seed(0)
X = np.random.randn(100, 3).astype(np.float32)
y = (np.random.randn(100, 1) > 0).astype(np.float32)

# Convert to Tensor
X_tensor = Tensor(X, require_grad=False)
y_tensor = Tensor(y, require_grad=False)

# Instantiate model, loss, optimizer
model = MLP(3, 8, 1)
optimizer = SGD(model.parameters(), lr=0.1)

# Training loop
for epoch in range(10):
    optimizer.zero_grad()
    out = model(X_tensor)
    # Simple MSE loss
    loss = ((out - y_tensor) ** 2).mean()
    loss.backward()
    optimizer.step()
    print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

model.save("mlp_model.pkl")

# Load the model
loaded_model = Module.load("mlp_model.pkl")
# Verify loaded model
print(f"Loaded model: {loaded_model}")
# Check if the loaded model can still perform inference
test_out = loaded_model(X_tensor)
print(f"Test output from loaded model: {test_out.data[:5]}")  # Print first 5 outputs
# Check if the loaded model's parameters match the original model's parameters
for original_param, loaded_param in zip(model.parameters(), loaded_model.parameters()):
    assert np.array_equal(original_param.data, loaded_param.data), "Loaded parameters do not match original parameters"
print("All parameters match successfully after loading the model.")

...

See the docs/ folder for more details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

nnetflow-1.0.8.tar.gz (9.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

nnetflow-1.0.8-py3-none-any.whl (9.6 kB view details)

Uploaded Python 3

File details

Details for the file nnetflow-1.0.8.tar.gz.

File metadata

  • Download URL: nnetflow-1.0.8.tar.gz
  • Upload date:
  • Size: 9.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.8.18

File hashes

Hashes for nnetflow-1.0.8.tar.gz
Algorithm Hash digest
SHA256 aeb538b315ada46346633bf17377c859383e2fb84b1297dbad7e1131403f7a79
MD5 4bac828ca7ae0fd44c4cec22654f4049
BLAKE2b-256 f84e823d584027c6d269279ecd637de150024fa77edd222d31cab2ab96e2627f

See more details on using hashes here.

File details

Details for the file nnetflow-1.0.8-py3-none-any.whl.

File metadata

  • Download URL: nnetflow-1.0.8-py3-none-any.whl
  • Upload date:
  • Size: 9.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.8.18

File hashes

Hashes for nnetflow-1.0.8-py3-none-any.whl
Algorithm Hash digest
SHA256 474f56b35929f6a1c9b1d25b1c86637c9b6e21175506e54d6463ea0c07a09804
MD5 a2825238936721384dfb2d058ef552df
BLAKE2b-256 e0cc2974b4210de19a751b8aa02efa224592c6b7e0c626837411fc8e86fb3a27

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page