No project description provided
Project description
Domain Adapted Language Modeling Toolkit
Manifesto
A great rift has emerged between general LLMs and the vector stores that are providing them with contextual information. The unification of these systems is an important step in grounding AI systems in efficient, factual domains, where they are utilized not only for their generality, but for their specificity and uniqueness. To this end, we are excited to open source the Arcee Domain Adapted Language Model (DALM) toolkit for developers to build on top of our Arcee open source Domain Pretrained (DPT) LLMs. We believe that our efforts will help as we begin next phase of language modeling, where organizations deeply tailor AI to operate according to their unique intellectual property and worldview.
Demo DALMs
Query example DALMs created by the Arcee Team.
DALM-Patent | DALM-PubMed | DALM-SEC | DALM-Yours |
---|---|---|---|
Research Contents
This repository primarily contains code for fine-tuning a fully differential Retrieval Augmented Generation (RAG-end2end) architecture.
For the first time in the literature, we modified the initial RAG-end2end model (TACL paper, HuggingFace implementation) to work with decoder-only language models like Llama, Falcon, or GPT. We also incorporated the in-batch negative concept alongside the RAG's marginalization to make the entire process efficient.
-
Inside the training folder, you'll find two codes to train the RAG-end2end and Retriever with contrastive learning.
-
All evaluations related to the Retriever and the Generator are located in the eval folder.
-
Additionally, we have data processing codes and synthetic data generation code inside the datasets folder.
Usage
To perform training and evaluation for both the retriever model and the new rag-e2e model, please adhere to the following steps.
Installation
You can install this repo directly via pip install indomain
Alternatively, for development or research, you can clone and install the repo locally:
git clone https://github.com/arcee-ai/DALM.git && cd DALM
pip install --upgrade -e .
This will install the DALM repo and all necessary dependencies.
Make sure things are installed correctly by running dalm version
Data setup
tl;dr
You can run dalm qa-gen <path-to-dataset>
to preprocess your dataset for training. See dalm qa-gen --help
for more options
If you do not have a dataset, you can start with ours
dalm qa-gen dalm/datasets/toy_data_train.csv
- The setup for training and evaluation can be effortlessly executed provided you possess a CSV file containing two/three columns:
Passage
,Query
(andAnswer
if running e2e). You can utilize the script question_answer_generation.py to generate this CSV. - It's important to highlight that the retriever-only training method employs solely the passages and queries, whereas the rag-e2e training code utilizes all three columns.
- In our experiments, we utilize
BAAI/bge-large-en
as the default retriever and employmeta-llama/Llama-2-7b-hf
as the default generator. The code is designed to be compatible with any embedding model or autoregressive model available in the Hugging Face model repository at https://huggingface.co/models.
Training
You can leverage our scripts directly if you'd like, or you can use the dalm
cli. The arguments for both are identical
Train Retriever Only
Train BAAI/bge-large-en
retriever with contrastive learning.
python dalm/training/retriever_only/train_retriever_only.py \
--dataset_path "./dalm/datasets/toy_data_train.csv" \
--model_name_or_path "BAAI/bge-large-en" \
--output_dir "retriever_only_checkpoints" \
--use_peft \
--with_tracking \
--report_to all \
--per_device_train_batch_size 150
or
dalm train-retriever-only "BAAI/bge-large-en" "./dalm/datasets/toy_data_train.csv" \
--output-dir "retriever_only_checkpoints" \
--use-peft \
--with-tracking \
--report-to all \
--per-device-train-batch-size 150
For all available arguments and options, see dalm train-retriever-only --help
Train Retriever and Generator Jointly (RAG-e2e)
Train Llama-2-7b
generator jointly with the retriever model BAAI/bge-large-en
.
python dalm/training/rag_e2e/train_rage2e.py \
--dataset_path "./dalm/datasets/toy_data_train.csv" \
--retriever_name_or_path "BAAI/bge-large-en" \
--generator_name_or_path "meta-llama/Llama-2-7b-hf" \
--output_dir "rag_e2e_checkpoints" \
--with_tracking \
--report_to all \
--per_device_train_batch_size 150
or
dalm train-rag-e2e \
"./dalm/datasets/toy_data_train.csv" \
"BAAI/bge-large-en" \
"meta-llama/Llama-2-7b-hf" \
--output-dir "rag_e2e_checkpoints" \
--with-tracking \
--report-to all \
--per-device-train-batch-size 150
For all available arguments and options, see dalm train-rag-e2e --help
Evaluation
Here's a summary of evaluation results on evaluating on a 200K line test csv of Patent abstracts
Type of Retriever | Recall | Hit rate |
---|---|---|
Plain Retriever | 0.45984 | 0.45984 |
Retriever with contrastive learning | 0.46037 | 0.46038 |
Retriever End2End | 0.73634 | 0.73634 |
To run retriever only eval (make sure you have the checkpoints in the project root)
python dalm/eval/eval_retriever_only.py --dataset_path qa_pairs_test.csv --retriever_model_name_or_path "BAAI/bge-large-en" --passage_column_name Abstract --query_column_name Question --retriever_peft_model_path retriever_only_checkpoints
For the e2e eval
python dalm/eval/eval_rag.py --dataset_path qa_pairs_test_2.csv --retriever_model_name_or_path "BAAI/bge-large-en" --generator_model_name_or_path "meta-llama/Llama-2-7b-hf" --passage_column_name Abstract --query_column_name Question --answer_column_name Answer --evaluate_generator --query_batch_size 5 --retriever_peft_model_path rag_e2e_checkpoints/retriever --generator_peft_model_path rag_e2e_checkpoints/generator
Contributing
See CONTRIBUTING
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file indomain-0.0.3.tar.gz
.
File metadata
- Download URL: indomain-0.0.3.tar.gz
- Upload date:
- Size: 3.0 MB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 6b6a466045be8c5cb9f65ce9d65d64fa0aa53133666e21cc7252f315267dfb7e |
|
MD5 | 34705fb3d74cad4c309ba8e6b4588d1c |
|
BLAKE2b-256 | 81b910aabd8f2fd623350e35a4b4bdf58d1cb6d3329a98cb6acb527946a602d6 |
File details
Details for the file indomain-0.0.3-py3-none-any.whl
.
File metadata
- Download URL: indomain-0.0.3-py3-none-any.whl
- Upload date:
- Size: 46.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 844cb62209a9d48120b7ad1530aa2dd031439f8718490ea06e266166ac62c2b2 |
|
MD5 | 57f90f4d98221b3e3b841d652cb7f70e |
|
BLAKE2b-256 | 199b21772749e0903727a15fd895d0af0caef89f87dd138d9430892b211394ff |