Skip to main content

A Simplistic trainer for Flax

Project description

Flax-Pilot

Flax-Pilot aims to simplify the process of writing training loops for Google's Flax framework. As someone new to Flax, I started this project to deepen my understanding. This module represents a beginner's exploration into building efficient training workflows, emphasizing the need for further expertise to refine and expand its capabilities. Future plans include integrating multiple optimizer training, diverse metric modules, callbacks, and advancing towards more complex training loops, aiming to enhance its functionality and versatility. Flax-Pilot supports distributed training, ensuring scalability and efficiency across multiple devices.

As of 27-7-2024, the trainer is available as package PyPI version for GPU & CPU.

How to Use?

🛠️ Write a flax.linen Module

import flax.linen as nn
class CNN(nn.Module):
    @nn.compact
    def __call__(self, x, deterministic):
        x = nn.Conv(features=32, kernel_size=(3, 3))(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))  
        x = nn.Dense(features=256)(x)
        x = nn.Dropout(rate=0.5, deterministic=deterministic)(x)
        x = nn.Dense(features=10)(x)
        return x

🔧 Define Optimizer, Input Shapes, and Dict of Loss & Metric Trackers

Loss trackers (lt) takes in scalar loss value and averages it throughout training.
Metric trackers (mt) takes in y_true, y_pred and computes metric score and averages throughout training.

import optax as tx

opt = tx.adam(0.0001)
input_shape = {'x': (1, 28, 28, 1)}

from fpilot import Trackers as tr

# Create tracker instances.
loss_metric_tracker_dict = {
    'lt': {'loss': tr.Mean()},
    'mt': {'F1': tr.F1Score(threshold=0.6, num_classes=10, average='macro')}
}

🧮 Create loss_fn

A function that takes these certain params as written below in the code and returns scalar loss, dict of loss, mutable variables and tracker updates.

Key names lt, mt shouldn't be changed anywhere, as training loops depend on those keys. Subkey names, loss, F1 are free to be changed but must match across loss_metric_tracker_dict and loss_metric_value_dict.

import optax as tx

# This fn's 1st return value is differentiated wrt the fn's first param.
def loss_fn(params, mut_variables, apply, sample, deterministic, det_key, step, objective):
    x, y = sample
    yp = apply(params, x, deterministic=deterministic, rngs={'dropout': det_key})
    
    # No mutable vars in this model so,
    mut_variables = {}
    
    loss = tx.softmax_cross_entropy(y, yp).mean()
    loss_metric_value_dict = {'lt': {'loss': loss}, 'mt': {'F1': (y, yp)}}
    return loss, mut_variables, loss_metric_value_dict

🏋️ Create Trainer Instance

from fpilot import Trainer

trainer = Trainer(CNN(), input_shape, optimizer, loss_fn, loss_metric_tracker_dict)

📈 Train the Model & Evaluate

train_ds = ... # tf.data.Dataset as numpy iterator
val_ds = ... # tf.data.Dataset as numpy iterator
train_steps, val_steps = 10000, 1000 # steps per epoch
ckpt_path = "/saved/model/model_1"  # If set to None, no checkpoints will be saved during training.

trainer.train(epochs, train_ds, val_ds, train_steps, val_steps, ckpt_path)

What's next?

  • Seperate package for TPU.
  • Callbacks.
  • TensorBoard logging.

Demo

Demo notebooks will be available in code section here.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

flax_pilot-0.2.2.tar.gz (11.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

flax_pilot-0.2.2-py3-none-any.whl (12.9 kB view details)

Uploaded Python 3

File details

Details for the file flax_pilot-0.2.2.tar.gz.

File metadata

  • Download URL: flax_pilot-0.2.2.tar.gz
  • Upload date:
  • Size: 11.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.11.10

File hashes

Hashes for flax_pilot-0.2.2.tar.gz
Algorithm Hash digest
SHA256 e6ebc416217b277c68fff2aac12695591ab15dac34b5ae0f4adc81be3882ca9b
MD5 8bd1d8578569d3bf4fe2ee167dd2e4a2
BLAKE2b-256 c28c255ca3533f3d8c54da51ac1a2c255bfe73705637d12b43a73a91319bf7de

See more details on using hashes here.

File details

Details for the file flax_pilot-0.2.2-py3-none-any.whl.

File metadata

  • Download URL: flax_pilot-0.2.2-py3-none-any.whl
  • Upload date:
  • Size: 12.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.11.10

File hashes

Hashes for flax_pilot-0.2.2-py3-none-any.whl
Algorithm Hash digest
SHA256 a8508d828395b58de22987cbf9689c4db9f2f823ce066f9774368efef44be4bb
MD5 70262a73e3d20eb0d0a5df1cfcd22d02
BLAKE2b-256 bc4b008b2c6b8a83b08963fe80c8d408a8d934dc641c9a45c0dadf22689409f0

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page