A Simplistic trainer for Flax
Project description
Flax-Pilot
Flax-Pilot aims to simplify the process of writing training loops for Google's Flax framework. As someone new to Flax, I started this project to deepen my understanding. This module represents a beginner's exploration into building efficient training workflows, emphasizing the need for further expertise to refine and expand its capabilities. Future plans include integrating multiple optimizer training, diverse metric modules, callbacks, and advancing towards more complex training loops, aiming to enhance its functionality and versatility. Flax-Pilot supports distributed training, ensuring scalability and efficiency across multiple devices.
As of 27-7-2024, the trainer is available as package for GPU & CPU.
How to Use?
🛠️ Write a flax.linen Module
import flax.linen as nn
class CNN(nn.Module):
@nn.compact
def __call__(self, x, deterministic):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1))
x = nn.Dense(features=256)(x)
x = nn.Dropout(rate=0.5, deterministic=deterministic)(x)
x = nn.Dense(features=10)(x)
return x
🔧 Define Optimizer, Input Shapes, and Dict of Loss & Metric Trackers
Loss trackers (lt) takes in scalar loss value and averages it throughout training.
Metric trackers (mt) takes in y_true, y_pred and computes metric score and averages throughout training.
import optax as tx
opt = tx.adam(0.0001)
input_shape = {'x': (1, 28, 28, 1)}
from fpilot import Trackers as tr
# Create tracker instances.
loss_metric_tracker_dict = {
'lt': {'loss': tr.Mean()},
'mt': {'F1': tr.F1Score(threshold=0.6, num_classes=10, average='macro')}
}
🧮 Create loss_fn
A function that takes these certain params as written below in the code and returns scalar loss, dict of loss, mutable variables and tracker updates.
Key names lt, mt shouldn't be changed anywhere, as training loops depend on those keys. Subkey names, loss, F1 are free to be changed
but must match across loss_metric_tracker_dict and loss_metric_value_dict.
import optax as tx
# This fn's 1st return value is differentiated wrt the fn's first param.
def loss_fn(params, mut_variables, apply, sample, deterministic, det_key, step, objective):
x, y = sample
yp = apply(params, x, deterministic=deterministic, rngs={'dropout': det_key})
# No mutable vars in this model so,
mut_variables = {}
loss = tx.softmax_cross_entropy(y, yp).mean()
loss_metric_value_dict = {'lt': {'loss': loss}, 'mt': {'F1': (y, yp)}}
return loss, mut_variables, loss_metric_value_dict
🏋️ Create Trainer Instance
from fpilot import Trainer
trainer = Trainer(CNN(), input_shape, optimizer, loss_fn, loss_metric_tracker_dict)
📈 Train the Model & Evaluate
train_ds = ... # tf.data.Dataset as numpy iterator
val_ds = ... # tf.data.Dataset as numpy iterator
train_steps, val_steps = 10000, 1000 # steps per epoch
ckpt_path = "/saved/model/model_1" # If set to None, no checkpoints will be saved during training.
trainer.train(epochs, train_ds, val_ds, train_steps, val_steps, ckpt_path)
What's next?
- Seperate package for TPU.
- Callbacks.
- TensorBoard logging.
Demo
Demo notebooks will be available in code section here.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file flax_pilot-0.2.2.tar.gz.
File metadata
- Download URL: flax_pilot-0.2.2.tar.gz
- Upload date:
- Size: 11.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.0.1 CPython/3.11.10
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e6ebc416217b277c68fff2aac12695591ab15dac34b5ae0f4adc81be3882ca9b
|
|
| MD5 |
8bd1d8578569d3bf4fe2ee167dd2e4a2
|
|
| BLAKE2b-256 |
c28c255ca3533f3d8c54da51ac1a2c255bfe73705637d12b43a73a91319bf7de
|
File details
Details for the file flax_pilot-0.2.2-py3-none-any.whl.
File metadata
- Download URL: flax_pilot-0.2.2-py3-none-any.whl
- Upload date:
- Size: 12.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.0.1 CPython/3.11.10
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a8508d828395b58de22987cbf9689c4db9f2f823ce066f9774368efef44be4bb
|
|
| MD5 |
70262a73e3d20eb0d0a5df1cfcd22d02
|
|
| BLAKE2b-256 |
bc4b008b2c6b8a83b08963fe80c8d408a8d934dc641c9a45c0dadf22689409f0
|