No project description provided
Project description
TorchLearn
TorchLearn is a Python package that provides an additional layer on top of vanilla pytorch in order to simplify and fasten the development and deployment of Machine Learning Training pipelines.
TorchLearn defines a Trainer
abstraction that allows to define the processing logic as a
Directed Acyclic Graph (DAG) and only execute the nodes that are necessary to compute the
loss and metrics. An example of DAG definition and training is shown below:
from functools import partial
from torch.nn import functional as F
from torch.optim import Adam
from torchlearn.processing.graph import ProcessingGraph
from torchlearn.metric.classification import ConfusionMatrix, Accuracy, Precision
def processing(model):
return ProcessingGraph(
inputs=("input", "target", "sample_weight")
functions=(
("output", model, "input"),
("predicted", partial(torch.argmax, dim=-1), "output"),
("sample_loss", partial(F.cross_entropy, reduction="none"), ("output", "target")),
("weighted_loss", torch.prod, ("sample_loss", "sample_weight")),
("loss", torch.sum, "weighted_loss")
)
)
cm = ConfusionMatrix(classes)
metrics = {"acc": Accuracy(cm), "precision": Precision(cm, 1)}
trainer = Trainer(model, Adam)
results = trainer.train(processing, epochs, train_loader, val_loader, metrics=metrics)
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
torchlearn-0.2.0.tar.gz
(26.7 kB
view details)
Built Distribution
File details
Details for the file torchlearn-0.2.0.tar.gz
.
File metadata
- Download URL: torchlearn-0.2.0.tar.gz
- Upload date:
- Size: 26.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.11.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 8b1f16f17d8728051983162957507c97bfe1eaab8745f8cc3055fef0be07266e |
|
MD5 | 69d35697e7725587f6133f7f3dc5e41a |
|
BLAKE2b-256 | d7834a97902b27df1d5fa73b4819f73133b41331257a346328ca216caf006c69 |
File details
Details for the file torchlearn-0.2.0-py3-none-any.whl
.
File metadata
- Download URL: torchlearn-0.2.0-py3-none-any.whl
- Upload date:
- Size: 30.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | e676d6256ad87e6cc6f20d6036dcb1655b4d3be007f83b2da3c52c06d883455e |
|
MD5 | bee86295c44c96d1e6201de291515d35 |
|
BLAKE2b-256 | f2ff6543a34bad575e688ce43c40b1416916f65aca76de04b6d6a73830c8cef7 |