A minimal neural network framework with autodiff and NumPy
Project description
nnetflow
A minimal neural network framework with autodiff, inspired by micrograd and pytorch.
Installation
pip install nnetflow
from nnetflow.engine import Tensor
from nnetflow.layers import Linear
from nnetflow.module import Module
from nnetflow.optim import SGD
import numpy as np
# Define a simple MLP
class MLP(Module):
def __init__(self, in_dim, hidden_dim, out_dim):
super().__init__()
self.fc1 = Linear(in_dim, hidden_dim)
self.fc2 = Linear(hidden_dim, out_dim)
def forward(self, x):
x = self.fc1(x).relu()
x = self.fc2(x)
return x
# Generate dummy data
np.random.seed(0)
X = np.random.randn(100, 3).astype(np.float32)
y = (np.random.randn(100, 1) > 0).astype(np.float32)
# Convert to Tensor
X_tensor = Tensor(X, require_grad=False)
y_tensor = Tensor(y, require_grad=False)
# Instantiate model, loss, optimizer
model = MLP(3, 8, 1)
optimizer = SGD(model.parameters(), lr=0.1)
# Training loop
for epoch in range(10):
optimizer.zero_grad()
out = model(X_tensor)
# Simple MSE loss
loss = ((out - y_tensor) ** 2).mean()
loss.backward()
optimizer.step()
print(f"Epoch {epoch}, Loss: {loss.item():.4f}")
model.save("mlp_model.pkl")
# Load the model
loaded_model = Module.load("mlp_model.pkl")
# Verify loaded model
print(f"Loaded model: {loaded_model}")
# Check if the loaded model can still perform inference
test_out = loaded_model(X_tensor)
print(f"Test output from loaded model: {test_out.data[:5]}") # Print first 5 outputs
# Check if the loaded model's parameters match the original model's parameters
for original_param, loaded_param in zip(model.parameters(), loaded_model.parameters()):
assert np.array_equal(original_param.data, loaded_param.data), "Loaded parameters do not match original parameters"
print("All parameters match successfully after loading the model.")
...
See the docs/ folder for more details.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
nnetflow-1.0.8.tar.gz
(9.7 kB
view details)
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file nnetflow-1.0.8.tar.gz.
File metadata
- Download URL: nnetflow-1.0.8.tar.gz
- Upload date:
- Size: 9.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.8.18
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
aeb538b315ada46346633bf17377c859383e2fb84b1297dbad7e1131403f7a79
|
|
| MD5 |
4bac828ca7ae0fd44c4cec22654f4049
|
|
| BLAKE2b-256 |
f84e823d584027c6d269279ecd637de150024fa77edd222d31cab2ab96e2627f
|
File details
Details for the file nnetflow-1.0.8-py3-none-any.whl.
File metadata
- Download URL: nnetflow-1.0.8-py3-none-any.whl
- Upload date:
- Size: 9.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.8.18
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
474f56b35929f6a1c9b1d25b1c86637c9b6e21175506e54d6463ea0c07a09804
|
|
| MD5 |
a2825238936721384dfb2d058ef552df
|
|
| BLAKE2b-256 |
e0cc2974b4210de19a751b8aa02efa224592c6b7e0c626837411fc8e86fb3a27
|