Skip to main content

A highly memory-efficient contrastive loss.

Project description

Breaking the Memory Barrier: Near Infinite Batch Size Scaling for Contrastive Loss

If our project helps you, please give us a star โญ on GitHub to support us. ๐Ÿ™๐Ÿ™

arXiv License Hits GitHub issues GitHub closed issues

๐Ÿ’ก Some other multimodal foundation model projects from our team may interest you โœจ.

Video-LLaMA: An Instruction-tuned Audio-Visual Language Model for Video Understanding
Hang Zhang, Xin Li, Lidong Bing
github github arXiv

VCD: Mitigating Object Hallucinations in Large Vision-Language Models through Visual Contrastive Decoding
Sicong Leng, Hang Zhang, Guanzheng Chen, Xin Li, Shijian Lu, Chunyan Miao, Lidong Bing
github github arXiv

VideoLLaMA 2: Advancing Spatial-Temporal Modeling and Audio Understanding in Video-LLMs
Zesen Cheng, Sicong Leng, Hang Zhang, Yifei Xin, Xin Li, Guanzheng Chen, Yongxin Zhu, Wenqi Zhang, Ziyang Luo, Deli Zhao, Lidong Bing
github github arXiv

The Curse of Multi-Modalities: Evaluating Hallucinations of Large Multimodal Models across Language, Visual, and Audio
Sicong Leng, Yun Xing, Zesen Cheng, Yang Zhou, Hang Zhang, Xin Li, Deli Zhao, Shijian Lu, Chunyan Miao, Lidong Bing
github github arXiv

๐Ÿ“ฐ News

  • [2024.10.18] Release training, evaluation codes of Inf-CLIP.

๐Ÿ› ๏ธ Requirements and Installation

Basic Dependencies:

  • Python >= 3.8
  • Pytorch >= 2.0.0
  • CUDA Version >= 11.8

[Remote] Install Inf-CL:

# remote installing
pip install inf_cl

[Local] Install Inf-CL:

pip install -e .

Install required packages:

git clone https://github.com/DAMO-NLP-SG/Inf-CLIP
cd Inf-CLIP
pip install -r requirements.txt

โญ Features

inf_cl is the triton implementation of Inf-CL loss:

inf_clip is the CLIP training codebase with Inf-CL loss and other training features:

๐Ÿ”‘ Usage

A simple example about how to adopt our Inf-CL loss for contrastive learning. Using such command for attempting:

torchrun --nproc_per_node 2 tests/example.py
import torch
import torch.nn.functional as F
import torch.distributed as dist
import numpy as np

from inf_cl import cal_inf_loss


def create_cl_tensors(rank, world_size):
    # Parameters
    dtype = torch.float32
    num_heads = 3        # Number of attention heads
    seq_length_q = 32768 # Sequence length
    seq_length_k = 32768
    d_model = 256        # Dimension of each head (must be 16, 32, 64, or 128)

    # Randomly initialize inputs
    q = torch.rand((seq_length_q // world_size, num_heads * d_model), dtype=dtype, device=f"cuda:{rank}")
    k = torch.rand((seq_length_k // world_size, num_heads * d_model), dtype=dtype, device=f"cuda:{rank}")
    l = torch.ones([], dtype=dtype, device=f"cuda:{rank}") * np.log(1 / 0.07)

    q = F.normalize(q, p=2, dim=-1).requires_grad_() # Query
    k = F.normalize(k, p=2, dim=-1).requires_grad_() # Key
    l = l.requires_grad_() # Logit scale

    return q, k, l


if __name__ == "__main__":
    # Assume that the distributed environment has been initialized
    dist.init_process_group("nccl")

    rank = dist.get_rank()
    world_size = dist.get_world_size()

    torch.cuda.set_device(rank)

    # Exampled by Image-Text Contrastive Learning, q is the global image features, 
    # k is the text features, and l is the logit scale.
    q, k, l = create_cl_tensors(rank, world_size)

    # labels are diagonal elements by default. 
    # labels = torch.arange(q.shape[0])
    loss = cal_inf_loss(q, k, scale=l.exp())

    print(loss)

๐Ÿš€ Main Results

Memory Cost

* denotes adopting "data offload" strategy.

Max Supported Batch Size

Speed

Batch Size Scaling

Training with larger data scale needs larger batch size.

๐Ÿ—๏ธ Training & Evaluation

Quick Start

To facilitate further development on top of our codebase, we provide a quick-start guide on how to use Inf-CLIP to train a customized CLIP and evaluate the trained model on the mainstream clip benchmarks.

  1. Training Data Structure:
Inf-CLIP
โ”œโ”€โ”€ datasets
โ”‚   โ”œโ”€โ”€ cc3m/ # https://github.com/rom1504/img2dataset/blob/main/dataset_examples/cc3m.md
|   |   โ”œโ”€โ”€ 0000.tar
|   |   โ”œโ”€โ”€ 0001.tar
|   |   โ”œโ”€โ”€ ...
|   |   โ””โ”€โ”€ 0301.tar
โ”‚   โ”œโ”€โ”€ cc12m/ # https://github.com/rom1504/img2dataset/blob/main/dataset_examples/cc12m.md
|   |   โ”œโ”€โ”€ 0000.tar
|   |   โ”œโ”€โ”€ 0001.tar
|   |   โ”œโ”€โ”€ ...
|   |   โ””โ”€โ”€ 1044.tar
โ”‚   โ”œโ”€โ”€ laion400m/ # https://github.com/rom1504/img2dataset/blob/main/dataset_examples/laion400m.md
|   |   โ”œโ”€โ”€ 00000.tar
|   |   โ”œโ”€โ”€ 00001.tar
|   |   โ”œโ”€โ”€ ...
|   |   โ””โ”€โ”€ 41407.tar
  1. Command:
bash scripts/cc3m/lit_vit-b-32_bs16k.sh
bash scripts/cc12m/lit_vit-b-32_bs32k.sh
bash scripts/laion400m/lit_vit-b-32_bs256k.sh
  1. Evaluation Data Structure:
Inf-CLIP
โ”œโ”€โ”€ datasets
โ”‚   โ”œโ”€โ”€ imagenet-1k/ # download val_images.tar.gz of imagenet
|   |   โ””โ”€โ”€ val/
|   |   |   โ”œโ”€โ”€ n01440764
|   |   |   โ”œโ”€โ”€ n01443537
|   |   |   โ”œโ”€โ”€ ...
|   |   |   โ””โ”€โ”€ n15075141
โ”‚   โ”œโ”€โ”€ clip-benchmark/ # bash datasets/benchmarks_download.sh
|   |   โ”œโ”€โ”€ wds_mscoco_captions
|   |   โ”œโ”€โ”€ wds_flickr8k
|   |   โ”œโ”€โ”€ wds_flickr30k
|   |   โ”œโ”€โ”€ wds_imagenet1k
|   |   โ”œโ”€โ”€ wds_imagenetv2
|   |   โ”œโ”€โ”€ wds_imagenet_sketch
|   |   โ”œโ”€โ”€ wds_imagenet-a
|   |   โ”œโ”€โ”€ wds_imagenet-r
|   |   โ”œโ”€โ”€ wds_imagenet-o
|   |   โ””โ”€โ”€ wds_objectnet
  1. Command:
# imagenet evaluation
bash scripts/imagenet_eval.sh
# overall evaluation
bash scripts/benchmarks_eval.sh

๐Ÿ“‘ Citation

If you find Inf-CLIP useful for your research and applications, please cite using this BibTeX:

@article{damonlpsg2024infcl,
  title={Breaking the Memory Barrier: Near Infinite Batch Size Scaling for Contrastive Loss},
  author={Zesen Cheng, Hang Zhang, Kehan Li, Sicong Leng, Zhiqiang Hu, Fei Wu, Deli Zhao, Xin Li, Lidong Bing},
  journal={arXiv preprint arXiv:},
  year={2024},
  url = {https://arxiv.org/abs/}
}

๐Ÿ‘ Acknowledgement

The codebase of Inf-CLIP is adapted from OpenCLIP. We are also grateful for the following projects our Inf-CL arise from:

๐Ÿ”’ License

This project is released under the Apache 2.0 license as found in the LICENSE file. The service is a research preview intended for non-commercial use ONLY, subject to the model Licenses of CLIP, Terms of Use of the data generated by OpenAI, and Laion. Please get in touch with us if you find any potential violations.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

inf_cl-1.1.tar.gz (16.6 kB view details)

Uploaded Source

Built Distribution

inf_cl-1.1-py3-none-any.whl (15.3 kB view details)

Uploaded Python 3

File details

Details for the file inf_cl-1.1.tar.gz.

File metadata

  • Download URL: inf_cl-1.1.tar.gz
  • Upload date:
  • Size: 16.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.8.10

File hashes

Hashes for inf_cl-1.1.tar.gz
Algorithm Hash digest
SHA256 ffd70ee6d269165add0e1c9a1bb2dce5220cda467ceca521a6618d534da73790
MD5 93d08b4f4bb851444add8f721587bc64
BLAKE2b-256 21d199e8ea570a8bf37cc1224f9f5e8396a48446f3ec07f2e7821b8e87151c98

See more details on using hashes here.

File details

Details for the file inf_cl-1.1-py3-none-any.whl.

File metadata

  • Download URL: inf_cl-1.1-py3-none-any.whl
  • Upload date:
  • Size: 15.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.8.10

File hashes

Hashes for inf_cl-1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 2086c7cdff7033a5c0ea130b79891aa765124c52b24341eda4ec0dbd10558beb
MD5 256668bd06596718833605db8e7fa698
BLAKE2b-256 e2e31753a3bb8fbb4d29732d5235c19e1cf7d6446f172cd8b9d646c4323da7e4

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page