Skip to main content

A flexible multi-task learning library for natural language processing with support for hierarchical task structures.

Project description

multitask

A flexible multi-task learning library for NLP supporting mixed task types (binary/multiclass/multilabel/regression), multiple encoders, and distributed training.

Table of Contents

Features

  • Mixed task types in single model (binary, multiclass, multilabel, regression)
  • Multiple encoders: TensorFlow Hub, HuggingFace, pre-computed embeddings
  • Automatic class imbalance and task weighting
  • Automatic threshold optimization for binary/multilabel tasks
  • Cross-platform distributed training with automatic GPU detection
  • Model persistence with threshold saving
  • Comprehensive type hints

Installation

From PyPI

pip install multitask-py

This installs the core library and its dependencies:

  • TensorFlow (+ TF Hub, TF Text, tf-models-official)
  • NumPy
  • scikit-learn
  • pandas
  • PyYAML

To install with all optional backends and features:

pip install 'multitask-py[huggingface,openai,cohere,voyageai,regression]'

Or install extras individually as needed:

Extra What it adds
huggingface transformers, torch, huggingface-hub
openai OpenAI embeddings API client
cohere Cohere embeddings API client
voyageai Voyage AI embeddings API client
regression scipy (Pearson r in evaluation)

From source

git clone https://github.com/yourusername/multitask.git
cd multitask
pip install .                  # core only
pip install '.[huggingface]'   # core + HuggingFace support

GPU support

  • Apple Silicon: pip install tensorflow-macos tensorflow-metal
  • NVIDIA: pip install tensorflow[and-cuda]

Developers

The Conda environment pins every dependency for reproducibility:

conda env create -f environment.yml
conda activate multitask
pip install -e .

This includes all optional extras, dev tools (pytest, black, mypy), and system libraries (MPI, BLAS, HDF5) that pip cannot provide.

Package management overview

File Audience Purpose
pyproject.toml Users Declares runtime deps (loose bounds) and optional extras
environment.yml Developers Pins exact versions for a reproducible conda environment

Quick Start

from multitask import ModelConfig, TrainingConfig, MultiTaskModel, Trainer, EncoderConfig
from multitask.config import EncoderIntegration, EncoderInputType

# Configure universal sentence encoder
encoder_config = EncoderConfig(
    encoder_integration=EncoderIntegration.TFHUB,
    encoder_input_type=EncoderInputType.RAW_STRING,
    encoder_identifier='https://tfhub.dev/google/universal-sentence-encoder/4',
)

# Three tasks → task_structure must have three outputs (see docs for hierarchical layouts)
config = ModelConfig(
    task_structure=[[3]],
    task_names=['sentiment', 'toxicity', 'emotions'],
    task_types=['multiclass', 'binary', 'multiclass'],
    num_classes_per_task=[3, 2, 5],  # binary tasks use num_classes=2
    encoder_config=encoder_config,
)

model = MultiTaskModel(config)

# Train: build tf.data.Dataset batches of (inputs, labels_dict); see docs/DATA_FORMAT.md
train_dataset = ...  # e.g. tf.data.Dataset.from_tensor_slices(...) then .batch(...)
val_dataset = ...
trainer = Trainer(model, TrainingConfig())
history, thresholds = trainer.fit(train_dataset, val_dataset, num_tasks=3)

# Predict: pass tensors or numpy matching the model input (e.g. text or embeddings)
predictions = model.predict(x_test)  # dict[str, Tensor] per task

Data Format

Text Input

DataFrame with text column and one column per task:

import pandas as pd

data = pd.DataFrame({
    'text': ['great product', 'terrible', ...],
    'sentiment': [2, 0, ...],              # 0-2 for 3-class
    'toxicity': [0, 1, ...],               # 0-1 for binary
    'emotions': ['joy anger', 'fear', ...],  # Space-separated for multilabel
})

Pre-computed Embeddings

Add embedding column with numpy arrays:

import numpy as np
from multitask.config import EncoderInputType

data['embedding'] = [np.random.randn(768) for _ in range(len(data))]

encoder_config = EncoderConfig(
    encoder_input_type=EncoderInputType.PRECOMPUTED,
    embedding_dim=768
)

# PRECOMPUTED path: no encoder_identifier is required
# (inputs are embedding vectors provided directly by your dataset)

config = ModelConfig(
    task_structure=[[1]],
    task_names=['sentiment'],
    task_types=['multiclass'],
    num_classes_per_task=[3],
    encoder_config=encoder_config
)

Missing Labels

You can set any float as the missing labels. Here we use -1 to mark missing labels (excluded from loss):

data = pd.DataFrame({
    'text': ['text1', 'text2', 'text3'],
    'task1': [0, 1, -1],     # task1 missing for text3
    'task2': [1, -1, 0],     # task2 missing for text2
})
# Optional: use a custom missing-label value globally
config = ModelConfig(
    task_structure=[[2]],
    task_names=['task1', 'task2'],
    task_types=['binary', 'multiclass'],
    num_classes_per_task=[2, 3],
    mask_value=-999.0,
)

See DATA_FORMAT.md for complete documentation.

Encoder Types

TensorFlow Hub

from multitask.config import EncoderIntegration, EncoderInputType

# Universal Sentence Encoder
encoder_config = EncoderConfig(
    encoder_identifier="https://tfhub.dev/google/universal-sentence-encoder/4",
    encoder_input_type=EncoderInputType.RAW_STRING,
    encoder_integration=EncoderIntegration.TFHUB
)

config = ModelConfig(
    task_structure=[[1]],
    task_names=['sentiment'],
    task_types=['multiclass'],
    num_classes_per_task=[3],
    encoder_config=encoder_config
)

HuggingFace

from multitask.config import EncoderIntegration, EncoderInputType

encoder_config = EncoderConfig(
    encoder_identifier="bert-base-uncased",
    encoder_input_type=EncoderInputType.HUGGINGFACE_TOKENS,
    encoder_integration=EncoderIntegration.HUGGINGFACE
)

config = ModelConfig(
    task_structure=[[1]],
    task_names=['sentiment'],
    task_types=['multiclass'],
    num_classes_per_task=[3],
    encoder_config=encoder_config
)

Pre-computed Embeddings

from multitask.config import EncoderInputType

encoder_config = EncoderConfig(
    embedding_dim=768,
    encoder_input_type=EncoderInputType.PRECOMPUTED,
)

config = ModelConfig(
    task_structure=[[1]],
    encoder_config=encoder_config
)

See ENCODER_SUPPORT.md for all options.

Training

Basic

config = TrainingConfig(batch_size=32, epochs=10, learning_rate=2e-5)
trainer = Trainer(model, config)
history, thresholds = trainer.fit(train_dataset, val_dataset, num_tasks=N)

Automatic Weight Computation

By default, trainer automatically computes:

  • Task weights: Inversely proportional to non-masked samples per task
  • Class weights: Inversely proportional to class frequency

Override with explicit weights:

trainer.fit(
    train_data, val_data,
    class_weights=[{0: 1.0, 1: 2.0}, ...],  # Per task
    task_weights=[1.0, 2.0, ...],
)

Threshold Optimization

Find optimal thresholds for binary/multilabel tasks:

history, thresholds = trainer.fit(
    train_dataset, val_dataset, num_tasks=N,
    optimize_thresholds=True,  # Uses Youden's J statistic
)

Distributed Training

Automatic GPU detection and strategy selection:

from multitask import get_distribution_strategy, setup_gpu_memory_growth

setup_gpu_memory_growth()
strategy, info = get_distribution_strategy('auto', verbose=True)

with strategy.scope():
    model = MultiTaskModel(config)
    trainer = Trainer(model, training_config, strategy=strategy)

trainer.fit(train_dataset, val_dataset, num_tasks=N)

See DISTRIBUTED_TRAINING.md.

Checkpointing

trainer.fit(
    train_dataset, val_dataset, num_tasks=N,
    checkpoint_dir='checkpoints/exp1',
    checkpoint_monitor='val_loss',
)

Verbosity

trainer.fit(train_dataset, val_dataset, num_tasks=N, verbose=0)  # Silent
trainer.fit(train_dataset, val_dataset, num_tasks=N, verbose=1)  # Progress (default)
trainer.fit(train_dataset, val_dataset, num_tasks=N, verbose=2)  # One line per epoch

Model Persistence

Save

model.save('models/my_model', thresholds=thresholds)

Load

from multitask import load_model

model, thresholds = load_model('models/my_model')
predictions = model.predict(x_test)

See MODEL_PERSISTENCE.md.

Configuration

ModelConfig

ModelConfig(
    task_structure=[[3]],            # Branching layout; sets number of outputs
    task_names=['a', 'b', 'c'],
    task_types=['binary', 'multiclass', 'multilabel'],
    num_classes_per_task=[2, 5, 4],
    shared_layer_sizes=[256, 128],   # Optional dense stack before branches
    default_layer_sizes=[128, 64],   # Per-branch dense stacks (if not using branch_layer_sizes)
    encoder_config=encoder_config,
)

dropout_rate is not a ModelConfig field — pass it to MultiTaskModel(...):

model = MultiTaskModel(config, dropout_rate=training_config.dropout_rate)

The default falls back to 0.1 (matching TrainingConfig.dropout_rate's default).

EncoderConfig

EncoderConfig(
    encoder_integration=EncoderIntegration.TFHUB,
    encoder_input_type=EncoderInputType.TFHUB_TOKENS,
    encoder_identifier='...',
    embedding_dim=768,                # Pre-computed / no encoder
)

TrainingConfig

TrainingConfig(
    batch_size=32,
    epochs=10,
    learning_rate=2e-5,
    weight_decay_rate=0.01,
    dropout_rate=0.1,
    early_stopping=True,
    early_stopping_patience=3,
)

Troubleshooting

Import errors — install the missing package or the matching extra:

pip install tensorflow-hub             # ModuleNotFoundError: tensorflow_hub
pip install transformers               # ModuleNotFoundError: transformers
pip install 'multitask-py[huggingface]'   # ...or install the extra (includes transformers + torch)
pip install 'multitask-py[openai]'        # ModuleNotFoundError: openai
pip install 'multitask-py[regression]'    # ModuleNotFoundError: scipy

GPU issues:

import tensorflow as tf
print(tf.config.list_physical_devices('GPU'))

Out of memory:

from multitask import setup_gpu_memory_growth
setup_gpu_memory_growth()
# Or: TrainingConfig(batch_size=16)

NaN loss:

  • Check labels are integers in valid range
  • Reduce learning rate: learning_rate=1e-5

API Reference

Package exports (from multitask import ...): MultiTaskModel, Trainer, ModelConfig, TrainingConfig, EncoderConfig, load_model, and distributed helpers above.

Enums (import from multitask.config): EncoderIntegration, EncoderInputType.

Utils: load_model(), get_distribution_strategy(), setup_gpu_memory_growth(), check_distributed_compatibility(), print_device_info()

Examples

  • examples/minimal_pipeline_example.py: Shortest full pipeline (NumPy + tf.data, train, evaluate, save/load)
  • examples/hierarchical_pipeline_example.py: Two-level task_structure, same end-to-end flow
  • examples/dataframe_pipeline_example.py: Pandas DataFrame → tf.data → train → evaluate
  • examples/save_load_example.py: Model persistence
  • examples/distributed_training.py: Distributed training

Citation

@software{multitask,
  title = {multitask: A Flexible Multi-Task Learning Library for NLP},
  author = {Mehlhaff, Isaac D. and Morucci, Marco},
  year = {2025},
}

See LICENSE.

Pull requests welcome.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

multitask_py-1.0.0.tar.gz (60.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

multitask_py-1.0.0-py3-none-any.whl (59.0 kB view details)

Uploaded Python 3

File details

Details for the file multitask_py-1.0.0.tar.gz.

File metadata

  • Download URL: multitask_py-1.0.0.tar.gz
  • Upload date:
  • Size: 60.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.3

File hashes

Hashes for multitask_py-1.0.0.tar.gz
Algorithm Hash digest
SHA256 2543f9482f85821137c7180f5c6e415e2a1aafc2aa52a34a1a1651225ad7760d
MD5 4694d6bda5d55a14c3181791955deefa
BLAKE2b-256 6456030dded97ab68df820bf3756e61e810fa46b8a162f675d34aa7af7fc80d8

See more details on using hashes here.

File details

Details for the file multitask_py-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: multitask_py-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 59.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.3

File hashes

Hashes for multitask_py-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 d86e8b9cacc9a48c512ecf3fb3ab0089a3ac33ab572a805b7be0f880bcf9897c
MD5 d80178a193c447aba05e9bd3fde840be
BLAKE2b-256 b122ee3c9075a6a6702b0e0ad342d986bdf878e1e466ec8fb6b4eb694872ec8e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page