Skip to main content

Code used to train ColPali

Reason this release was yanked:

The `models` module is missing.

Project description

ColPali: Efficient Document Retrieval with Vision Language Models 👀

arXiv GitHub Hugging Face

[Model card] [ViDoRe Benchmark] [ViDoRe Leaderboard] [Demo] [Blog Post]

[!TIP] If you want to try the pre-trained ColPali on your own documents, you should use the vidore-benchmark repository. It comes with a Python package and a CLI tool for convenient evaluation.

Associated Paper

ColPali: Efficient Document Retrieval with Vision Language Models Manuel Faysse*, Hugues Sibille*, Tony Wu* Bilel Omrani, Gautier Viaud, Céline Hudelot, Pierre Colombo (*Equal Contribution)

This repository contains the code used for training the vision retrievers in the paper. In particular, it contains the code for training the ColPali model, which is a vision retriever based on the ColBERT architecture.

Setup

We used Python 3.11.6 and PyTorch 2.2.2 to train and test our models, but the codebase is expected to be compatible with Python >=3.9 and recent PyTorch versions.

The eval codebase depends on a few Python packages, which can be downloaded using the following command:

pip install colpali-engine

You can a install the package from the source code:

pip install git+https://github.com/illuin-tech/colpali

Usage

Example usage of the model is shown in the scripts directory.

python scripts/infer/run_inference_with_python.py
import torch
import typer
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoProcessor
from PIL import Image

from colpali_engine.models.paligemma_colbert_architecture import ColPali
from colpali_engine.trainer.retrieval_evaluator import CustomEvaluator
from colpali_engine.utils.colpali_processing_utils import process_images, process_queries
from colpali_engine.utils.image_from_page_utils import load_from_dataset


def main() -> None:
    """Example script to run inference with ColPali"""
    # Load model
    model_name = "vidore/colpali"
    model = ColPali.from_pretrained("google/paligemma-3b-mix-448", torch_dtype=torch.bfloat16, device_map="cuda").eval()
    model.load_adapter(model_name)
    processor = AutoProcessor.from_pretrained(model_name)

    # select images -> load_from_pdf(<pdf_path>),  load_from_image_urls(["<url_1>"]), load_from_dataset(<path>)
    images = load_from_dataset("vidore/docvqa_test_subsampled")
    queries = ["From which university does James V. Fiorca come ?", "Who is the japanese prime minister?"]

    # run inference - docs
    dataloader = DataLoader(
        images,
        batch_size=4,
        shuffle=False,
        collate_fn=lambda x: process_images(processor, x),
    )
    ds = []
    for batch_doc in tqdm(dataloader):
        with torch.no_grad():
            batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
            embeddings_doc = model(**batch_doc)
        ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))

    # run inference - queries
    dataloader = DataLoader(
        queries,
        batch_size=4,
        shuffle=False,
        collate_fn=lambda x: process_queries(processor, x, Image.new("RGB", (448, 448), (255, 255, 255))),
    )

    qs = []
    for batch_query in dataloader:
        with torch.no_grad():
            batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
            embeddings_query = model(**batch_query)
        qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))

    # run evaluation
    retriever_evaluator = CustomEvaluator(is_multi_vector=True)
    scores = retriever_evaluator.evaluate(qs, ds)
    print(scores.argmax(axis=1))


if __name__ == "__main__":
    typer.run(main)

More details are given in the ColPali Model card.

Training

USE_LOCAL_DATASET=0 python scripts/train/train_colbert.py scripts/configs/siglip/train_siglip_model_debug.yaml

or

accelerate launch scripts/train/train_colbert.py scripts/configs/train_colidefics_model.yaml

Configurations

All training arguments can be set through a configuration file. The configuration file is a yaml file that contains all the arguments for training.

The construction is as follows:

@dataclass
class ColModelTrainingConfig:
    model: PreTrainedModel
    tr_args: TrainingArguments = None
    output_dir: str = None
    max_length: int = 256
    run_eval: bool = True
    run_train: bool = True
    peft_config: Optional[LoraConfig] = None
    add_suffix: bool = False
    processor: Idefics2Processor = None
    tokenizer: PreTrainedTokenizer = None
    loss_func: Optional[Callable] = ColbertLoss()
    dataset_loading_func: Optional[Callable] = None
    eval_dataset_loader: Optional[Dict[str, Callable]] = None
    pretrained_peft_model_name_or_path: Optional[str] = None

Example

Example of a YAML configuration file:

config:
  (): colpali_engine.utils.train_colpali_engine_models.ColModelTrainingConfig
  output_dir: !path ../../../models/without_tabfquad/train_colpali-3b-mix-448
  processor:
    () : colpali_engine.utils.wrapper.AutoProcessorWrapper
    pretrained_model_name_or_path: "./models/paligemma-3b-mix-448"
    max_length: 50
  model:
    (): colpali_engine.utils.wrapper.AutoColModelWrapper
    pretrained_model_name_or_path: "./models/paligemma-3b-mix-448"
    training_objective: "colbertv1"
    # attn_implementation: "eager"
    torch_dtype:  !ext torch.bfloat16
#    device_map: "auto"
#    quantization_config:
#      (): transformers.BitsAndBytesConfig
#      load_in_4bit: true
#      bnb_4bit_quant_type: "nf4"
#      bnb_4bit_compute_dtype:  "bfloat16"
#      bnb_4bit_use_double_quant: true

  dataset_loading_func: !ext colpali_engine.utils.dataset_transformation.load_train_set
  eval_dataset_loader: !import ../data/test_data.yaml

  max_length: 50
  run_eval: true
  add_suffix: true
  loss_func:
    (): colpali_engine.loss.colbert_loss.ColbertPairwiseCELoss
  tr_args: !import ../tr_args/default_tr_args.yaml
  peft_config:
    (): peft.LoraConfig
    r: 32
    lora_alpha: 32
    lora_dropout: 0.1
    init_lora_weights: "gaussian"
    bias: "none"
    task_type: "FEATURE_EXTRACTION"
    target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)'
    # target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)'

Local training

USE_LOCAL_DATASET=0 python scripts/train/train_colbert.py scripts/configs/siglip/train_siglip_model_debug.yaml

SLURM

sbatch --nodes=1 --cpus-per-task=16 --mem-per-cpu=32GB --time=20:00:00 --gres=gpu:1  -p gpua100 --job-name=colidefics --output=colidefics.out --error=colidefics.err --wrap="accelerate launch scripts/train/train_colbert.py  scripts/configs/train_colidefics_model.yaml"

sbatch --nodes=1  --time=5:00:00 -A cad15443 --gres=gpu:8  --constraint=MI250 --job-name=colpali --wrap="python scripts/train/train_colbert.py scripts/configs/train_colpali_model.yaml"

Citation

ColPali: Efficient Document Retrieval with Vision Language Models

  • First authors: Manuel Faysse*, Hugues Sibille*, Tony Wu* (*Equal Contribution)
  • Contributors: Bilel Omrani, Gautier Viaud, Céline Hudelot, Pierre Colombo
@misc{faysse2024colpaliefficientdocumentretrieval,
      title={ColPali: Efficient Document Retrieval with Vision Language Models}, 
      author={Manuel Faysse and Hugues Sibille and Tony Wu and Bilel Omrani and Gautier Viaud and Céline Hudelot and Pierre Colombo},
      year={2024},
      eprint={2407.01449},
      archivePrefix={arXiv},
      primaryClass={cs.IR},
      url={https://arxiv.org/abs/2407.01449}, 
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

colpali_engine-0.1.0.tar.gz (28.7 kB view hashes)

Uploaded Source

Built Distribution

colpali_engine-0.1.0-py3-none-any.whl (30.9 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page