A library to train PyTorch models as a stream of events
Project description
FitStream
A tiny library to make PyTorch experiment easy for small models and in-memory datasets.
Getting started
Using uv
uv add fitstream
Using pip:
pip install fitstream
Training a model:
from torch.optim import Adam
from fitstream import epoch_stream, take # epoch_stream is the main entry point
X, y = get_data()
model = get_model()
loss = get_loss()
optimizer = Adam(model.parameters())
# an infinite stream of training epochs (limit it with `take` or `early_stop`)
events = epoch_stream((X, y), model, optimizer, loss, batch_size=32, shuffle=True)
for event in take(10)(events):
print(f"step={event['step']}, loss={event['train_loss']}")
# epoch=1, loss=...
# epoch=2, loss=...
# ...
Basics
The core idea of the library is "training loop as a stream of events". The epoch_stream is just an iterable over
dictionaries comprising of the epoch, the model, and the training loss. Everything we do is transforming or enriching
these events. FitStream provides a small pipe(...) helper to compose transformations left-to-right.
Augmentation
The augment function turns an "augmenter" (a function that looks at an event and returns extra keys) into a stream
transform stage. We typically compose stages with pipe(...).
Here is an example - we add the norm of the model parameters to each event:
from torch import nn, linalg
from fitstream import epoch_stream, augment, pipe
def model_param_norm(ev: dict) -> dict:
model_params = nn.utils.parameters_to_vector(ev['model'].parameters())
return {'model_param_norm': linalg.norm(model_params)}
events = pipe(
epoch_stream(...),
augment(model_param_norm),
)
for event in events:
print(f"step={event['step']}",
f"model_param_norm={event['model_param_norm']}"
)
We also have some built-in augmentation functions. Here is an example of adding validation loss to each event:
from torch import nn
from fitstream import epoch_stream, augment, pipe, validation_loss
validation_set = get_validation_set()
events = pipe(
epoch_stream(...),
augment(validation_loss(validation_set, nn.CrossEntropyLoss())),
)
for event in events:
print(f"step={event['step']}, val_loss={event['val_loss']}")
We can, of course, augment the stream more than once:
events = pipe(
epoch_stream(...),
augment(validation_loss(...)),
augment(model_param_norm),
)
for event in events:
print(f"step={event['step']}",
f"val_loss={event['val_loss']}",
f"model_param_norm={event['model_param_norm']}"
)
Selecting events
Since the training loop is a standard Python iterable, you can use any Python selection logic. FitStream includes a
small helper, take(...), to limit the number of epochs:
from fitstream import epoch_stream, take
for event in take(100)(epoch_stream(...)):
print(event)
# {'step': 1, ....}
# {'step': 2, ...}
# ...
# { 'step': 100, ...}
fitstream has some of its own selection primitives, such as early stopping:
from fitstream import augment, early_stop, epoch_stream, pipe, take, validation_loss
events = pipe(
epoch_stream(...),
augment(validation_loss(...)),
take(500), # safety cap
early_stop(key="val_loss", patience=10, mode="min", min_delta=1e-4),
)
for event in events:
print(event)
mode="min" is the default. Use mode="max" for metrics such as accuracy, and set
min_delta to ignore tiny noisy changes.
Side effects
Sometimes you want to log metrics (or write to an external system) without changing the stream.
Use tap(fn, every=...) and the built-in print_keys(...) helper:
from fitstream import epoch_stream, pipe, print_keys, tap, take
events = pipe(
epoch_stream(...),
tap(print_keys("train_loss"), every=5),
take(20),
)
list(events)
Sinks
Iterating over events and doing something yourself can be tedious, so we have some utilities to help you process the event stream.
It is typically useful to collect all events into a list, but exclude the model and keep just the metrics. We have
the collect sink for that:
from fitstream import collect, epoch_stream, take
# collect 100 epochs to a list
history = collect(take(100)(epoch_stream(...)))
We can also store them to a jsonl file:
from fitstream import collect_jsonl, epoch_stream, take
# collect 100 epochs to json
collect_jsonl(take(100)(epoch_stream(...)), 'runs/my_experiment.jsonl')
Documentation
Full documentation is available at https://fitstream.readthedocs.io/.
Development
- After cloning this repo, run
make setupto create a virtual environment and install all dependencies. - Building is done via
uv build. - Running tests is done via
make test - Building documentation via
make doc - Serving documentation locally (and opening it) via
make doc-open - Linting via
make lint
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file fitstream-0.3.0.tar.gz.
File metadata
- Download URL: fitstream-0.3.0.tar.gz
- Upload date:
- Size: 12.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.9.6
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
d27f4abcc5bee300e3c275d9c49e5efc0aca58b18b905cdd95a411a8cfc71aa4
|
|
| MD5 |
44a47447177b3c857c277c05ef5a2352
|
|
| BLAKE2b-256 |
a84a183f10316af8b1fa65b8d94cb4332b432eadf6d63cfc93f5bc354d58b43d
|
File details
Details for the file fitstream-0.3.0-py3-none-any.whl.
File metadata
- Download URL: fitstream-0.3.0-py3-none-any.whl
- Upload date:
- Size: 14.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.9.6
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
d36de2d2ee64d7b4da3fbf981816b2df3a88e5d997e820c4ac1e16b85354bc19
|
|
| MD5 |
8f2cadb5e865e336579ba1a3ff63f266
|
|
| BLAKE2b-256 |
fedb01fac7d5e7ebe372f5e4ea1eaae081d7b63c144cfaf39981033eec6af346
|