Skip to main content

FusionSent: A Fusion-Based Multi-Task Sentence Embedding Model

Project description

FusionSent: A Fusion-Based Multi-Task Sentence Embedding Model

Welcome to the FusionSent repository. FusionSent is an efficient few-shot learning model designed for multi-label classification of scientific documents with many classes.

Training Process of FusionSent

Figure 1: The training process of FusionSent comprises three steps:

  1. Fine-tune two different sentence embedding models from the same Pre-trained Language Model (PLM), with parameters θ₁, θ₂ respectively.
    • θ₁ is fine-tuned on pairs of training sentences using cosine similarity loss, and θ₂ is fine-tuned on pairs of training sentences and their corresponding label texts, using contrastive loss.
    • Label texts can consist of simple label/class names or more extensive texts that semantically describe the meaning of a label/class.
  2. Merge parameter sets θ₁, θ₂ into θ₃ using Spherical Linear Interpolation (SLERP).
  3. Freeze θ₃ to embed the training sentences, which are then used as input features to train a classification head.

By fine-tuning sentence embedding models using contrastive learning, FusionSent achieves high performance even with limited labeled data. The model initially leverages two distinct sub-models: one, using regular contrastive learning with item pairs ('setfit'), and another using label embeddings with class-description pairs ('label_embedding'). These two models are then fused, via (spherical) linear intterpolation, to create the robost FusionSent model that excels in diverse classification tasks. For detailed insights into the model and its performance, please refer to our published paper.

Overview

FusionSent is integrated with the Hugging Face Hub and provides two main classes:

  • FusionSentModel: This class encapsulates the dual fine-tuning process of the two sentence embedding models ('setfit, and 'label_embedding') and their fusion into a single model ('fusionsent'). It is the core model class for embedding sentences and performing classification tasks.
  • FusionTrainer: Responsible for loading, cleaning, and preparing datasets for training and evaluation.

Installation

To install the fusionSent package, use pip:

pip install fusionsent

Usage Example

from fusionsent.training_args import TrainingArguments
from fusionsent.modeling import FusionSentModel
from fusionsent.trainer import Trainer
from datasets import Dataset

# Example dataset objects with sentences belonging to classes: ["Computer Science", "Physics", "Philosophy"]
train_dataset = Dataset.from_dict({
    "text": [
        "Algorithms and data structures form the foundation of computer science.",
        "Quantum mechanics explores the behavior of particles at subatomic scales.",
        "The study of ethics is central to philosophical inquiry."
    ],
    "label": [
        [1, 0, 0],  # Computer Science
        [0, 1, 0],  # Physics
        [0, 0, 1]   # Philosophy
    ],
    "label_description": [
        ["Computer Science"],
        ["Physics"],
        ["Philosophy"]
    ]
})

eval_dataset = Dataset.from_dict({
    "text": [
        "Artificial intelligence is transforming the landscape of technology.",
        "General relativity revolutionized our understanding of gravity.",
        "Epistemology questions the nature and limits of human knowledge."
    ],
    "label": [
        [1, 0, 0],  # Computer Science
        [0, 1, 0],  # Physics
        [0, 0, 1]   # Philosophy
    ],
    "label_description": [
        ["Computer Science"],
        ["Physics"],
        ["Philosophy"]
    ]
})

# Load the model.
model_id = "malteos/scincl"
model = FusionSentModel._from_pretrained(model_id=model_id)

# Set training arguments.
training_args = TrainingArguments(
    batch_sizes=(16, 1),
    num_epochs=(1, 3),
    sampling_strategies="undersampling"
)

# Initialize trainer.
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset
)

# Train the model.
trainer.train()

# Evaluate the model.
eval_scores = trainer.evaluate(
    x_eval=eval_dataset["text"],
    y_eval=eval_dataset["label"]
)

# Perform inference.
texts = [
    "Computational complexity helps us understand the efficiency of algorithms.",
    "Thermodynamics studies the relationships between heat, work, and energy.",
    "Existentialism explores the freedom and responsibility of individual existence."
]
predictions = model.predict(texts)
print(predictions)

For a more elaborate example, please refer to the Jupyter notebook of a Description-Embedding Experiment.

Citation

If you use FusionSent in your research, please cite the following paper:

@article{...,
  title={...},
  author={...},
  journal={...},
  year={...}
}

For additional details and advanced configurations, please refer to the original paper linked at the beginning of this document.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

fusionsent-0.0.7.tar.gz (24.8 kB view details)

Uploaded Source

Built Distribution

fusionsent-0.0.7-py3-none-any.whl (23.7 kB view details)

Uploaded Python 3

File details

Details for the file fusionsent-0.0.7.tar.gz.

File metadata

  • Download URL: fusionsent-0.0.7.tar.gz
  • Upload date:
  • Size: 24.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.9.20

File hashes

Hashes for fusionsent-0.0.7.tar.gz
Algorithm Hash digest
SHA256 911ed78d45cd2fb7912c1c6e9b2e78b271004ffd1bdd14a81d70d9d4c5a99f74
MD5 a042de7a14457e5484e828ef24a359b2
BLAKE2b-256 a5c8ec7e607736b483a07f35392695e1d92eeda99472edc613283148009dae02

See more details on using hashes here.

File details

Details for the file fusionsent-0.0.7-py3-none-any.whl.

File metadata

  • Download URL: fusionsent-0.0.7-py3-none-any.whl
  • Upload date:
  • Size: 23.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.9.20

File hashes

Hashes for fusionsent-0.0.7-py3-none-any.whl
Algorithm Hash digest
SHA256 88be62f485ed283b71079b87021ac8500639206edac54f844e643376829c6a7c
MD5 0278e3a9029b7da77cedcc8eb0dfc71a
BLAKE2b-256 233ca739b16a3ef149f294a503107422e404ac1f8c57909d964b6ba6f227103b

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page