Skip to main content

No project description provided

Project description

Domain Adapted Language Modeling Toolkit

Manifesto

A great rift has emerged between general LLMs and the vector stores that are providing them with contextual information. The unification of these systems is an important step in grounding AI systems in efficient, factual domains, where they are utilized not only for their generality, but for their specificity and uniqueness. To this end, we are excited to open source the Arcee Domain Adapted Language Model (DALM) toolkit for developers to build on top of our Arcee open source Domain Pretrained (DPT) LLMs. We believe that our efforts will help as we begin next phase of language modeling, where organizations deeply tailor AI to operate according to their unique intellectual property and worldview.

Demo DALMs

Query example DALMs created by the Arcee Team.

DALM-Patent DALM-PubMed DALM-SEC DALM-Yours

Research Contents

This repository primarily contains code for fine-tuning a fully differential Retrieval Augmented Generation (RAG-end2end) architecture.

E2E

For the first time in the literature, we modified the initial RAG-end2end model (TACL paper, HuggingFace implementation) to work with decoder-only language models like Llama, Falcon, or GPT. We also incorporated the in-batch negative concept alongside the RAG's marginalization to make the entire process efficient.

  • Inside the training folder, you'll find two codes to train the RAG-end2end and Retriever with contrastive learning.

  • All evaluations related to the Retriever and the Generator are located in the eval folder.

  • Additionally, we have data processing codes and synthetic data generation code inside the datasets folder.

Usage

To perform training and evaluation for both the retriever model and the new rag-e2e model, please adhere to the following steps.

Installation

You can install this repo directly via pip install indomain

Alternatively, for development or research, you can clone and install the repo locally:

git clone https://github.com/arcee-ai/DALM.git && cd DALM
pip install --upgrade -e .

This will install the DALM repo and all necessary dependencies.

Make sure things are installed correctly by running dalm version

Data setup

tl;dr

You can run dalm qa-gen <path-to-dataset> to preprocess your dataset for training. See dalm qa-gen --help for more options
If you do not have a dataset, you can start with ours

dalm qa-gen dalm/datasets/toy_data_train.csv
  • The setup for training and evaluation can be effortlessly executed provided you possess a CSV file containing two/three columns: Passage, Query (and Answer if running e2e). You can utilize the script question_answer_generation.py to generate this CSV.
  • It's important to highlight that the retriever-only training method employs solely the passages and queries, whereas the rag-e2e training code utilizes all three columns.
  • In our experiments, we utilize BAAI/bge-large-en as the default retriever and employ meta-llama/Llama-2-7b-hf as the default generator. The code is designed to be compatible with any embedding model or autoregressive model available in the Hugging Face model repository at https://huggingface.co/models.

Training

You can leverage our scripts directly if you'd like, or you can use the dalm cli. The arguments for both are identical

Train Retriever Only

Train BAAI/bge-large-en retriever with contrastive learning.

python dalm/training/retriever_only/train_retriever_only.py \
--train_dataset_csv_path "./dalm/datasets/toy_data_train.csv" \
--model_name_or_path "BAAI/bge-large-en" \
--output_dir "./dalm/training/rag_e2e/retriever_only_checkpoints" \
--use_peft \
--with_tracking \
--report_to all \
--per_device_train_batch_size 150

or

dalm train-retriever-only "BAAI/bge-large-en" "./dalm/datasets/toy_data_train.csv" \
--output-dir "./dalm/training/rag_e2e/retriever_only_checkpoints" \
--use-peft \
--with-tracking \
--report-to all \
--per-device-train-batch-size 150

For all available arguments and options, see dalm train-retriever-only --help

Train Retriever and Generator Jointly (RAG-e2e)

Train Llama-2-7b generator jointly with the retriever model BAAI/bge-large-en.

python dalm/training/rag_e2e/train_rage2e.py \
  --dataset_path "./dalm/datasets/toy_data_train.csv" \
  --retriever_name_or_path "BAAI/bge-large-en" \
  --generator_name_or_path "meta-llama/Llama-2-7b-hf" \
  --output_dir "./dalm/training/rag_e2e/rag_e2e_checkpoints" \
  --with_tracking \
  --report_to all \
  --per_device_train_batch_size 24

or

dalm train-rag-e2e \
"./dalm/datasets/toy_data_train.csv" \
"BAAI/bge-large-en" \
"meta-llama/Llama-2-7b-hf" \
--output-dir "./dalm/training/rag_e2e/rag_e2e_checkpoints" \
--with-tracking \
--report-to all \
--per-device-train-batch-size 24

For all available arguments and options, see dalm train-rag-e2e --help

Evaluation

Here's a summary of evaluation results on evaluating on a 200K line test csv of Patent abstracts

Type of Retriever Recall Hit rate
Plain Retriever 0.45984 0.45984
Retriever with contrastive learning 0.46037 0.46038
Retriever End2End 0.73634 0.73634

To run retriever only eval (make sure you have the checkpoints in the project root)

 python dalm/eval/eval_retriever_only.py  --dataset_path qa_pairs_test.csv --retriever_model_name_or_path "BAAI/bge-large-en" --passage_column_name Abstract --query_column_name Question --retriever_peft_model_path retriever_only_checkpoints

For the e2e eval

python dalm/eval/eval_rag.py  --dataset_path qa_pairs_test_2.csv --retriever_model_name_or_path "BAAI/bge-large-en" --generator_model_name_or_path "meta-llama/Llama-2-7b-hf" --passage_column_name Abstract --query_column_name Question --answer_column_name Answer --evaluate_generator --query_batch_size 5 --retriever_peft_model_path retriever_only_checkpoints --generator_peft_model_path generator_only_checkpoints

Contributing

See CONTRIBUTING

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

indomain-0.0.1.tar.gz (40.8 kB view details)

Uploaded Source

Built Distribution

indomain-0.0.1-py3-none-any.whl (45.4 kB view details)

Uploaded Python 3

File details

Details for the file indomain-0.0.1.tar.gz.

File metadata

  • Download URL: indomain-0.0.1.tar.gz
  • Upload date:
  • Size: 40.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.5

File hashes

Hashes for indomain-0.0.1.tar.gz
Algorithm Hash digest
SHA256 bcfef81953762e9a47fdcc5e723769dff2278f679b79f414ed61d91c69b7278f
MD5 8db0308d278ee5a49f4cabaea1f2c5a8
BLAKE2b-256 1a4035f0c9411ad2c4ab35460d84427142b760995b881542858d6b42eaa09df3

See more details on using hashes here.

File details

Details for the file indomain-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: indomain-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 45.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.5

File hashes

Hashes for indomain-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 6213ef412c4d7cb78f29283bc25aa1f880d716ac722f3b290bdb473bc971ca67
MD5 eef46649fc3d0b43d8720f46c6f0dc79
BLAKE2b-256 f2410c1657520e3d63bd296dcbdf5f73b366cad3b3297c78fd23da0c4717a5c7

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page