JAX/Flax implementation of rational neural nets
Project description
RationalNets
JAX/Flax implementation of rational neural nets.
Original
- paper: Nicolas Boullé, Yuji Nakatsukasa, and Alex Townsend, Rational neural networks, arXiv preprint arXiv:2004.01902 (2020).
- github: https://github.com/NBoulle/RationalNets
Installation
RationalNets can be installed with pip directly from GitHub, with the following command:
pip install git+https://github.com/yonesuke/RationalNets.git
QuickStart
import jax.numpy as jnp
from jax import random
from rationalnets import RationalMLP
model = RationalMLP([12, 8, 4])
batch = jnp.ones((32, 10))
variables = model.init(random.PRNGKey(0), batch)
output = model.apply(variables, batch)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
rationalnets-0.0.1.tar.gz
(2.8 kB
view hashes)
Built Distribution
Close
Hashes for rationalnets-0.0.1-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | c631caac68a8e8e0d475eec0017368306b21668a497d5df037d8766421dabcae |
|
MD5 | f6a5b16c0aa5c315d77bdefa042c3274 |
|
BLAKE2b-256 | 9efa8616f65b41aee0ae6fd2bc0679753c59b67210fc110f82fb946592bbe7b3 |