Skip to main content

No project description provided

Project description

Domain Adapted Language Modeling Toolkit

Manifesto

A great rift has emerged between general LLMs and the vector stores that are providing them with contextual information. The unification of these systems is an important step in grounding AI systems in efficient, factual domains, where they are utilized not only for their generality, but for their specificity and uniqueness. To this end, we are excited to open source the Arcee Domain Adapted Language Model (DALM) toolkit for developers to build on top of our Arcee open source Domain Pretrained (DPT) LLMs. We believe that our efforts will help as we begin next phase of language modeling, where organizations deeply tailor AI to operate according to their unique intellectual property and worldview.

Demo DALMs

Query example DALMs created by the Arcee Team.

DALM-Patent DALM-PubMed DALM-SEC DALM-Yours

Research Contents

This repository primarily contains code for fine-tuning a fully differential Retrieval Augmented Generation (RAG-end2end) architecture.

E2E

For the first time in the literature, we modified the initial RAG-end2end model (TACL paper, HuggingFace implementation) to work with decoder-only language models like Llama, Falcon, or GPT. We also incorporated the in-batch negative concept alongside the RAG's marginalization to make the entire process efficient.

  • Inside the training folder, you'll find two codes to train the RAG-end2end and Retriever with contrastive learning.

  • All evaluations related to the Retriever and the Generator are located in the eval folder.

  • Additionally, we have data processing codes and synthetic data generation code inside the datasets folder.

Usage

To perform training and evaluation for both the retriever model and the new rag-e2e model, please adhere to the following steps.

Installation

You can install this repo directly via pip install indomain

Alternatively, for development or research, you can clone and install the repo locally:

git clone https://github.com/arcee-ai/DALM.git && cd DALM
pip install --upgrade -e .

This will install the DALM repo and all necessary dependencies.

Make sure things are installed correctly by running dalm version

Data setup

tl;dr

You can run dalm qa-gen <path-to-dataset> to preprocess your dataset for training. See dalm qa-gen --help for more options
If you do not have a dataset, you can start with ours

dalm qa-gen dalm/datasets/toy_data_train.csv
  • The setup for training and evaluation can be effortlessly executed provided you possess a CSV file containing two/three columns: Passage, Query (and Answer if running e2e). You can utilize the script question_answer_generation.py to generate this CSV.
  • It's important to highlight that the retriever-only training method employs solely the passages and queries, whereas the rag-e2e training code utilizes all three columns.
  • In our experiments, we utilize BAAI/bge-large-en as the default retriever and employ meta-llama/Llama-2-7b-hf as the default generator. The code is designed to be compatible with any embedding model or autoregressive model available in the Hugging Face model repository at https://huggingface.co/models.

Training

You can leverage our scripts directly if you'd like, or you can use the dalm cli. The arguments for both are identical

Train Retriever Only

Train BAAI/bge-large-en retriever with contrastive learning.

python dalm/training/retriever_only/train_retriever_only.py \
--dataset_path "./dalm/datasets/toy_data_train.csv" \
--model_name_or_path "BAAI/bge-large-en" \
--output_dir "retriever_only_checkpoints" \
--use_peft \
--with_tracking \
--report_to all \
--per_device_train_batch_size 150

or

dalm train-retriever-only "BAAI/bge-large-en" "./dalm/datasets/toy_data_train.csv" \
--output-dir "retriever_only_checkpoints" \
--use-peft \
--with-tracking \
--report-to all \
--per-device-train-batch-size 150

For all available arguments and options, see dalm train-retriever-only --help

Train Retriever and Generator Jointly (RAG-e2e)

Train Llama-2-7b generator jointly with the retriever model BAAI/bge-large-en.

python dalm/training/rag_e2e/train_rage2e.py \
  --dataset_path "./dalm/datasets/toy_data_train.csv" \
  --retriever_name_or_path "BAAI/bge-large-en" \
  --generator_name_or_path "meta-llama/Llama-2-7b-hf" \
  --output_dir "rag_e2e_checkpoints" \
  --with_tracking \
  --report_to all \
  --per_device_train_batch_size 150

or

dalm train-rag-e2e \
"./dalm/datasets/toy_data_train.csv" \
"BAAI/bge-large-en" \
"meta-llama/Llama-2-7b-hf" \
--output-dir "rag_e2e_checkpoints" \
--with-tracking \
--report-to all \
--per-device-train-batch-size 150

For all available arguments and options, see dalm train-rag-e2e --help

Evaluation

Here's a summary of evaluation results on evaluating on a 200K line test csv of Patent abstracts

Type of Retriever Recall Hit rate
Plain Retriever 0.45984 0.45984
Retriever with contrastive learning 0.46037 0.46038
Retriever End2End 0.73634 0.73634

To run retriever only eval (make sure you have the checkpoints in the project root)

 python dalm/eval/eval_retriever_only.py  --dataset_path qa_pairs_test.csv --retriever_name_or_path "BAAI/bge-large-en" --passage_column_name Abstract --query_column_name Question --retriever_peft_model_path retriever_only_checkpoints

For the e2e eval

python dalm/eval/eval_rag.py  --dataset_path qa_pairs_test_2.csv --retriever_name_or_path "BAAI/bge-large-en" --generator_model_name_or_path "meta-llama/Llama-2-7b-hf" --passage_column_name Abstract --query_column_name Question --answer_column_name Answer --evaluate_generator --query_batch_size 5 --retriever_peft_model_path rag_e2e_checkpoints/retriever --generator_peft_model_path rag_e2e_checkpoints/generator

Contributing

See CONTRIBUTING

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

indomain-0.0.4.tar.gz (3.0 MB view details)

Uploaded Source

Built Distribution

indomain-0.0.4-py3-none-any.whl (46.0 kB view details)

Uploaded Python 3

File details

Details for the file indomain-0.0.4.tar.gz.

File metadata

  • Download URL: indomain-0.0.4.tar.gz
  • Upload date:
  • Size: 3.0 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.5

File hashes

Hashes for indomain-0.0.4.tar.gz
Algorithm Hash digest
SHA256 2bdfe425773ae4df8dc3c611517e0b4e0d667dcfdcf5d9610b0bc48c01275be5
MD5 94734f2c6118440654cbfbed64b84e9f
BLAKE2b-256 9e41a36f529f9f671f95e9c1224b91c90e4f8c314fb3b037fd777e9abe14397e

See more details on using hashes here.

File details

Details for the file indomain-0.0.4-py3-none-any.whl.

File metadata

  • Download URL: indomain-0.0.4-py3-none-any.whl
  • Upload date:
  • Size: 46.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.5

File hashes

Hashes for indomain-0.0.4-py3-none-any.whl
Algorithm Hash digest
SHA256 8e963dc439bec8e377e164533c055643dbdaa1d76d1be6ef36b5880ea241ef9b
MD5 d1f995977064a0c433822c44f340e758
BLAKE2b-256 1bcac20a9fb7f2b2016bdeb12a99234db762eb3b682da9169257ea50559427c0

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page