Skip to main content

JAX/Flax implementation of rational neural nets

Project description

rationalnets

JAX/Flax implementation of rational neural nets.

Original

Installation

rationalnets can be installed with pip with the following command:

python -m pip install rationalnets

Or you can install the latest version with the following command:

python -m pip install git+https://github.com/yonesuke/RationalNets.git

QuickStart

Rational activation function

import jax.numpy as jnp
from jax import random
from rationalnets import RationalMLP

xs = jnp.arange(-2.0, 2.0, 0.01)
act = Rational()
params = model.init(random.PRNGKey(0), xs)
ys = act.apply(params, xs) # values of rational activation function for -2.0 ~ 2.0

Rational MLP

import jax.numpy as jnp
from jax import random
from rationalnets import RationalMLP

model = RationalMLP([12, 8, 4])
batch = jnp.ones((32, 10))
variables = model.init(random.PRNGKey(0), batch)
output = model.apply(variables, batch)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

rationalnets-0.1.0.tar.gz (2.9 kB view hashes)

Uploaded Source

Built Distribution

rationalnets-0.1.0-py3-none-any.whl (3.4 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page