Skip to main content

A Simplistic trainer for Flax

Project description

Flax-Pilot

Flax-Pilot aims to simplify the process of writing training loops for Google's Flax framework. As someone new to Flax, I started this project to deepen my understanding. This module represents a beginner's exploration into building efficient training workflows, emphasizing the need for further expertise to refine and expand its capabilities. Future plans include integrating multiple optimizer training, diverse metric modules, callbacks, and advancing towards more complex training loops, aiming to enhance its functionality and versatility. Flax-Pilot supports distributed training, ensuring scalability and efficiency across multiple devices.

As of 27-7-2024, the trainer is available as package PyPI version

How to Use?

🛠️ Write a flax.linen Module

import flax.linen as nn
class CNN(nn.Module):
    @nn.compact
    def __call__(self, x, deterministic):
        x = nn.Conv(features=32, kernel_size=(3, 3))(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))  
        x = nn.Dense(features=256)(x)
        x = nn.Dropout(rate=0.5, deterministic=deterministic)(x)
        x = nn.Dense(features=10)(x)
        return x

🔧 Define Optimizer, Input Shapes, and Dict of Loss & Metric Trackers

Loss trackers (lt) takes in scalar loss value and averages it throughout training.
Metric trackers (mt) takes in y_true, y_pred and computes metric score and averages throughout training.

import optax as tx

opt = tx.adam(0.0001)
input_shape = {'x': (1, 28, 28, 1)}

from fpilot import BasicTrackers as tr

# Create tracker instances.
loss_metric_tracker_dict = {
    'lt': {'loss': tr.Mean()},
    'mt': {'F1': tr.F1Score(threshold=0.6, num_classes=10, average='macro')}
}

🧮 Create loss_fn

A function that takes these certain params as written below in the code and returns scalar loss, dict of loss & metrics values.

Key names lt, mt shouldn't be changed anywhere, as training loops depend on those keys. Subkey names, loss, F1 are free to be changed but must match across loss_metric_tracker_dict and loss_metric_value_dict.

import optax as tx

# This fn's 1st return value is differentiated wrt the fn's first param.
def loss_fn(params, apply, sample, deterministic, det_key, step):
    x, y = sample
    yp = apply(params, x, deterministic=deterministic, rngs={'dropout': det_key})
    loss = tx.softmax_cross_entropy(y, yp).mean()
    loss_metric_value_dict = {'lt': {'loss': loss}, 'mt': {'F1': (y, yp)}}
    return loss, loss_metric_value_dict

🏋️ Create Trainer Instance

from fpilot import Trainer

trainer = Trainer(CNN(), input_shape, optimizer, loss_fn, loss_metric_tracker_dict)

📈 Train the Model & Evaluate

train_ds = ... # tf.data.Dataset as numpy iterator
val_ds = ... # tf.data.Dataset as numpy iterator
train_steps, val_steps = 10000, 1000 # steps per epoch
ckpt_path = "/saved/model/model_1"  # If set to None, no checkpoints will be saved during training.

trainer.train(epochs, train_ds, val_ds, train_steps, val_steps, ckpt_path)

Demo

Review the 'examples' folder for training tutorials. The vae-gan-cfg-using-pretrained notebook demonstrates how to use the trainer as a Python package, while the other notebooks show how to use the trainer with git clone. Therefore, see the vae-gan-cfg-using-pretrained for a more simpler training.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

flax_pilot-0.1.6.tar.gz (10.9 kB view details)

Uploaded Source

Built Distribution

flax_pilot-0.1.6-py3-none-any.whl (12.6 kB view details)

Uploaded Python 3

File details

Details for the file flax_pilot-0.1.6.tar.gz.

File metadata

  • Download URL: flax_pilot-0.1.6.tar.gz
  • Upload date:
  • Size: 10.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.10.13

File hashes

Hashes for flax_pilot-0.1.6.tar.gz
Algorithm Hash digest
SHA256 57b3c586ba75af4ffe22b49a5748a2f4b1545dc99cef69772fe2af7fbe86e262
MD5 18ca9c2ef73be8dfa7081b72bad25168
BLAKE2b-256 09e6a2ba522be85cdf68aee336f1410eb3b5cc5e83dff816d5b93c40921564dd

See more details on using hashes here.

Provenance

File details

Details for the file flax_pilot-0.1.6-py3-none-any.whl.

File metadata

  • Download URL: flax_pilot-0.1.6-py3-none-any.whl
  • Upload date:
  • Size: 12.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.10.13

File hashes

Hashes for flax_pilot-0.1.6-py3-none-any.whl
Algorithm Hash digest
SHA256 732a9dd0e8841fb6cf6f641e779e3dfeb6361f090626f2d7236f4876f1e4e25a
MD5 f6b053057bc2ee64c5fde4ddb60cf5b5
BLAKE2b-256 3d40f0998800f68bccc0b6bada2066fcec6c3354300b824aebd68ca924cb7f1e

See more details on using hashes here.

Provenance

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page