Skip to main content

No project description provided

Project description

TorchLearn

TorchLearn is a Python package that provides an additional layer on top of vanilla pytorch in order to simplify and fasten the development and deployment of Machine Learning Training pipelines.

TorchLearn defines a Trainer abstraction that allows to define the processing logic as a Directed Acyclic Graph (DAG) and only execute the nodes that are necessary to compute the loss and metrics. An example of DAG definition and training is shown below:

from functools import partial

from torch.nn import functional as F
from torch.optim import Adam

from torchlearn.processing.graph import ProcessingGraph
from torchlearn.metric.classification import ConfusionMatrix, Accuracy, Precision

def processing(model):
    return ProcessingGraph(
        inputs=("input", "target", "sample_weight")
        functions=(
            ("output", model, "input"),
            ("predicted", partial(torch.argmax, dim=-1), "output"),
            ("sample_loss", partial(F.cross_entropy, reduction="none"), ("output", "target")),
            ("weighted_loss", torch.prod, ("sample_loss", "sample_weight")),
            ("loss", torch.sum, "weighted_loss")
        )
)

cm = ConfusionMatrix(classes)
metrics = {"acc": Accuracy(cm), "precision": Precision(cm, 1)}

trainer = Trainer(model, Adam)
results = trainer.train(processing, epochs, train_loader, val_loader, metrics=metrics)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchlearn-0.2.0.tar.gz (26.7 kB view details)

Uploaded Source

Built Distribution

torchlearn-0.2.0-py3-none-any.whl (30.8 kB view details)

Uploaded Python 3

File details

Details for the file torchlearn-0.2.0.tar.gz.

File metadata

  • Download URL: torchlearn-0.2.0.tar.gz
  • Upload date:
  • Size: 26.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.11.3

File hashes

Hashes for torchlearn-0.2.0.tar.gz
Algorithm Hash digest
SHA256 8b1f16f17d8728051983162957507c97bfe1eaab8745f8cc3055fef0be07266e
MD5 69d35697e7725587f6133f7f3dc5e41a
BLAKE2b-256 d7834a97902b27df1d5fa73b4819f73133b41331257a346328ca216caf006c69

See more details on using hashes here.

File details

Details for the file torchlearn-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: torchlearn-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 30.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.3

File hashes

Hashes for torchlearn-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 e676d6256ad87e6cc6f20d6036dcb1655b4d3be007f83b2da3c52c06d883455e
MD5 bee86295c44c96d1e6201de291515d35
BLAKE2b-256 f2ff6543a34bad575e688ce43c40b1416916f65aca76de04b6d6a73830c8cef7

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page