Skip to main content

FusionSent: A Fusion-Based Multi-Task Sentence Embedding Model

Project description

FusionSent: A Fusion-Based Multi-Task Sentence Embedding Model

Welcome to the FusionSent repository. FusionSent is an efficient few-shot learning model designed for multi-label classification of scientific documents with many classes.

Training Process of FusionSent

Figure 1: The training process of FusionSent comprises three steps:

  1. Fine-tune two different sentence embedding models from the same Pre-trained Language Model (PLM), with parameters θ₁, θ₂ respectively.
    • θ₁ is fine-tuned on pairs of training sentences using cosine similarity loss, and θ₂ is fine-tuned on pairs of training sentences and their corresponding label texts, using contrastive loss.
    • Label texts can consist of simple label/class names or more extensive texts that semantically describe the meaning of a label/class.
  2. Merge parameter sets θ₁, θ₂ into θ₃ using Spherical Linear Interpolation (SLERP).
  3. Freeze θ₃ to embed the training sentences, which are then used as input features to train a classification head.

By fine-tuning sentence embedding models using contrastive learning, FusionSent achieves high performance even with limited labeled data. The model initially leverages two distinct sub-models: one, using regular contrastive learning with item pairs ('setfit'), and another using label embeddings with class-description pairs ('label_embedding'). These two models are then fused, via (spherical) linear intterpolation, to create the robost FusionSent model that excels in diverse classification tasks. For detailed insights into the model and its performance, please refer to our published paper.

Overview

FusionSent is integrated with the Hugging Face Hub and provides two main classes:

  • FusionSentModel: This class encapsulates the dual fine-tuning process of the two sentence embedding models ('setfit, and 'label_embedding') and their fusion into a single model ('fusionsent'). It is the core model class for embedding sentences and performing classification tasks.
  • FusionTrainer: Responsible for loading, cleaning, and preparing datasets for training and evaluation.

Installation

To install the fusionSent package, use pip:

pip install fusionsent

Usage Example

from fusionsent.training_args import TrainingArguments
from fusionsent.modeling import FusionSentModel
from fusionsent.trainer import Trainer
from datasets import Dataset

# Example dataset objects with sentences belonging to classes: ["Computer Science", "Physics", "Philosophy"]
train_dataset = Dataset.from_dict({
    "text": [
        "Algorithms and data structures form the foundation of computer science.",
        "Quantum mechanics explores the behavior of particles at subatomic scales.",
        "The study of ethics is central to philosophical inquiry."
    ],
    "label": [
        [1, 0, 0],  # Computer Science
        [0, 1, 0],  # Physics
        [0, 0, 1]   # Philosophy
    ],
    "label_description": [
        ["Computer Science"],
        ["Physics"],
        ["Philosophy"]
    ]
})

eval_dataset = Dataset.from_dict({
    "text": [
        "Artificial intelligence is transforming the landscape of technology.",
        "General relativity revolutionized our understanding of gravity.",
        "Epistemology questions the nature and limits of human knowledge."
    ],
    "label": [
        [1, 0, 0],  # Computer Science
        [0, 1, 0],  # Physics
        [0, 0, 1]   # Philosophy
    ],
    "label_description": [
        ["Computer Science"],
        ["Physics"],
        ["Philosophy"]
    ]
})

# Load the model.
model_id = "malteos/scincl"
model = FusionSentModel._from_pretrained(model_id=model_id)

# Set training arguments.
training_args = TrainingArguments(
    batch_sizes=(16, 1),
    num_epochs=(1, 3),
    sampling_strategies="undersampling"
)

# Initialize trainer.
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset
)

# Train the model.
trainer.train()

# Evaluate the model.
eval_scores = trainer.evaluate(
    x_eval=eval_dataset["text"],
    y_eval=eval_dataset["label"]
)

# Perform inference.
texts = [
    "Computational complexity helps us understand the efficiency of algorithms.",
    "Thermodynamics studies the relationships between heat, work, and energy.",
    "Existentialism explores the freedom and responsibility of individual existence."
]
predictions = model.predict(texts)
print(predictions)

For a more elaborate example, please refer to the Jupyter notebook of a Description-Embedding Experiment.

Citation

If you use FusionSent in your research, please cite the following paper:

@article{...,
  title={...},
  author={...},
  journal={...},
  year={...}
}

For additional details and advanced configurations, please refer to the original paper linked at the beginning of this document.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

fusionsent-0.0.8.tar.gz (24.8 kB view details)

Uploaded Source

Built Distribution

fusionsent-0.0.8-py3-none-any.whl (23.7 kB view details)

Uploaded Python 3

File details

Details for the file fusionsent-0.0.8.tar.gz.

File metadata

  • Download URL: fusionsent-0.0.8.tar.gz
  • Upload date:
  • Size: 24.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.9.20

File hashes

Hashes for fusionsent-0.0.8.tar.gz
Algorithm Hash digest
SHA256 353e2cc2d3d640b99c2730e9f3a434655c3588762ce5efc50db1d64a2c201401
MD5 545f9b07e737f66ec06d6010b54eaf90
BLAKE2b-256 43c293156505e591e9800ca2e3ff2d7da499bae6956b427dfa8d118eeb192af9

See more details on using hashes here.

File details

Details for the file fusionsent-0.0.8-py3-none-any.whl.

File metadata

  • Download URL: fusionsent-0.0.8-py3-none-any.whl
  • Upload date:
  • Size: 23.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.9.20

File hashes

Hashes for fusionsent-0.0.8-py3-none-any.whl
Algorithm Hash digest
SHA256 eb6a3573c58eebd7dca4f13392cb142ea7feb4bc8ad411b06cb5548e5eb223e1
MD5 b2098320c33d35208f7e6e871ce82451
BLAKE2b-256 fb0d7c5d30bde32e7ec0f9dd9332399293d08a2d8e1616b24bb95349ac07c686

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page