Skip to main content

Multi-task learning utilities for fastai

Project description

fastmtl

Multi-task learning utilities for fastai

Install

pip install fastmtl

Usage

Loss

Apply a loss function per model output and get weighted sum of them. For instance, if the first model output is for classification and the second model output is for regression,

from fastmtl.loss import CombinedLoss
loss_func = CombinedLoss(CrossEntropyLossFlat(), MSELossFlat(), weight=[1.0, 3.0])

Metric

Apply metrics for each model output. For instance, if we have a model making classification and regression, we can evaluate each model output with relevant metrics. Assuming that model outputs a tuple of tensors for classification and regression, respectively:

from fastai.metrics import F1Score, R2Score
from fastmtl.metric import mtl_metrics

clf_f1_macro =  F1Score(average='macro')
clf_f1_macro.name = 'clf_f1(macro)'
clf_f1_micro =  F1Score(average='micro')
clf_f1_micro.name = 'clf_f1(micro)'

reg_r2 = R2Score()
reg_r2.name = 'reg_r2'

# metrics for classification in the first list 
# metrics for regression in the second list 
metrics = mtl_metrics([clf_f1_macro, clf_f1_micro], [reg_r2])

learn = Learner(
    ...
    metrics=metrics,
)

Tutorials

Video distortion detection

TODO

  • Support tabular learner
  • Support fastai>=2.7

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

fastmtl-1.2.1.tar.gz (10.0 kB view details)

Uploaded Source

File details

Details for the file fastmtl-1.2.1.tar.gz.

File metadata

  • Download URL: fastmtl-1.2.1.tar.gz
  • Upload date:
  • Size: 10.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.9

File hashes

Hashes for fastmtl-1.2.1.tar.gz
Algorithm Hash digest
SHA256 f3e100a9933f9b268845e8c26b3ec652133dcb842a2d2172472067d4c1bc2567
MD5 bbf1b443bdf76b77eed179a1a61e3414
BLAKE2b-256 4fa80b71003543840b5e96359c5678070e2445583ca759eba009013b9b320d5c

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page