Skip to main content

A Flax trainer

Project description

XTRAIN: a tiny library for training Flax models.

Design goals:

  • Help avoiding boiler-plate code
  • Minimal functionality and dependency
  • Agnostic to hardware configuration (e.g. GPU->TPU)

General workflow

Step 1: define your model

class MyFlaxModule(nn.Module):
  @nn.compact
  def __call__(self, x):
    ...

Step 2: define loss function

def my_loss_func(batch, prediction):
    x, y_true = batch
    loss = ....
    return loss

Step 3: create an iterator that supplies training data

my_data = zip(sequence_of_inputs, sequence_of_labels)

Step 4: train

# create and initialize a Trainer object
trainer = xtrain.Trainer(
  model = MyFlaxModule(),
  losses = my_loss_func,
  optimizer = optax.adam(1e-4),
)

train_iter = trainer.train(my_data) # returns a iterable object

# iterate the train_iter trains the model
for epoch in range(3):
  for model_out in train_iter:
    pass
  print(train_iter.loss_logs)
  train_iter.reset_loss_logs()

Training data format

  • tensowflow Dataset
  • torch dataloader
  • generator function
  • other python iterable that produce numpy data

Checkpointing

train_iter is orbax compatible.

import orbax.checkpoint as ocp
ocp.StandardCheckpointer().save(cp_path, args=ocp.args.StandardSave(train_iter))

Freeze submodule

train_iter.freeze("submodule/Dense_0/kernel")

Simple batch parallelism on multiple device

# Add a new batch dim to you dataset
ds = ds.batch(8)
# create trainer with the Distributed strategy
trainer_iter = xtrain.Trainer(model, losses, optimizer, strategy=xtrain.Distributed).train(ds)

API documentation

https://jiyuuchc.github.io/xtrain/

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

xtrain-0.4.2.tar.gz (12.9 kB view details)

Uploaded Source

Built Distribution

xtrain-0.4.2-py3-none-any.whl (14.4 kB view details)

Uploaded Python 3

File details

Details for the file xtrain-0.4.2.tar.gz.

File metadata

  • Download URL: xtrain-0.4.2.tar.gz
  • Upload date:
  • Size: 12.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.4 CPython/3.10.15 Linux/6.5.0-1025-azure

File hashes

Hashes for xtrain-0.4.2.tar.gz
Algorithm Hash digest
SHA256 fdbdb8b76ca97f7064b3257b3790a9a01ebb2b39ed40f0eb2a096254b2bbf7a4
MD5 f5fe37d9717db65d4ba3234d94104a7a
BLAKE2b-256 6b46e7a7b14e9d432996696420e2bc40fb9987cfa2ba0635c6d9b4baecbb16db

See more details on using hashes here.

File details

Details for the file xtrain-0.4.2-py3-none-any.whl.

File metadata

  • Download URL: xtrain-0.4.2-py3-none-any.whl
  • Upload date:
  • Size: 14.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.4 CPython/3.10.15 Linux/6.5.0-1025-azure

File hashes

Hashes for xtrain-0.4.2-py3-none-any.whl
Algorithm Hash digest
SHA256 2aa9b9083863eb15cdec6cb2989d8b00d1a4be21bbc0f6af5eb70d8e43060fca
MD5 17bf4757c341885de7db5029a9db0c0b
BLAKE2b-256 6ac91d83e71ff8dedde9e43eede83892098e4c595fbd7df1b6b1f5544ede630d

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page