Skip to main content

A Flax trainer

Project description

XTRAIN: a tiny library for training Flax models.

Design goals:

  • Help avoiding boiler-plate code
  • Minimal functionality and dependency
  • Agnostic to hardware configuration (e.g. GPU->TPU)

General workflow

Step 1: define your model

class MyFlaxModule(nn.Module):
  @nn.compact
  def __call__(self, x):
    ...

Step 2: define loss function

def my_loss_func(batch, prediction):
    x, y_true = batch
    loss = ....
    return loss

Step 3: create an iterator that supplies training data

my_data = zip(sequence_of_inputs, sequence_of_labels)

Step 4: train

# create and initialize a Trainer object
trainer = xtrain.Trainer(
  model = MyFlaxModule(),
  losses = my_loss_func,
  optimizer = optax.adam(1e-4),
)

train_iter = trainer.train(my_data) # returns a iterable object

# iterate the train_iter trains the model
for epoch in range(3):
  for model_out in train_iter:
    pass
  print(train_iter.loss_logs)
  train_iter.reset_loss_logs()

Training data format

  • tensowflow Dataset
  • torch dataloader
  • generator function
  • other python iterable that produce numpy data

Checkpointing

train_iter is orbax compatible.

import orbax.checkpoint as ocp
ocp.StandardCheckpointer().save(cp_path, args=ocp.args.StandardSave(train_iter))

Freeze submodule

train_iter.freeze("submodule/Dense_0/kernel")

Simple batch parallelism on multiple device

# Add a new batch dim to you dataset
ds = ds.batch(8)
# create trainer with the Distributed strategy
trainer_iter = xtrain.Trainer(model, losses, optimizer, strategy=xtrain.Distributed).train(ds)

API documentation

https://jiyuuchc.github.io/xtrain/

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

xtrain-0.4.0.tar.gz (12.8 kB view details)

Uploaded Source

Built Distribution

xtrain-0.4.0-py3-none-any.whl (14.4 kB view details)

Uploaded Python 3

File details

Details for the file xtrain-0.4.0.tar.gz.

File metadata

  • Download URL: xtrain-0.4.0.tar.gz
  • Upload date:
  • Size: 12.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.3 CPython/3.10.15 Linux/6.8.0-1014-azure

File hashes

Hashes for xtrain-0.4.0.tar.gz
Algorithm Hash digest
SHA256 10fa431c47285e7916b410102c9e7ba79874b4e6a36099e726f13d35b29b72c8
MD5 ffbb7c34d020cc1f2b391b3cb6e1312e
BLAKE2b-256 20f6f811f71e311d0a73f357d4b080f8c5b5261e294f50933dddfb7a96f542bf

See more details on using hashes here.

File details

Details for the file xtrain-0.4.0-py3-none-any.whl.

File metadata

  • Download URL: xtrain-0.4.0-py3-none-any.whl
  • Upload date:
  • Size: 14.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.3 CPython/3.10.15 Linux/6.8.0-1014-azure

File hashes

Hashes for xtrain-0.4.0-py3-none-any.whl
Algorithm Hash digest
SHA256 6bbf4885b27e3445a943870353e3f088510328844639a0980f20abfbc2558e13
MD5 e96b4f4763b9203d45a051e9948d6df1
BLAKE2b-256 40547af2e4a2ab4618c309a47e6dc8d30adf6d3fdb17ff373ad3d89ec2a70469

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page