Multi-task learning utilities for fastai
Project description
fastmtl
Multi-task learning utilities for fastai
Install
pip install fastmtl
Usage
Loss
Apply a loss function per model output and get weighted sum of them. For instance, if the first model output is for classification and the second model output is for regression,
from fastmtl.loss import CombinedLoss
loss_func = CombinedLoss(CrossEntropyLossFlat(), MSELossFlat(), weight=[1.0, 3.0])
Metric
Apply metrics for each model output. For instance, if we have a model making classification and regression, we can evaluate each model output with relevant metrics. Assuming that model outputs a tuple of tensors for classification and regression, respectively:
from fastai.metrics import F1Score, R2Score
from fastmtl.metric import mtl_metrics
clf_f1_macro = F1Score(average='macro')
clf_f1_macro.name = 'clf_f1(macro)'
clf_f1_micro = F1Score(average='micro')
clf_f1_micro.name = 'clf_f1(micro)'
reg_r2 = R2Score()
reg_r2.name = 'reg_r2'
# metrics for classification in the first list
# metrics for regression in the second list
metrics = mtl_metrics([clf_f1_macro, clf_f1_micro], [reg_r2])
learn = Learner(
...
metrics=metrics,
)
Tutorials
TODO
- Support tabular learner
- Support fastai>=2.7
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
fastmtl-1.2.1.tar.gz
(10.0 kB
view details)
File details
Details for the file fastmtl-1.2.1.tar.gz
.
File metadata
- Download URL: fastmtl-1.2.1.tar.gz
- Upload date:
- Size: 10.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.10.9
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | f3e100a9933f9b268845e8c26b3ec652133dcb842a2d2172472067d4c1bc2567 |
|
MD5 | bbf1b443bdf76b77eed179a1a61e3414 |
|
BLAKE2b-256 | 4fa80b71003543840b5e96359c5678070e2445583ca759eba009013b9b320d5c |