Skip to main content

A library to train PyTorch models as a stream of events

Project description

FitStream

A tiny library to make PyTorch experiment easy for small models and in-memory datasets.

fitstream-build PyPI downloads PyPI Python version

Getting started

Using uv

uv add fitstream

Using pip:

pip install fitstream

Training a model:

from torch.optim import Adam

from fitstream import epoch_stream, take # epoch_stream is the main entry point

X, y = get_data()
model = get_model()
loss = get_loss()
optimizer = Adam(model.parameters())

# an infinite stream of training epochs (limit it with `take` or `early_stop`)
events = epoch_stream((X, y), model, optimizer, loss, batch_size=32, shuffle=True)
for event in take(10)(events):
    print(f"step={event['step']}, loss={event['train_loss']}")
# epoch=1, loss=...
# epoch=2, loss=...
# ...

Basics

The core idea of the library is "training loop as a stream of events". The epoch_stream is just an iterable over dictionaries comprising of the epoch, the model, and the training loss. Everything we do is transforming or enriching these events. FitStream provides a small pipe(...) helper to compose transformations left-to-right.

Augmentation

The augment function turns an "augmenter" (a function that looks at an event and returns extra keys) into a stream transform stage. We typically compose stages with pipe(...).

Here is an example - we add the norm of the model parameters to each event:

from torch import nn, linalg
from fitstream import epoch_stream, augment, pipe

def model_param_norm(ev: dict) -> dict:
    model_params = nn.utils.parameters_to_vector(ev['model'].parameters())
    return {'model_param_norm': linalg.norm(model_params)}


events = pipe(
    epoch_stream(...),
    augment(model_param_norm),
)
for event in events:
    print(f"step={event['step']}", 
          f"model_param_norm={event['model_param_norm']}"
    )

We also have some built-in augmentation functions. Here is an example of adding validation loss to each event:

from torch import nn
from fitstream import epoch_stream, augment, pipe, validation_loss

validation_set = get_validation_set()
events = pipe(
    epoch_stream(...),
    augment(validation_loss(validation_set, nn.CrossEntropyLoss())),
)
for event in events:
    print(f"step={event['step']}, val_loss={event['val_loss']}")

We can, of course, augment the stream more than once:

events = pipe(
    epoch_stream(...),
    augment(validation_loss(...)),
    augment(model_param_norm),
)
for event in events:
    print(f"step={event['step']}", 
          f"val_loss={event['val_loss']}",
          f"model_param_norm={event['model_param_norm']}"
    )

Selecting events

Since the training loop is a standard Python iterable, you can use any Python selection logic. FitStream includes a small helper, take(...), to limit the number of epochs:

from fitstream import epoch_stream, take

for event in take(100)(epoch_stream(...)):
    print(event)
# {'step': 1, ....}
# {'step': 2, ...}
# ...
# { 'step': 100, ...}

fitstream has some of its own selection primitives, such as early stopping:

from fitstream import augment, early_stop, epoch_stream, pipe, take, validation_loss

events = pipe(
    epoch_stream(...),
    augment(validation_loss(...)),
    take(500),  # safety cap
    early_stop(key="val_loss", patience=10, mode="min", min_delta=1e-4),
)
for event in events:
    print(event)

mode="min" is the default. Use mode="max" for metrics such as accuracy, and set min_delta to ignore tiny noisy changes.

Side effects

Sometimes you want to log metrics (or write to an external system) without changing the stream. Use tap(fn, every=...) and the built-in print_keys(...) helper:

from fitstream import epoch_stream, pipe, print_keys, tap, take

events = pipe(
    epoch_stream(...),
    tap(print_keys("train_loss"), every=5),
    take(20),
)
list(events)

Sinks

Iterating over events and doing something yourself can be tedious, so we have some utilities to help you process the event stream.

It is typically useful to collect all events into a list, but exclude the model and keep just the metrics. We have the collect sink for that:

from fitstream import collect, epoch_stream, take

# collect 100 epochs to a list
history = collect(take(100)(epoch_stream(...)))

We can also store them to a jsonl file:

from fitstream import collect_jsonl, epoch_stream, take

# collect 100 epochs to json
collect_jsonl(take(100)(epoch_stream(...)), 'runs/my_experiment.jsonl')

Documentation

Full documentation is available at https://fitstream.readthedocs.io/.

Development

  • After cloning this repo, run make setup to create a virtual environment and install all dependencies.
  • Building is done via uv build.
  • Running tests is done via make test
  • Building documentation via make doc
  • Serving documentation locally (and opening it) via make doc-open
  • Linting via make lint

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

fitstream-0.3.0.tar.gz (12.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

fitstream-0.3.0-py3-none-any.whl (14.2 kB view details)

Uploaded Python 3

File details

Details for the file fitstream-0.3.0.tar.gz.

File metadata

  • Download URL: fitstream-0.3.0.tar.gz
  • Upload date:
  • Size: 12.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.6

File hashes

Hashes for fitstream-0.3.0.tar.gz
Algorithm Hash digest
SHA256 d27f4abcc5bee300e3c275d9c49e5efc0aca58b18b905cdd95a411a8cfc71aa4
MD5 44a47447177b3c857c277c05ef5a2352
BLAKE2b-256 a84a183f10316af8b1fa65b8d94cb4332b432eadf6d63cfc93f5bc354d58b43d

See more details on using hashes here.

File details

Details for the file fitstream-0.3.0-py3-none-any.whl.

File metadata

  • Download URL: fitstream-0.3.0-py3-none-any.whl
  • Upload date:
  • Size: 14.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.6

File hashes

Hashes for fitstream-0.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 d36de2d2ee64d7b4da3fbf981816b2df3a88e5d997e820c4ac1e16b85354bc19
MD5 8f2cadb5e865e336579ba1a3ff63f266
BLAKE2b-256 fedb01fac7d5e7ebe372f5e4ea1eaae081d7b63c144cfaf39981033eec6af346

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page