No project description provided
Project description
TorchLearn
TorchLearn is a Python package that provides an additional layer on top of vanilla pytorch in order to simplify and fasten the development and deployment of Machine Learning Training pipelines.
TorchLearn defines a Trainer
abstraction that allows to define the processing logic as a
Directed Acyclic Graph (DAG) and only execute the nodes that are necessary to compute the
loss and metrics. An example of DAG definition and training is shown below:
from functools import partial
from torch.nn import functional as F
from torch.optim import Adam
from torchlearn.processing.graph import ProcessingGraph
from torchlearn.metric.classification import ConfusionMatrix, Accuracy, Precision
def processing(model):
return ProcessingGraph(
inputs=("input", "target", "sample_weight")
functions=(
("output", model, "input"),
("predicted", partial(torch.argmax, dim=-1), "output"),
("sample_loss", partial(F.cross_entropy, reduction="none"), ("output", "target")),
("weighted_loss", torch.prod, ("sample_loss", "sample_weight")),
("loss", torch.sum, "weighted_loss")
)
)
cm = ConfusionMatrix(classes)
metrics = {"acc": Accuracy(cm), "precision": Precision(cm, 1)}
trainer = Trainer(model, Adam)
results = trainer.train(processing, epochs, train_loader, val_loader, metrics=metrics)
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
torchlearn-0.2.2.tar.gz
(32.4 kB
view details)
Built Distribution
File details
Details for the file torchlearn-0.2.2.tar.gz
.
File metadata
- Download URL: torchlearn-0.2.2.tar.gz
- Upload date:
- Size: 32.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | ddbcec8c2e1ec30ee29b5efe7363f0976f1b9008a202ceb3240aa40be59d9cdb |
|
MD5 | 60ed2e73721b5cf110da9586cb2f27b5 |
|
BLAKE2b-256 | 3b091be4bfab3589aacabeedae41192945b9e3869adbfd5d5fd5c3d4d7810012 |
File details
Details for the file torchlearn-0.2.2-py3-none-any.whl
.
File metadata
- Download URL: torchlearn-0.2.2-py3-none-any.whl
- Upload date:
- Size: 30.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | b0c43b038536213e210268478589bb56b101a59b8b306e04516bf0bce6e0a94e |
|
MD5 | 43a862e10e96a1162d1be7bb126e5df5 |
|
BLAKE2b-256 | 8f1e0d0c545c262392e73a08b3e744af4256b3981cc83a769d58bf199d2ace47 |