A simple deep learning library for JAX.
Project description
SNAX
Hungry for a dead-simple functional deep learning library?
You came to the right place.
Creating a Multi-layer perceptron
import jax
import jax.numpy as jnp
import snax
hidden_sizes = [10, 20, 30]
input_size = 3
key = jax.random.PRNGKey(0)
mlp = snax.nn.MLP(key,
input_size,
hidden_sizes,
act_fn=jax.nn.relu)
out = mlp(jnp.ones([input_size]))
Creating a deep LSTM
import jax
import jax.numpy as jnp
import snax
input_size = 3
num_steps = 40
hidden_layer_sizes = [32, 64, 32]
key = jax.random.PRNGKey(0)
lstm = snax.recurrent.LSTM(key,
input_size,
hidden_layer_sizes,
act_fn=jnp.tanh,
forget_gate_bias_init=1.)
# Run the LSTM on some inputs
inputs = jnp.zeros((num_steps, input_size))
new_state, outs = LSTM(inputs, lstm.initial_state())
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
snaxlib-0.0.34.tar.gz
(28.6 kB
view hashes)
Built Distribution
snaxlib-0.0.34-py3-none-any.whl
(34.8 kB
view hashes)