jax implementation for metalearning neuronal diversity
Project description
jaxDiversity
This is an updated implementation for Neural networks embrace diversity paper
Authors
Anshul Choudhary, Anil Radhakrishnan, John F. Lindner, Sudeshna Sinha, and William L. Ditto
Link to paper
Key Results
- We construct neural networks with learnable activation functions and sere that they quickly diversify from each other under training.
- These activations subsequently outperform their pure counterparts on classification tasks.
- The neuronal sub-networks instantiate the neurons and meta-learning adjusts their weights and biases to find efficient spanning sets of nonlinear activations.
- These improved neural networks provide quantitative examples of the emergence of diversity and insight into its advantages.
Install
pip install jaxDiversity
How to use
The codebase has 4 main components: * dataloading: Contains tools for loading the datasets mentioned in the manuscript. We use pytorch dataloaders with a custom numpy collate function to use this data in jax.
-
losses: We handle both traditional mlps and hamiltonian neural networkss with minimal changes with our loss implementations.
-
mlp: Contains custom mlp that takes in multiple activations and uses them intralayer to create a diverse network. Also contains the activation neural networks.
-
loops: Contains the inner and outer loops for metalearning to optimize the activation functions in tandem with the supervised learning task
Minimum example
import jax
import optax
from jaxDiversity.utilclasses import InnerConfig, OuterConfig # simple utility classes for configuration consistency
from jaxDiversity.dataloading import NumpyLoader, DummyDataset
from jaxDiversity.mlp import mlp_afunc, MultiActMLP, init_linear_weight, xavier_normal_init, save
from jaxDiversity.baseline import compute_loss as compute_loss_baseline
from jaxDiversity.hnn import compute_loss as compute_loss_hnn
from jaxDiversity.loops import inner_opt, outer_opt
inner optimzation or standard training loop with the baseline activation
dev_inner_config = InnerConfig(test_train_split=0.8,
input_dim=2,
output_dim=2,
hidden_layer_sizes=[18],
batch_size=64,
epochs=2,
lr=1e-3,
mu=0.9,
n_fns=2,
l2_reg=1e-1,
seed=42)
key = jax.random.PRNGKey(dev_inner_config.seed)
model_key, init_key = jax.random.split(key)
afuncs = [lambda x: x**2, lambda x: x]
train_dataset = DummyDataset(1000, dev_inner_config.input_dim, dev_inner_config.output_dim)
test_dataset = DummyDataset(1000, dev_inner_config.input_dim, dev_inner_config.output_dim)
train_dataloader = NumpyLoader(train_dataset, batch_size=dev_inner_config.batch_size, shuffle=True)
test_dataloader = NumpyLoader(test_dataset, batch_size=dev_inner_config.batch_size, shuffle=True)
opt = optax.rmsprop(learning_rate=dev_inner_config.lr, momentum=dev_inner_config.mu, decay=dev_inner_config.l2_reg)
model = MultiActMLP(dev_inner_config.input_dim, dev_inner_config.output_dim, dev_inner_config.hidden_layer_sizes, model_key, bias=False)
baselineNN, opt_state ,inner_results = inner_opt(model =model,
train_data =train_dataloader,
test_data = test_dataloader,
afuncs = afuncs,
opt = opt,
loss_fn=compute_loss_baseline,
config = dev_inner_config, training=True, verbose=True)
metalearning with Hamiltonian Neural Networks
inner_config = InnerConfig(test_train_split=0.8,
input_dim=2,
output_dim=1,
hidden_layer_sizes=[32],
batch_size=64,
epochs=5,
lr=1e-3,
mu=0.9,
n_fns=2,
l2_reg=1e-1,
seed=42)
outer_config = OuterConfig(input_dim=1,
output_dim=1,
hidden_layer_sizes=[18],
batch_size=1,
steps=2,
print_every=1,
lr=1e-3,
mu=0.9,
seed=24)
train_dataset = DummyDataset(1000, inner_config.input_dim, 2)
test_dataset = DummyDataset(1000, inner_config.input_dim, 2)
train_dataloader = NumpyLoader(train_dataset, batch_size=inner_config.batch_size, shuffle=True)
test_dataloader = NumpyLoader(test_dataset, batch_size=inner_config.batch_size, shuffle=True)
opt = optax.rmsprop(learning_rate=inner_config.lr, momentum=inner_config.mu, decay=inner_config.l2_reg)
meta_opt = optax.rmsprop(learning_rate=outer_config.lr, momentum=outer_config.mu)
HNN_acts, HNN_stats = outer_opt(train_dataloader, test_dataloader,compute_loss_hnn ,inner_config, outer_config, opt, meta_opt, save_path=None)
Link to older pytorch codebase with classification problem: DiversityNN
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file jaxDiversity-0.0.2.tar.gz
.
File metadata
- Download URL: jaxDiversity-0.0.2.tar.gz
- Upload date:
- Size: 16.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.4
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 40f72d1d3b7a464ea09cd8247e05deacf1d32a32ac9ac9a344487841d2a31121 |
|
MD5 | 0bd76e2a5a3e88dcb5455a82d84e5a75 |
|
BLAKE2b-256 | 36caf42365a330da1c16118a5c659ac6cada04e84e902a4fb0640ef5c37d7d2e |
File details
Details for the file jaxDiversity-0.0.2-py3-none-any.whl
.
File metadata
- Download URL: jaxDiversity-0.0.2-py3-none-any.whl
- Upload date:
- Size: 15.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.4
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 79dcbea6fcf78b3876613aa47497677d43fb8db3447b5f8cb0dc630f8405ea44 |
|
MD5 | 21e1ddfbb07dbfa69644468271e84fb3 |
|
BLAKE2b-256 | 4f990cf1fbeb14f042fc0ae9145ba1456da709c5338ffb0a5603f599f15f104a |