Skip to main content

Fine-grained, dynamic control of neural network topology in JAX.

Project description

Connex

Connex is a JAX library built on Equinox that allows for fine-grained, dynamic control of neural network topology. With Connex, you can

  • Turn any directed acyclic graph (DAG) into a trainable neural network.
  • Add and remove both connections and neurons at the individual level.
  • Set and modify dropout probabilities for all neurons individually.
  • Easily toggle techniques such as normalization, adaptive activations, and self-attention.
  • Export a trained network to a NetworkX weighted digraph for network analysis.

Installation

pip install connex

Usage

As a tiny pedagogical example, let's create a trainable neural network from the following DAG

dag

with input neuron 0 and output neurons 3 and 11 (in that order) and ReLU activation for the hidden neurons:

import connex as cnx
import jax.nn as jnn


# Create the graph data
adjacency_dict = {
    0: [1, 2, 3],
    1: [4],
    2: [4, 5],
    4: [6],
    5: [7],
    6: [8, 9],
    7: [10],
    8: [11],
    9: [11],
    10: [11]
}

# Specify the input and output neurons
input_neurons = [0]
output_neurons = [3, 11]

# Create the network
network = cnx.NeuralNetwork(
    adjacency_dict,
    input_neurons, 
    output_neurons,
    jnn.relu
)

That's it! A connex.NeuralNetwork is a subclass of equinox.Module, so it can be trained as such:

import equinox as eqx
import jax
import jax.numpy as jnp
import optax

# Initialize the optimizer
optim = optax.adam(1e-3)
opt_state = optim.init(eqx.filter(network, eqx.is_array))

# Define the loss function
@eqx.filter_value_and_grad
def loss_fn(model, x, y):
    preds = jax.vmap(model)(x)
    return jnp.mean((preds - y) ** 2)

# Define a single training step
@eqx.filter_jit
def step(model, opt_state, x, y):
    loss, grads = loss_fn(model, x, y)
    updates, opt_state = optim.update(grads, opt_state, model)
    model = eqx.apply_updates(model, updates)
    return model, opt_state, loss

# Toy data
x = jnp.expand_dims(jnp.linspace(0, 2 * jnp.pi, 250), 1)
y = jnp.hstack((jnp.cos(x), jnp.sin(x)))

# Training loop
n_epochs = 500
for epoch in range(n_epochs):
    network, opt_state, loss = step(network, opt_state, x, y)
    print(f"Epoch: {epoch}   Loss: {loss}")

Now suppose we wish to add connections 1 → 6 and 2 → 11, remove neuron 9, and set the dropout probability of all hidden neurons to 0.1:

# Add connections
network = cnx.add_connections(network, [(1, 6), (2, 11)])

# Remove neuron
network = cnx.remove_neurons(network, [9])

# Set dropout probability
network = cnx.set_dropout_p(network, 0.1)

That's all there is to it. The new connections have been initialized with untrained parameters, and the neurons in the original network that have not been removed (along with their respective incoming and outgoing connections) have retained their trained parameters.

For more information about manipulating connectivity structure and the NeuralNetwork base class, please see the API section of the documentation. For examples of subclassing NeuralNetwork, please see connex.nn.

Citation

@software{gleyzer2023connex,
  author = {Leonard Gleyzer},
  title = {{C}onnex: Fine-grained Control over Neural Network Topology in {JAX}},
  url = {http://github.com/leonard-gleyzer/connex},
  year = {2023},
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

connex-0.3.3.tar.gz (23.4 kB view details)

Uploaded Source

Built Distribution

connex-0.3.3-py3-none-any.whl (31.0 kB view details)

Uploaded Python 3

File details

Details for the file connex-0.3.3.tar.gz.

File metadata

  • Download URL: connex-0.3.3.tar.gz
  • Upload date:
  • Size: 23.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: python-httpx/0.24.1

File hashes

Hashes for connex-0.3.3.tar.gz
Algorithm Hash digest
SHA256 eccd5b4c71297ab474c50a8fe95a5153aa531d6f12966cbdd90c023f7aa47a4a
MD5 3dedacd79afc9931410cb4759dd7b773
BLAKE2b-256 aff3ab7fa02f94f3614091b73a5cb089acf28524326a6fcb3673f7da7e4e6aea

See more details on using hashes here.

File details

Details for the file connex-0.3.3-py3-none-any.whl.

File metadata

  • Download URL: connex-0.3.3-py3-none-any.whl
  • Upload date:
  • Size: 31.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: python-httpx/0.24.1

File hashes

Hashes for connex-0.3.3-py3-none-any.whl
Algorithm Hash digest
SHA256 8eecb45c74c60041c186e302a2ab1063931e3ed6de16bceb5b6003605473be75
MD5 c3d0df5814e807406beb311785cff67a
BLAKE2b-256 152aebfad8728f109d9f014b795d59e5ccfe0130a758065db59bada88cbe0fec

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page