Skip to main content

A package designed for fast training and inference for large numbers of small neural networks simultaneously.

Project description

TurbaNet

TurbaNet is a lightweight and user-friendly API wrapper for the JAX library, designed to simplify and accelerate the setup of swarm-based training, evaluation, and simulation of small neural networks.​ Based on the work presented by Will Whitney in his blog post from 2021.[^1]

Key Features

  • Simplified API: Provides an intuitive interface for configuring and managing swarm-based neural network tasks.​
  • Efficiency: Leverages JAX's capabilities to offer accelerated computation for training and evaluation processes.​
  • Flexibility: Supports various configurations, allowing users to tailor the swarm behavior to specific needs.​

Installation

To install TurbaNet, ensure that you have Python and pip installed. Then, run:

pip install turbanet

TurbaNet train states require models and optimizers from Flax and Optax which can be installed with:

pip install flax optax

GPU Support

TurbaNet is built on top of JAX and fully supports GPU acceleration out of the box. To take advantage of GPU-based training, ensure your environment is configured correctly with the appropriate CUDA and cuDNN versions supported by JAX.

Refer to the official JAX installation guide for detailed instructions and compatibility tables, including supported CUDA versions and platform-specific recommendations.

If JAX detects a supported GPU and the correct environment configuration, TurbaNet will automatically execute on the GPU without any additional setup.

You can verify whether GPU support is enabled by running:

from jax.extend.backend import get_backend
print(get_backend().platform)  # Should return 'gpu' if GPU is active

Getting Started

Here's a basic example demonstrating how to initialize and use TurbaNet:

import matplotlib.pyplot as plt
import numpy as np
import optax
from flax import linen as nn
from turbanet import TurbaTrainState, mse

# Set numpy random seed for reproducable results
np.random.seed(0)

# Sample input
X_data = np.random.randint(0, 2, (10, 10)).astype(float)
y_data = np.random.randint(0, 2, (10, 1)).astype(float)


# Define model for the swarm
class MyModel(nn.Module):
    hidden_dim: int = 32

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(self.hidden_dim)(x)
        x = nn.relu(x)
        x = nn.Dense(1)(x)
        x = nn.sigmoid(x)
        return x


# Define optimizer
optimizer = optax.adam(learning_rate=0.01)

# Define the size of the swarm
swarm_size = 10

# Initialize the swarm with desired parameters
swarm = TurbaTrainState.swarm(MyModel(), optimizer, swarm_size, X_data[0].reshape(1, -1))

# Train the swarm on your dataset
epochs = 100
losses = np.zeros((epochs, swarm_size))
for epoch in range(epochs):
    X = np.expand_dims(X_data, 0).repeat(len(swarm), axis=0)
    y = np.expand_dims(y_data, 0).repeat(len(swarm), axis=0)
    swarm, losses[epoch], predictions = swarm.train(X, y, mse)

# Plot the loss curves from training
plt.plot(losses)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.show()

For more detailed tutorials and examples, please refer to the documentation.

Contributing

We welcome contributions to TurbaNet! If you'd like to contribute, please follow these steps:

Fork the repository: Click the "Fork" button at the top right of the GitHub page.​

Clone your fork:

`git clone https://github.com/your-username/TurbaNet.git`
  1. Create a new branch:

git checkout -b feature/your-feature-name

  1. Make your changes: Implement your feature or fix the identified issue.​ 5. Commit your changes:

git commit -m "Description of your changes"

  1. Push to your fork:

git push origin feature/your-feature-name

  1. Submit a Pull Request: Navigate to the original repository and click on "New Pull Request" to submit your changes for review.​

License

This project is licensed under the MIT License. See the LICENSE file for more details.

References

[^1]: Whitney, W. (2021). Parallelizing neural networks on one GPU with JAX. Will Whitney's Blog. https://willwhitney.com/parallel-training-jax.html

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

turbanet-0.3.1.tar.gz (9.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

turbanet-0.3.1-py3-none-any.whl (7.5 kB view details)

Uploaded Python 3

File details

Details for the file turbanet-0.3.1.tar.gz.

File metadata

  • Download URL: turbanet-0.3.1.tar.gz
  • Upload date:
  • Size: 9.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.7

File hashes

Hashes for turbanet-0.3.1.tar.gz
Algorithm Hash digest
SHA256 7be837e30684594ecd9393978d4dd743f81fd7da7d48e910575b97d867dca0d6
MD5 e85020192233e8456e7564e47e16de91
BLAKE2b-256 ce95a481b4bdf16f1ea47ce92a68e186114e1e0e96e32081812b94a599064b05

See more details on using hashes here.

File details

Details for the file turbanet-0.3.1-py3-none-any.whl.

File metadata

  • Download URL: turbanet-0.3.1-py3-none-any.whl
  • Upload date:
  • Size: 7.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.7

File hashes

Hashes for turbanet-0.3.1-py3-none-any.whl
Algorithm Hash digest
SHA256 b067af26109f0ad1832186d732e8312d882ade27fb02b216a61dedada877d83b
MD5 c66b7ffb4504386df1be6abdbef1a958
BLAKE2b-256 d601226aec3cb67d4b08520a409993b247d3bebfefd48091f39f9890b19b2c41

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page