No project description provided
Project description
TorchLearn
TorchLearn is a Python package that provides an additional layer on top of vanilla pytorch in order to simplify and fasten the development and deployment of Machine Learning Training pipelines.
TorchLearn defines a Trainer
abstraction that allows to define the processing logic as a
Directed Acyclic Graph (DAG) and only execute the nodes that are necessary to compute the
loss and metrics. An example of DAG definition and training is shown below:
from functools import partial
from torch.nn import functional as F
from torch.optim import Adam
from torchlearn.processing.graph import ProcessingGraph
from torchlearn.metric.classification import ConfusionMatrix, Accuracy, Precision
def processing(model):
return ProcessingGraph(
inputs=("input", "target", "sample_weight")
functions=(
("output", model, "input"),
("predicted", partial(torch.argmax, dim=-1), "output"),
("sample_loss", partial(F.cross_entropy, reduction="none"), ("output", "target")),
("weighted_loss", torch.prod, ("sample_loss", "sample_weight")),
("loss", torch.sum, "weighted_loss")
)
)
cm = ConfusionMatrix(classes)
metrics = {"acc": Accuracy(cm), "precision": Precision(cm, 1)}
trainer = Trainer(model, Adam)
results = trainer.train(processing, epochs, train_loader, val_loader, metrics=metrics)
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
torchlearn-0.2.1.tar.gz
(32.3 kB
view details)
Built Distribution
File details
Details for the file torchlearn-0.2.1.tar.gz
.
File metadata
- Download URL: torchlearn-0.2.1.tar.gz
- Upload date:
- Size: 32.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | a8f4dae6778cb8aa1f358160ce7d5e063206e4585a1b842173d219a21f056306 |
|
MD5 | d261883bb9b34d6dc1b0293a20f64f0b |
|
BLAKE2b-256 | 18c73f79cd4b10206fe9cf33b3f4231c7747ffd59a9743936177c5a43d643ea2 |
File details
Details for the file torchlearn-0.2.1-py3-none-any.whl
.
File metadata
- Download URL: torchlearn-0.2.1-py3-none-any.whl
- Upload date:
- Size: 30.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 717b1b9e3db0480d2cf9abc6ff0453064053c3b4887685668d8253f2194fc04a |
|
MD5 | 699089cb7ef1044db0ed22a4c75de07f |
|
BLAKE2b-256 | 2cc43c7158c64a97cbf6eb97da59272d50e4b934134d83b64d978491e2c9eca2 |