Skip to main content

No project description provided

Project description

Domain Adapted Language Modeling Toolkit

Manifesto

A great rift has emerged between general LLMs and the vector stores that are providing them with contextual information. The unification of these systems is an important step in grounding AI systems in efficient, factual domains, where they are utilized not only for their generality, but for their specificity and uniqueness. To this end, we are excited to open source the Arcee Domain Adapted Language Model (DALM) toolkit for developers to build on top of our Arcee open source Domain Pretrained (DPT) LLMs. We believe that our efforts will help as we begin next phase of language modeling, where organizations deeply tailor AI to operate according to their unique intellectual property and worldview.

Demo DALMs

Query example DALMs created by the Arcee Team.

DALM-Patent DALM-PubMed DALM-SEC DALM-Yours

Research Contents

This repository primarily contains code for fine-tuning a fully differential Retrieval Augmented Generation (RAG-end2end) architecture.

E2E

For the first time in the literature, we modified the initial RAG-end2end model (TACL paper, HuggingFace implementation) to work with decoder-only language models like Llama, Falcon, or GPT. We also incorporated the in-batch negative concept alongside the RAG's marginalization to make the entire process efficient.

  • Inside the training folder, you'll find two codes to train the RAG-end2end and Retriever with contrastive learning.

  • All evaluations related to the Retriever and the Generator are located in the eval folder.

  • Additionally, we have data processing codes and synthetic data generation code inside the datasets folder.

Usage

To perform training and evaluation for both the retriever model and the new rag-e2e model, please adhere to the following steps.

Installation

You can install this repo directly via pip install indomain

Alternatively, for development or research, you can clone and install the repo locally:

git clone https://github.com/arcee-ai/DALM.git && cd DALM
pip install --upgrade -e .

This will install the DALM repo and all necessary dependencies.

Make sure things are installed correctly by running dalm version. On an non-intel Mac you may need to downgrade transformers library: pip install transformers==4.30.

Data setup

tl;dr

You can run dalm qa-gen <path-to-dataset> to preprocess your dataset for training. See dalm qa-gen --help for more options
If you do not have a dataset, you can start with ours

# Note - our dataset already has queries and answers, so you don't actually need to run this.
# replace `toy_dataset_train.csv` with your dataset of titles and passages
dalm qa-gen dalm/datasets/toy_data_train.csv
  • The setup for training and evaluation can be effortlessly executed provided you possess a CSV file containing two/three columns: Passage, Query (and Answer if running e2e). You can utilize the script question_answer_generation.py to generate this CSV.
  • It's important to highlight that the retriever-only training method employs solely the passages and queries, whereas the rag-e2e training code utilizes all three columns.
  • In our experiments, we utilize BAAI/bge-large-en as the default retriever and employ meta-llama/Llama-2-7b-hf as the default generator. The code is designed to be compatible with any embedding model or autoregressive model available in the Hugging Face model repository at https://huggingface.co/models.

Training

You can leverage our scripts directly if you'd like, or you can use the dalm cli. The arguments for both are identical

Train Retriever Only

Train BAAI/bge-large-en retriever with contrastive learning.

python dalm/training/retriever_only/train_retriever_only.py \
--dataset_path "./dalm/datasets/toy_data_train.csv" \
--retriever_name_or_path "BAAI/bge-large-en" \
--output_dir "retriever_only_checkpoints" \
--use_peft \
--with_tracking \
--report_to all \
--per_device_train_batch_size 150

or

dalm train-retriever-only "BAAI/bge-large-en" "./dalm/datasets/toy_data_train.csv" \
--output-dir "retriever_only_checkpoints" \
--use-peft \
--with-tracking \
--report-to all \
--per-device-train-batch-size 150

For all available arguments and options, see dalm train-retriever-only --help

Train Retriever and Generator Jointly (RAG-e2e)

Train Llama-2-7b generator jointly with the retriever model BAAI/bge-large-en.

python dalm/training/rag_e2e/train_rage2e.py \
  --dataset_path "./dalm/datasets/toy_data_train.csv" \
  --retriever_name_or_path "BAAI/bge-large-en" \
  --generator_name_or_path "meta-llama/Llama-2-7b-hf" \
  --output_dir "rag_e2e_checkpoints" \
  --with_tracking \
  --report_to all \
  --per_device_train_batch_size 150

or

dalm train-rag-e2e \
"./dalm/datasets/toy_data_train.csv" \
"BAAI/bge-large-en" \
"meta-llama/Llama-2-7b-hf" \
--output-dir "rag_e2e_checkpoints" \
--with-tracking \
--report-to all \
--per-device-train-batch-size 150

For all available arguments and options, see dalm train-rag-e2e --help

Evaluation

Here's a summary of evaluation results on evaluating on a 200K line test csv of Patent abstracts

Type of Retriever Recall Hit rate
Plain Retriever 0.45984 0.45984
Retriever with contrastive learning 0.46037 0.46038
Retriever End2End 0.73634 0.73634

To run retriever only eval (make sure you have the checkpoints in the project root)

 python dalm/eval/eval_retriever_only.py  \
 --dataset_path qa_pairs_test.csv \
 --retriever_name_or_path "BAAI/bge-large-en" \
 --passage_column_name Abstract \
 --query_column_name Question \
 --retriever_peft_model_path retriever_only_checkpoints

or

dalm eval-retriever qa_pairs_test.csv \
 --retriever-name-or-path "BAAI/bge-large-en" \
 --passage-column-name Abstract \
 --query-column-name Question \
 --retriever-peft-model-path retriever_only_checkpoints

See dalm eval-retriever --help for all available arguments

For the e2e eval

python dalm/eval/eval_rag.py  \
 --dataset_path qa_pairs_test_2.csv \
 --retriever_name_or_path "BAAI/bge-large-en" \
 --generator_name_or_path "meta-llama/Llama-2-7b-hf" \
 --passage_column_name Abstract \
 --query_column_name Question \
 --answer_column_name Answer \
 --evaluate_generator \
 --query_batch_size 5 \
 --retriever_peft_model_path rag_e2e_checkpoints/retriever \
 --generator_peft_model_path rag_e2e_checkpoints/generator

or

dalm eval-rag qa_pairs_test.csv \
 --retriever-name-or-path "BAAI/bge-large-en" \
 --generator-name-or-path "meta-llama/Llama-2-7b-hf" \
 --retriever-peft-model-path rag_e2e_checkpoints/retriever \
 --generator-peft-model-path rag_e2e_checkpoints/generator \
 --passage-column-name Abstract \
 --query-column-name Question \
 --answer-column-name Answer \
 --query-batch-size 5

See dalm eval-rag --help for all available arguments

Contributing

See CONTRIBUTING

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

indomain-0.0.5.tar.gz (7.9 MB view details)

Uploaded Source

Built Distribution

indomain-0.0.5-py3-none-any.whl (50.7 kB view details)

Uploaded Python 3

File details

Details for the file indomain-0.0.5.tar.gz.

File metadata

  • Download URL: indomain-0.0.5.tar.gz
  • Upload date:
  • Size: 7.9 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.5

File hashes

Hashes for indomain-0.0.5.tar.gz
Algorithm Hash digest
SHA256 b701c071cf628e8a15675defccfd42df6811212a3f8dde31cd8d6ddbbc1064c9
MD5 a4e1bbc4f08374ef505e5bc9efa5cd02
BLAKE2b-256 2eda05b100d73d047c18cfb6954c26ed59bb0f85f5f558507be5fe86801e2bf9

See more details on using hashes here.

File details

Details for the file indomain-0.0.5-py3-none-any.whl.

File metadata

  • Download URL: indomain-0.0.5-py3-none-any.whl
  • Upload date:
  • Size: 50.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.5

File hashes

Hashes for indomain-0.0.5-py3-none-any.whl
Algorithm Hash digest
SHA256 2ea79f7456f96d37f439fad24aebb206ae5a3af6bdd31b3428053cd55d7621ac
MD5 dc5aaf263b4fcfd7bbe9c207e0ce991a
BLAKE2b-256 9dc38c05f5486f98b12835ed3ad8a538ef886935d65ecf4d9eba95cfc32edc31

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page