Skip to main content

jax implementation for metalearning neuronal diversity

Project description

jaxDiversity

This is an updated implementation for Neural networks embrace diversity paper

Authors

Anshul Choudhary, Anil Radhakrishnan, John F. Lindner, Sudeshna Sinha, and William L. Ditto

Link to paper

Key Results

  • We construct neural networks with learnable activation functions and sere that they quickly diversify from each other under training.
  • These activations subsequently outperform their pure counterparts on classification tasks.
  • The neuronal sub-networks instantiate the neurons and meta-learning adjusts their weights and biases to find efficient spanning sets of nonlinear activations.
  • These improved neural networks provide quantitative examples of the emergence of diversity and insight into its advantages.

Install

pip install jaxDiversity

How to use

The codebase has 4 main components: * dataloading: Contains tools for loading the datasets mentioned in the manuscript. We use pytorch dataloaders with a custom numpy collate function to use this data in jax.

  • losses: We handle both traditional mlps and hamiltonian neural networkss with minimal changes with our loss implementations.

  • mlp: Contains custom mlp that takes in multiple activations and uses them intralayer to create a diverse network. Also contains the activation neural networks.

  • loops: Contains the inner and outer loops for metalearning to optimize the activation functions in tandem with the supervised learning task

Minimum example

import jax
import optax

from jaxDiversity.utilclasses import InnerConfig, OuterConfig # simple utility classes for configuration consistency
from jaxDiversity.dataloading import NumpyLoader, DummyDataset
from jaxDiversity.mlp import mlp_afunc, MultiActMLP, init_linear_weight, xavier_normal_init, save
from jaxDiversity.baseline import compute_loss as compute_loss_baseline
from jaxDiversity.hnn import compute_loss as compute_loss_hnn
from jaxDiversity.loops import inner_opt, outer_opt

inner optimzation or standard training loop with the baseline activation

dev_inner_config = InnerConfig(test_train_split=0.8,
                            input_dim=2,
                            output_dim=2,
                            hidden_layer_sizes=[18],
                            batch_size=64,
                            epochs=2,
                            lr=1e-3,
                            mu=0.9,
                            n_fns=2,
                            l2_reg=1e-1,
                            seed=42)
key = jax.random.PRNGKey(dev_inner_config.seed)
model_key, init_key = jax.random.split(key)
afuncs = [lambda x: x**2, lambda x: x]
train_dataset = DummyDataset(1000, dev_inner_config.input_dim, dev_inner_config.output_dim)
test_dataset = DummyDataset(1000, dev_inner_config.input_dim, dev_inner_config.output_dim)
train_dataloader = NumpyLoader(train_dataset, batch_size=dev_inner_config.batch_size, shuffle=True)
test_dataloader = NumpyLoader(test_dataset, batch_size=dev_inner_config.batch_size, shuffle=True)

opt = optax.rmsprop(learning_rate=dev_inner_config.lr, momentum=dev_inner_config.mu, decay=dev_inner_config.l2_reg)
model = MultiActMLP(dev_inner_config.input_dim, dev_inner_config.output_dim, dev_inner_config.hidden_layer_sizes, model_key, bias=False)
baselineNN, opt_state ,inner_results = inner_opt(model =model, 
                                            train_data =train_dataloader,
                                            test_data = test_dataloader,
                                            afuncs = afuncs, 
                                            opt = opt, 
                                            loss_fn=compute_loss_baseline,
                                            config = dev_inner_config, training=True, verbose=True)

metalearning with Hamiltonian Neural Networks

inner_config = InnerConfig(test_train_split=0.8,
                            input_dim=2,
                            output_dim=1,
                            hidden_layer_sizes=[32],
                            batch_size=64,
                            epochs=5,
                            lr=1e-3,
                            mu=0.9,
                            n_fns=2,
                            l2_reg=1e-1,
                            seed=42)
outer_config = OuterConfig(input_dim=1,
                            output_dim=1,
                            hidden_layer_sizes=[18],
                            batch_size=1,
                            steps=2,
                            print_every=1,
                            lr=1e-3,
                            mu=0.9,
                            seed=24)
train_dataset = DummyDataset(1000, inner_config.input_dim, 2)
test_dataset = DummyDataset(1000, inner_config.input_dim, 2)
train_dataloader = NumpyLoader(train_dataset, batch_size=inner_config.batch_size, shuffle=True)
test_dataloader = NumpyLoader(test_dataset, batch_size=inner_config.batch_size, shuffle=True)

opt = optax.rmsprop(learning_rate=inner_config.lr, momentum=inner_config.mu, decay=inner_config.l2_reg)
meta_opt = optax.rmsprop(learning_rate=outer_config.lr, momentum=outer_config.mu)

HNN_acts, HNN_stats = outer_opt(train_dataloader, test_dataloader,compute_loss_hnn ,inner_config, outer_config, opt, meta_opt, save_path=None)

Link to older pytorch codebase with classification problem: DiversityNN

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jaxDiversity-0.0.2.tar.gz (16.1 kB view details)

Uploaded Source

Built Distribution

jaxDiversity-0.0.2-py3-none-any.whl (15.8 kB view details)

Uploaded Python 3

File details

Details for the file jaxDiversity-0.0.2.tar.gz.

File metadata

  • Download URL: jaxDiversity-0.0.2.tar.gz
  • Upload date:
  • Size: 16.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.4

File hashes

Hashes for jaxDiversity-0.0.2.tar.gz
Algorithm Hash digest
SHA256 40f72d1d3b7a464ea09cd8247e05deacf1d32a32ac9ac9a344487841d2a31121
MD5 0bd76e2a5a3e88dcb5455a82d84e5a75
BLAKE2b-256 36caf42365a330da1c16118a5c659ac6cada04e84e902a4fb0640ef5c37d7d2e

See more details on using hashes here.

File details

Details for the file jaxDiversity-0.0.2-py3-none-any.whl.

File metadata

  • Download URL: jaxDiversity-0.0.2-py3-none-any.whl
  • Upload date:
  • Size: 15.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.4

File hashes

Hashes for jaxDiversity-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 79dcbea6fcf78b3876613aa47497677d43fb8db3447b5f8cb0dc630f8405ea44
MD5 21e1ddfbb07dbfa69644468271e84fb3
BLAKE2b-256 4f990cf1fbeb14f042fc0ae9145ba1456da709c5338ffb0a5603f599f15f104a

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page