Skip to main content

A package designed for fast training and inference for large numbers of small neural networks simultaneously.

Project description

TurbaNet

TurbaNet is a lightweight and user-friendly API wrapper for the JAX library, designed to simplify and accelerate the setup of swarm-based training, evaluation, and simulation of small neural networks.​ Based on the work presented by Will Whitney in his blog post from 2021.[^1]

Key Features

  • Simplified API: Provides an intuitive interface for configuring and managing swarm-based neural network tasks.​
  • Efficiency: Leverages JAX's capabilities to offer accelerated computation for training and evaluation processes.​
  • Flexibility: Supports various configurations, allowing users to tailor the swarm behavior to specific needs.​

Installation

To install TurbaNet, ensure that you have Python and pip installed. Then, run:

pip install turbanet

TurbaNet train states require models and optimizers from Flax and Optax which can be installed with:

pip install flax optax

Getting Started

Here's a basic example demonstrating how to initialize and use TurbaNet:

import matplotlib.pyplot as plt
import numpy as np
import optax
from flax import linen as nn
from turbanet import TurbaTrainState, mse

# Set numpy random seed for reproducable results
np.random.seed(0)

# Sample input
X_data = np.random.randint(0, 2, (10, 10)).astype(float)
y_data = np.random.randint(0, 2, (10, 1)).astype(float)


# Define model for the swarm
class MyModel(nn.Module):
    hidden_dim: int = 32

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(self.hidden_dim)(x)
        x = nn.relu(x)
        x = nn.Dense(1)(x)
        x = nn.sigmoid(x)
        return x


# Define optimizer
optimizer = optax.adam(learning_rate=0.01)

# Define the size of the swarm
swarm_size = 10

# Initialize the swarm with desired parameters
swarm = TurbaTrainState.swarm(MyModel(), optimizer, swarm_size, X_data[0].reshape(1, -1))

# Train the swarm on your dataset
epochs = 100
losses = np.zeros((epochs, swarm_size))
for epoch in range(epochs):
    X = np.expand_dims(X_data, 0).repeat(len(swarm), axis=0)
    y = np.expand_dims(y_data, 0).repeat(len(swarm), axis=0)
    swarm, losses[epoch], predictions = swarm.train(X, y, mse)

# Plot the loss curves from training
plt.plot(losses)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.show()

For more detailed tutorials and examples, please refer to the documentation.

Contributing

We welcome contributions to TurbaNet! If you'd like to contribute, please follow these steps:

Fork the repository: Click the "Fork" button at the top right of the GitHub page.​

Clone your fork:

`git clone https://github.com/your-username/TurbaNet.git`
  1. Create a new branch:

git checkout -b feature/your-feature-name

  1. Make your changes: Implement your feature or fix the identified issue.​ 5. Commit your changes:

git commit -m "Description of your changes"

  1. Push to your fork:

git push origin feature/your-feature-name

  1. Submit a Pull Request: Navigate to the original repository and click on "New Pull Request" to submit your changes for review.​

License

This project is licensed under the MIT License. See the LICENSE file for more details.

References

[^1]: Whitney, W. (2021). Parallelizing neural networks on one GPU with JAX. Will Whitney's Blog. https://willwhitney.com/parallel-training-jax.html

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

turbanet-0.3.0.tar.gz (9.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

turbanet-0.3.0-py3-none-any.whl (7.1 kB view details)

Uploaded Python 3

File details

Details for the file turbanet-0.3.0.tar.gz.

File metadata

  • Download URL: turbanet-0.3.0.tar.gz
  • Upload date:
  • Size: 9.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.7

File hashes

Hashes for turbanet-0.3.0.tar.gz
Algorithm Hash digest
SHA256 17f69d4daebe093ea0134a550f09b355aee4d236ad3276188080fcc475291725
MD5 d80a3488f56900452e3dfb6c2b64dd88
BLAKE2b-256 3d9a12c1f822cb5aed006f5bf00c7c3db32983f47327bff7b764164fa5bf1491

See more details on using hashes here.

File details

Details for the file turbanet-0.3.0-py3-none-any.whl.

File metadata

  • Download URL: turbanet-0.3.0-py3-none-any.whl
  • Upload date:
  • Size: 7.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.7

File hashes

Hashes for turbanet-0.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 4b540db46013666e65cedfe5367ad4a8dedcd12fdcaefc09e95f7f7dcd47ba39
MD5 1104c8fcda4abc9af1c33dd08badbbba
BLAKE2b-256 c624bd941564736cee6f3d21385973158ed2a86fa8a283e79b564805dda2fbdf

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page