JAX/Flax implementation of finite-size scaling
Project description
jaxfss
JAX/Flax implementation of finite-size scaling
Installation
jaxfss
can be installed with pip directly from GitHub, with the following command:
pip install git+https://github.com/yonesuke/jaxfss.git
Quickstart
jaxfss
is a tool for finite-size scaling.
See finite-size scaling on the Binder ratio of the 2d Ising model in examples.
Components
Here, we highlight the main blocks to do the finite-size scaling.
MLP
and RationalMLP
In jaxfss, we approximate the scaling function to a neural netowrk. We provide a simple multi layer perceptron (MLP) module to construct neural networks, which is build upon Flax.
import jaxfss
import jax
import jax.numpy as jnp
mlp = jaxfss.MLP(features=[20,20,1])
mlp_params = mlp.init(jax.random.PRNGKey(0), jnp.array([[1]]))
Default activation function of MLP
is a sigmoid function. You can change this via act
argument.
We also provide the MLP with rational activation function.
import jaxfss
import jax.numpy as jnp
from jax import random
mlp = jaxfss.RationalMLP(features=[20,20,1])
mlp_params = mlp.init(jax.random.PRNGKey(0), jnp.array([[1]]))
In RationalMLP
, the activation function is approximated by a rational function,
and the parameters of the activation function are also optimized during the learning process.
See arXiv:2004.01902 for the original paper.
CriticalData
In neural networks, it is very important to pre-normalize the training data.
CriticalData
is a module that does the pre-processing for you.
import jaxfss
Ls = ... # systemsize arrays
Ts = ... # temperature arrays
As = ... # observable arrays
As_err = ... # observable error arrays
dataset = jaxfss.CriticalData(Ls, Ts, As, As_err)
train_data = dataset.training_data # dict with "system_size", "temperature", "observable", "observable_var"
You can also initialize the module from file:
import jaxfss
dataset = jaxfss.CriticalData.from_file(fname="filename.txt")
train_data = dataset.training_data
MSELoss
and NLLLoss
We provide a helper function for constructing loss function.
MSELoss
is the mean squared error, and NLLLoss
is the negative log likelihood loss.
The function name is derived from pytorch.
import jaxfss
def loss_fn(params):
...
y_true = ...
y_pred = ...
return jaxfss.MSELoss(y_true, y_pred)
fit
Once everything is ready, all you need to do is learn!
fit
function is a wrapper that learns using optax.
We note that init_params
should be a dict with keys mlp
and fss
.
import jaxfss
import optax
init_params = {
"mlp": mlp_params,
"fss": ...
}
def loss_fn(params):
...
return ...
optimizer = optax.adam(learning_rate=10**-3)
steps = 10**4
params, losses, fsses = jaxfss.fit(loss_fn, optimizer, init_params, steps)
fit
function returns the final params
together with values of loss and fss in the learning process.
Further reading
Citation
We are currently preparing for the paper!!
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.