Skip to main content

A JAX-based neural network library surpassing Google DeepMind's Haiku and Optax

Project description

NextGenJAX

Overview

NextGenJAX is an advanced neural network library built on top of JAX, designed to surpass the capabilities of existing libraries such as Google DeepMind's Haiku and Optax. It leverages the flexibility and performance of JAX and Flax to provide a modular, high-performance, and easy-to-use framework for building and training neural networks.

Features

  • Modular design with customizable layers and activation functions
  • Support for various optimizers, including custom optimizers
  • Flexible training loop with support for custom loss functions
  • Integration with JAX and Flax for high performance and scalability
  • Comprehensive test suite to ensure model correctness and performance

Installation

To install NextGenJAX, you can use pip:

pip install nextgenjax

For development, clone the repository and install the required dependencies:

git clone https://github.com/VishwamAI/NextGenJAX.git
cd NextGenJAX
pip install -r requirements.txt

Usage

Creating a Model

To create a model using NextGenJAX, define the layers and activation functions, and initialize the model:

import jax
import jax.numpy as jnp
from src.layers import DenseLayer, ConvolutionalLayer
from src.custom_layers import CustomLayer
from src.model import NextGenModel

# Define the layers
layers = [
    DenseLayer(features=128, activation=jnp.relu),
    ConvolutionalLayer(features=64, kernel_size=(3, 3), activation=jnp.relu),
    CustomLayer(features=10, activation=jnp.tanh)
]

# Initialize the model
model = NextGenModel(layers=layers)

Training the Model

To train the model, use the training loop provided in train.py:

from src.train import create_train_state, train_model
from src.optimizers import sgd, adam

# Define the optimizer
optimizer = adam(learning_rate=0.001)

# Create the training state
train_state = create_train_state(model, optimizer)

# Define the training data and loss function
train_data = ...  # Your training data here
loss_fn = ...  # Your loss function here

# Train the model
train_model(train_state, train_data, loss_fn, num_epochs=10)

Development Setup

To set up a development environment:

  1. Clone the repository
  2. Install development dependencies: pip install -r requirements-dev.txt
  3. Run tests using pytest: pytest tests/

We use GitHub Actions for continuous integration and deployment. Our CI/CD workflow runs tests on Python 3.9 to ensure compatibility and code quality.

Community and Support

We welcome community engagement and support for the NextGenJAX project:

Contributing

We welcome contributions to NextGenJAX! Please follow these steps:

  1. Fork the repository
  2. Create a new branch (git checkout -b feature/your-feature)
  3. Make your changes and commit them (git commit -am 'Add new feature')
  4. Push to the branch (git push origin feature/your-feature)
  5. Create a new pull request using the Pull Request Template

Please adhere to our coding standards:

  • Follow PEP 8 guidelines
  • Write unit tests for new features
  • Update documentation as necessary

For more detailed guidelines, please refer to the CONTRIBUTING.md file.

Reporting Issues

If you encounter any issues or have suggestions for improvements, please open an issue in the repository. Use the appropriate issue template:

Provide as much detail as possible to help us understand and address the problem.

License

NextGenJAX is licensed under the MIT License. See the LICENSE file for more information.

Acknowledgements

NextGenJAX is inspired by the work of Google DeepMind and the JAX and Flax communities. We thank them for their contributions to the field of machine learning.

Contact Information

For support or questions about NextGenJAX, please reach out to:

Last updated: 2023-05-10 12:00:00 UTC

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

nextgenjax-0.1.1.tar.gz (23.7 kB view details)

Uploaded Source

Built Distribution

nextgenjax-0.1.1-py3-none-any.whl (24.4 kB view details)

Uploaded Python 3

File details

Details for the file nextgenjax-0.1.1.tar.gz.

File metadata

  • Download URL: nextgenjax-0.1.1.tar.gz
  • Upload date:
  • Size: 23.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.10.12

File hashes

Hashes for nextgenjax-0.1.1.tar.gz
Algorithm Hash digest
SHA256 775e5848a1635a4cfac3f1014f8e26d5de05882f61530c2a12fa5aa4f2b1a315
MD5 461d6ab50e7753d5f23cc843063088cc
BLAKE2b-256 20cd83a4d81b1a9a65178f457eca50f4cff1247aa51e846dd0137c151a27caa7

See more details on using hashes here.

File details

Details for the file nextgenjax-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: nextgenjax-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 24.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.10.12

File hashes

Hashes for nextgenjax-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 f49faf7d10b7e3d3d0552a2327a27df3c1cf3173594e940a0e7a4fe811e91e96
MD5 3c4b4c3630fb70c02e97e672b2f71379
BLAKE2b-256 df8246098d7e82160b5890a9150191316777579c9ee3b99f5867341bf5b671ce

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page