Skip to main content

No project description provided

Project description

TorchLearn

TorchLearn is a Python package that provides an additional layer on top of vanilla pytorch in order to simplify and fasten the development and deployment of Machine Learning Training pipelines.

TorchLearn defines a Trainer abstraction that allows to define the processing logic as a Directed Acyclic Graph (DAG) and only execute the nodes that are necessary to compute the loss and metrics. An example of DAG definition and training is shown below:

from functools import partial

from torch.nn import functional as F
from torch.optim import Adam

from torchlearn.processing.graph import ProcessingGraph
from torchlearn.metric.classification import ConfusionMatrix, Accuracy, Precision

def processing(model):
    return ProcessingGraph(
        inputs=("input", "target", "sample_weight")
        functions=(
            ("output", model, "input"),
            ("predicted", partial(torch.argmax, dim=-1), "output"),
            ("sample_loss", partial(F.cross_entropy, reduction="none"), ("output", "target")),
            ("weighted_loss", torch.prod, ("sample_loss", "sample_weight")),
            ("loss", torch.sum, "weighted_loss")
        )
)

cm = ConfusionMatrix(classes)
metrics = {"acc": Accuracy(cm), "precision": Precision(cm, 1)}

trainer = Trainer(model, Adam)
results = trainer.train(processing, epochs, train_loader, val_loader, metrics=metrics)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchlearn-0.2.1.tar.gz (32.3 kB view details)

Uploaded Source

Built Distribution

torchlearn-0.2.1-py3-none-any.whl (30.8 kB view details)

Uploaded Python 3

File details

Details for the file torchlearn-0.2.1.tar.gz.

File metadata

  • Download URL: torchlearn-0.2.1.tar.gz
  • Upload date:
  • Size: 32.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.3

File hashes

Hashes for torchlearn-0.2.1.tar.gz
Algorithm Hash digest
SHA256 a8f4dae6778cb8aa1f358160ce7d5e063206e4585a1b842173d219a21f056306
MD5 d261883bb9b34d6dc1b0293a20f64f0b
BLAKE2b-256 18c73f79cd4b10206fe9cf33b3f4231c7747ffd59a9743936177c5a43d643ea2

See more details on using hashes here.

File details

Details for the file torchlearn-0.2.1-py3-none-any.whl.

File metadata

  • Download URL: torchlearn-0.2.1-py3-none-any.whl
  • Upload date:
  • Size: 30.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.3

File hashes

Hashes for torchlearn-0.2.1-py3-none-any.whl
Algorithm Hash digest
SHA256 717b1b9e3db0480d2cf9abc6ff0453064053c3b4887685668d8253f2194fc04a
MD5 699089cb7ef1044db0ed22a4c75de07f
BLAKE2b-256 2cc43c7158c64a97cbf6eb97da59272d50e4b934134d83b64d978491e2c9eca2

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page