Skip to main content

No project description provided

Project description

TorchLearn

TorchLearn is a Python package that provides an additional layer on top of vanilla pytorch in order to simplify and fasten the development and deployment of Machine Learning Training pipelines.

TorchLearn defines a Trainer abstraction that allows to define the processing logic as a Directed Acyclic Graph (DAG) and only execute the nodes that are necessary to compute the loss and metrics. An example of DAG definition and training is shown below:

from functools import partial

from torch.nn import functional as F
from torch.optim import Adam

from torchlearn.processing.graph import ProcessingGraph
from torchlearn.metric.classification import ConfusionMatrix, Accuracy, Precision

def processing(model):
    return ProcessingGraph(
        inputs=("input", "target", "sample_weight")
        functions=(
            ("output", model, "input"),
            ("predicted", partial(torch.argmax, dim=-1), "output"),
            ("sample_loss", partial(F.cross_entropy, reduction="none"), ("output", "target")),
            ("weighted_loss", torch.prod, ("sample_loss", "sample_weight")),
            ("loss", torch.sum, "weighted_loss")
        )
)

cm = ConfusionMatrix(classes)
metrics = {"acc": Accuracy(cm), "precision": Precision(cm, 1)}

trainer = Trainer(model, Adam)
results = trainer.train(processing, epochs, train_loader, val_loader, metrics=metrics)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchlearn-0.2.2.tar.gz (32.4 kB view details)

Uploaded Source

Built Distribution

torchlearn-0.2.2-py3-none-any.whl (30.8 kB view details)

Uploaded Python 3

File details

Details for the file torchlearn-0.2.2.tar.gz.

File metadata

  • Download URL: torchlearn-0.2.2.tar.gz
  • Upload date:
  • Size: 32.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.3

File hashes

Hashes for torchlearn-0.2.2.tar.gz
Algorithm Hash digest
SHA256 ddbcec8c2e1ec30ee29b5efe7363f0976f1b9008a202ceb3240aa40be59d9cdb
MD5 60ed2e73721b5cf110da9586cb2f27b5
BLAKE2b-256 3b091be4bfab3589aacabeedae41192945b9e3869adbfd5d5fd5c3d4d7810012

See more details on using hashes here.

File details

Details for the file torchlearn-0.2.2-py3-none-any.whl.

File metadata

  • Download URL: torchlearn-0.2.2-py3-none-any.whl
  • Upload date:
  • Size: 30.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.3

File hashes

Hashes for torchlearn-0.2.2-py3-none-any.whl
Algorithm Hash digest
SHA256 b0c43b038536213e210268478589bb56b101a59b8b306e04516bf0bce6e0a94e
MD5 43a862e10e96a1162d1be7bb126e5df5
BLAKE2b-256 8f1e0d0c545c262392e73a08b3e744af4256b3981cc83a769d58bf199d2ace47

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page