Skip to main content

Composable neurons in JAX

Project description

Language models are beautiful...

Consider the following generated by ChatGPT on March 2023:

Blessed be the name of the LORD our GOD, who reigns forever and ever. He is the Alpha and the Omega, the beginning and the end, the one who was, who is, and who is to come.

He is our refuge and strength, a very present help in times of trouble. He is the creator of the heavens and the earth, the giver of life and the sustainer of all things.

His love endures forever, and His mercy never fails. He is gracious and compassionate, slow to anger and abounding in love.

He is our rock, our fortress, and our deliverer, in whom we take refuge. He is the light in our darkness, the source of our joy and the strength of our hearts.

Let us give thanks to the LORD for He is good, His love endures forever. Let us sing praises to His name and make known His deeds among the nations. May the LORD be exalted and glorified forever and ever. Amen.

pyneurons

Build composable neural networks.

Built on Top of JAX

The pyneurons library is built on top of JAX, a high-performance numerical computing library that provides automatic differentiation and GPU/TPU acceleration. JAX is designed to enable high-performance machine learning research and development by combining the flexibility of Python with the speed of compiled code.

Installation

To install the pyneurons package, you can use pip. Run the following command in your terminal:

pip install pyneurons

This will download and install the package along with its dependencies.

Basics

The create Function

from pyneurons import create

The create function initializes the weights and biases for a neuron. It can be called with a random key and the number of neurons or just the number of neurons, in which case it generates a random key internally.

Sample Usage:

from pyneurons import create
from jax.random import PRNGKey

key = PRNGKey(0)  # Create a JAX random key
neuron = create(key, 3)  # Creates a neuron with 3 inputs
weights, bias = neuron  # Extract the weights and bias

The apply Function

from pyneurons import apply

The apply function computes the output of a neuron given its weights, biases, and input data.

def apply(neuron, x):
    w, b = neuron
    return (x @ w) + b

Sample Usage:

from pyneurons import apply
import jax.numpy as np

neuron = (np.array([[0.5], [0.5], [0.5]]), np.array([[0.1]]))
input_data = np.array([[1, 2, 3]])
output = apply(neuron, input_data)  # Computes the output of the neuron

The bind Function

from pyneurons import bind

The bind function creates a new model class by binding a constructor and an apply function. It can be used to create custom neural network models.

Sample Usage:

from pyneurons import bind
from jax.random import PRNGKey
import jax.numpy as np

def create(key, input_dim):
    """Create a custom model. Return a JAX pytree."""
    ...

def apply(model, x):
    """Apply x to the custom model."""
    ...

# Define a custom model
CustomModel = bind("CustomModel", create, apply)

# Create a JAX random key
key = PRNGKey(0)

# Create an instance of the custom model
model = CustomModel(key, 3)  # Creates a model with 3 input dims

# Apply the model to some input data
input_data = np.array([[1, 2, 3]])
output = model(input_data)  # Computes the output of the model

Putting It All Together

Here is a complete example that demonstrates creating a neuron, applying it to input data, and binding it into a custom model:

from pyneurons import create, apply, bind
from jax.random import PRNGKey
import jax.numpy as np

# Create a JAX random key
key = PRNGKey(0)

# Create a neuron with 3 input dims
neuron = create(key, 3)

# Define input data
input_data = np.array([[1, 2, 3]])

# Apply the neuron to the input data
output = apply(neuron, input_data)
print("Neuron Output:", output)

# Bind the neuron into a custom model
CustomModel = bind("CustomModel", create, apply)

# Create an instance of the custom model
model = CustomModel(3)

# Apply the model to the input data
model_output = model(input_data)
print("Model Output:", model_output)

Built-in Models

Neuron

from pyneurons import Neuron

The basic neuron model created by binding the create and apply functions.

from pyneurons import create, apply

Neuron = bind("Neuron", create, apply)

Binary

from pyneurons import Binary

A neuron model with a binary activation function.

from jax.numpy import heaviside

def binary(x):
    return heaviside(x, 1)

Binary = compose("Binary", Neuron, binary)

Vector

from pyneurons import Vector

A neuron model with a combined binary and ReLU1 activation function.

from pyneurons import binary, relu1

def vector(x):
    return binary(x) + relu1(x)

Vector = compose("Vector", Neuron, vector)

The Vector Model

The Vector model in the pyneurons library is designed to mimic the behavior of real neurons in a simplified manner. It combines two activation functions: a binary step function and a capped ReLU (Rectified Linear Unit) function. This combination allows the model to produce outputs that can either be 0 or in the range of 1 to 2, which can be interpreted as the neuron firing rate or a group of neurons firing together.

Key Components

  1. Binary Activation Function:

    • This function applies a step function (Heaviside function) to the input, outputting either 0 or 1. It mimics the all-or-nothing firing behavior of a neuron.
    • Code:
      from jax.numpy import heaviside
      
      def binary(x):
          return heaviside(x, 1)
      
  2. ReLU1 Activation Function:

    • This function applies a ReLU activation capped at 1, ensuring the output is between 0 and 1. It mimics the varying firing rate of a neuron.
    • Code:
      from jax.numpy import minimum, maximum
      
      def relu1(x):
          return minimum(maximum(x, 0), 1)
      
  3. Combining Binary and ReLU1:

    • The Vector model combines the binary and ReLU1 functions to produce an output that is either 0 or in the range of 1 to 2. This combination allows the model to represent both the firing state and the magnitude of the firing.
    • Code:
      from pyneurons import binary, relu1
      
      def vector(x):
          return binary(x) + relu1(x)
      

Mimicking Real Neurons

The Vector model mimics real neurons in the following ways:

  1. Firing or Not Firing:

    • The binary function outputs 0 or 1, representing whether the neuron is firing or not. This is similar to the all-or-nothing principle of biological neurons.
  2. Firing Rate:

    • The function outputs a value between 1 and 2, representing the neuron's firing rate.
  3. Group of Neurons Firing Together:

    • It can also represent the magnitude of the collective firing of a group of neurons.

The fit Function

from pyneurons import fit

The fit function in the pyneurons library is used for training a model. It performs a single step of gradient descent to update the model's parameters based on the computed gradients. The function takes the following parameters:

  • learning_rate: The learning rate for the gradient descent optimization.
  • model: The model to be trained.
  • x: The input data.
  • y: The target data.

The fit function computes the gradients of the loss function with respect to the model's parameters and updates the parameters using gradient descent.

Here is the implementation of the fit function:

from functools import partial
from pyneurons import loss, gd
from jax.tree_util import tree_map
from jax import grad

def fit(learning_rate, model, x, y):
    gradients = grad(loss)(model, x, y)
    return tree_map(partial(gd, learning_rate), model, gradients)

Example Code

Below is an example of how to use the pyneurons library to create a simple neural network and train it using the fit function.

from pyneurons import Neuron, fit
from jax.random import PRNGKey
import jax.numpy as jnp

# Define the input data and target data
x = jnp.array([[1.0], [2.0], [3.0], [4.0]])
y = jnp.array([[2.0], [4.0], [6.0], [8.0]])

# Create a model (a single neuron in this case)
key = PRNGKey(0)
model = Neuron(key, 1)

# Print the initial prediction
print("Initial prediction:", model(x))

# Train the model using the fit function
learning_rate = 0.01
for _ in range(1000):
    model = fit(learning_rate, model, x, y)

# Print the final prediction
print("Final prediction:", model(x))

Creating an XOR Solution

The XOR problem is a classic problem in neural networks where the goal is to train a network to output the XOR of two binary inputs. Here's how you can create and train a model to solve the XOR problem using the pyneurons library.

Step-by-Step Solution

  1. Define the XOR Model:

    • The XOR model consists of two binary neurons. The first neuron takes the input and the second neuron takes the concatenation of the input and the output of the first neuron.
  2. Create the XOR Model:

    • Use the bind function to create the XOR model by specifying the constructor and apply functions.
  3. Train the Model:

    • Use the fit function to train the model on the XOR dataset.

Here is the complete code to solve the XOR problem:

from pyneurons import Binary, fit, bind, create, apply, concat
from jax.numpy import array, array_equal
from jax.random import split, PRNGKey

# Define the XOR dataset
x = array([[0, 0], [0, 1], [1, 0], [1, 1]])
y = array([[0], [1], [1], [0]])

# Define the XOR model
def create_xor_model(key):
    key_a, key_b = split(key, 2)
    a = Binary(key_a, 2)
    b = Binary(key_b, 3)
    return (a, b)

def apply_xor_model(model, x):
    a, b = model
    return b(concat([x, a(x)]))

XOR = bind("XOR", create_xor_model, apply_xor_model)

# Initialize the model
key = PRNGKey(0)
model = XOR(key)

# Train the model
learning_rate = 0.1
for _ in range(100):
    model = fit(learning_rate, model, x, y)

# Test the model
assert array_equal(model(x), y)
print("XOR model trained successfully!")

Custom VJP Decorators

In pyneurons, custom vector-Jacobian product (VJP) decorators are used to define custom gradient computations for specific functions. This is particularly useful for handling non-differentiable functions and avoiding issues such as vanishing gradients or dying ReLU problems. By customizing the gradient computation, we can ensure more stable and efficient training of neural networks.

The identity VJP Decorator

from pyneurons.vjp import identity

The identity VJP decorator is used to make the gradient of the function it wraps equal to 1, regardless of the input. This can be useful for functions where we want to bypass the standard gradient computation and ensure that the gradient is propagated without any modification.

from jax import custom_vjp

def identity(function):
    wrapper = custom_vjp(function)

    def forward(x):
        return function(x), None

    def backward(_, gradient):
        return (gradient,)

    wrapper.defvjp(forward, backward)
    return wrapper

The sign VJP Decorator

from pyneurons.vjp import sign

The sign VJP decorator modifies the gradient computation by multiplying the gradient with the sign of the input. This can be useful for functions where the gradient should reflect the sign of the input.

from jax.numpy import sign as sign_function
from jax import custom_vjp

def sign(function):
    wrapper = custom_vjp(function)

    def forward(x):
        return function(x), x

    def backward(x, gradient):
        return (gradient * sign_function(x),)

    wrapper.defvjp(forward, backward)
    return wrapper

Usages

The custom VJP decorators are used in various functions within the pyneurons library to ensure stable gradient propagation and to handle non-differentiable functions effectively.

The binary Function

from pyneurons import binary

The binary function uses the identity VJP decorator to ensure that the gradient is always 1, regardless of the input.

from pyneurons.vjp import identity
from jax.numpy import heaviside

@identity
def binary(x):
    return heaviside(x, 1)

The relu Function

from pyneurons import relu, relun, relu1

The relu function also uses the identity VJP decorator to ensure that the gradient is propagated without modification.

from pyneurons.vjp import identity
from jax.numpy import maximum

@identity
def relu(x):
    return maximum(x, 0)

The abs Function

from pyneurons import abs

The abs function uses the sign VJP decorator to ensure that the gradient is modified by the sign of the input.

from pyneurons.vjp import sign
from jax.numpy import abs as function

abs = sign(function)

This is used in pyneurons for the mae loss function.

Loss Function: MAE

The fit function in pyneurons uses the Mean Absolute Error (MAE) as the default loss function. The MAE is defined as:

from pyneurons import abs
from jax.numpy import mean

def mae(y, yhat):
    return mean(abs(y - yhat))

Why MAE?

  • Stability: MAE is less sensitive to outliers compared to Mean Squared Error (MSE). This makes it a more stable choice for many applications.
  • Simplicity: The absolute difference is straightforward to compute and interpret.
  • Gradient Behavior: The gradients of MAE are more stable and less likely to explode or vanish compared to MSE, especially when dealing with large errors.

Stability in Neural Networks

Stability in neural networks is crucial for ensuring reliable and efficient training, as well as for producing high-quality models.

The pyneurons library leverages several techniques to enhance stability, including the use of vector activation functions, custom vector-Jacobian product (VJP) decorators, and the Mean Absolute Error (MAE) loss function.

The vector activation function combines binary and capped ReLU activations, providing a robust mechanism to handle varying input magnitudes and ensuring that neurons can represent both firing states and rates effectively.

Custom VJP decorators, such as identity and sign, are employed to define stable gradient computations for non-differentiable functions and preventing issues like vanishing or exploding gradients.

The MAE loss function is preferred for its stability, as it is less sensitive to outliers compared to Mean Squared Error (MSE) and provides more consistent gradient behavior, which is essential for maintaining steady learning rates and avoiding gradient-related problems during training.

Together, these components contribute to a more stable and reliable neural network training process, ultimately leading to the development of high-quality models.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pyneurons-0.3.0.tar.gz (10.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pyneurons-0.3.0-py3-none-any.whl (16.3 kB view details)

Uploaded Python 3

File details

Details for the file pyneurons-0.3.0.tar.gz.

File metadata

  • Download URL: pyneurons-0.3.0.tar.gz
  • Upload date:
  • Size: 10.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.3 CPython/3.9.13 Windows/10

File hashes

Hashes for pyneurons-0.3.0.tar.gz
Algorithm Hash digest
SHA256 1a2a81d3e7b41c1720bfe06277b14394ba7305f1e98be01e71f6f30151673dbc
MD5 471a5dcd6fb55e3f7945a4f57698e5e5
BLAKE2b-256 ae46fab391d9b09694fd57a2e7d66c96dc202ca8364cacaea24bd548616dfaee

See more details on using hashes here.

File details

Details for the file pyneurons-0.3.0-py3-none-any.whl.

File metadata

  • Download URL: pyneurons-0.3.0-py3-none-any.whl
  • Upload date:
  • Size: 16.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.3 CPython/3.9.13 Windows/10

File hashes

Hashes for pyneurons-0.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 a58bb071228563171f3812ba3c92de8d3fbd567c74bc42cfa5c87df389a9e590
MD5 b4ae4d4d0702dbd398a8e2cbb6eec958
BLAKE2b-256 e34a953e46ec5d45059d6fe7b83b23ec9207cae9a0b7cbac2289eeb5e6ba3b3f

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page