A Flax trainer
Project description
XTRAIN: a tiny library for training Flax models.
Design goals:
- Help avoiding boiler-plate code
- Minimal functionality and dependency
- Agnostic to hardware configuration (e.g. GPU->TPU)
General workflow
Step 1: define your model
class MyFlaxModule(nn.Module):
@nn.compact
def __call__(self, x):
...
Step 2: define loss function
def my_loss_func(batch, prediction):
x, y_true = batch
loss = ....
return loss
Step 3: create an iterator that supplies training data
my_data = zip(sequence_of_inputs, sequence_of_labels)
Step 4: train
# create and initialize a Trainer object
trainer = xtrain.Trainer(
model = MyFlaxModule(),
losses = my_loss_func,
optimizer = optax.adam(1e-4),
)
train_iter = trainer.train(my_data) # returns a iterable object
# iterate the train_iter trains the model
for epoch in range(3):
for model_out in train_iter:
pass
print(train_iter.loss_logs)
train_iter.reset_loss_logs()
Training data format
- tensowflow Dataset
- torch dataloader
- generator function
- other python iterable that produce numpy data
Checkpointing
train_iter is orbax compatible.
import orbax.checkpoint as ocp
ocp.StandardCheckpointer().save(cp_path, args=ocp.args.StandardSave(train_iter))
Freeze submodule
train_iter.freeze("submodule/Dense_0/kernel")
Simple batch parallelism on multiple device
# Add a new batch dim to you dataset
ds = ds.batch(8)
# create trainer with the Distributed strategy
trainer_iter = xtrain.Trainer(model, losses, optimizer, strategy=xtrain.Distributed).train(ds)
API documentation
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
xtrain-0.4.3.tar.gz
(12.8 kB
view details)
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
xtrain-0.4.3-py3-none-any.whl
(14.4 kB
view details)
File details
Details for the file xtrain-0.4.3.tar.gz.
File metadata
- Download URL: xtrain-0.4.3.tar.gz
- Upload date:
- Size: 12.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.8.5 CPython/3.10.15 Linux/6.8.0-1017-azure
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
3793ae065a67dc85b54e557b121c2b323093b278bc120e150a0426440b654d31
|
|
| MD5 |
a78dd354821186f65e3f7fa28ed4e557
|
|
| BLAKE2b-256 |
4e1c5e866f402e07463a2acbbecfa2c4f9e038ec2b0f39b83f2bd705d8918da2
|
File details
Details for the file xtrain-0.4.3-py3-none-any.whl.
File metadata
- Download URL: xtrain-0.4.3-py3-none-any.whl
- Upload date:
- Size: 14.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.8.5 CPython/3.10.15 Linux/6.8.0-1017-azure
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
5dba9a035af9c7355246305b2ecf5b43602768ae4e8553b983cc5c22cef928a3
|
|
| MD5 |
88aac34fbbe72b12cafa9b7796246043
|
|
| BLAKE2b-256 |
486ae9e407f49be9b92d65429c81042a8eb5b3333ca4c5482c797ec37581ad24
|