Skip to main content

Training of CLIP in JAX

Project description

CLIP-JAX

This repository is used to train vision models with JAX:

  • many types of model architectures
  • any sharding strategy
  • training with constrastive loss such as CLIP, chunked sigmoid loss or captioning loss such as CapPa
  • downstream fine-tuning

Refer to the report "CapPa: Training vision models as captioners" for the open-source reproduction of CapPa.

Installation

pip install clip-jax

Note: this package is currently under active development, install from source for latest version.

Usage

Use a trained model

Refer to utils/demo_cappa.ipynb.

Open In Colab

You can find the model weights on Hugging Face.

Download training data

You can download training data from DataComp:

# clone and install datacomp

# download data
python download_upstream.py \
    --scale small --data_dir gs://my_bucket/datacomp/small metadata_dir metadata \
    --image_size 256 --resize_mode center_crop --skip_bbox_blurring --no_resize_only_if_bigger \
    --encode_format webp --output_format tfrecord

Alternatively, you can use your own dataset. In that case you should use img2dataset with output_format="tfrecord".

Train a model

Use training/train.py to train a model:

Here is an example command to train a model on a TPU v3-8:

python train.py \
    --assert_TPU_available \
    --config_name ../configs/small-patch16.json --dtype float32 \
    --do_train --train_folder gs://my_bucket/datacomp/small/shards \
    --output_dir gs://my_bucket/clip_model/$(date +"%Y%m%d%H%M%S") \
    --num_train_epochs 10 \
    --tokenizer_name openai/clip-vit-base-patch32 \
    --batch_size_per_node 4096 --gradient_accumulation_steps 1 \
    --learning_rate 0.00001 --warmup_steps 2000 --lr_offset 0 \
    --optim distributed_shampoo --beta1 0.9 --beta2 0.99 --weight_decay 0.0 \
    --block_size_text 512 --block_size_vision 512 --nesterov \
    --graft_type rmsprop_normalized --preconditioning_compute_steps 20 \
    --mp_devices 1 --shard_shampoo_across 2d \
    --activation_partitioning_dims 1 --parameter_partitioning_dims 1 \
    --loss_type sigmoid \
    --gradient_checkpointing \
    --unroll 100 \
    --logging_steps 100 --save_steps 5000

Acknowledgements

Citations

@misc{radford2021learning,
      title={Learning Transferable Visual Models From Natural Language Supervision},
      author={Alec Radford and Jong Wook Kim and Chris Hallacy and Aditya Ramesh and Gabriel Goh and Sandhini Agarwal and Girish Sastry and Amanda Askell and Pamela Mishkin and Jack Clark and Gretchen Krueger and Ilya Sutskever},
      year={2021},
      eprint={2103.00020},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}
@misc{zhai2023sigmoid,
      title={Sigmoid Loss for Language Image Pre-Training},
      author={Xiaohua Zhai and Basil Mustafa and Alexander Kolesnikov and Lucas Beyer},
      year={2023},
      eprint={2303.15343},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}
@misc{zhai2022scaling,
      title={Scaling Vision Transformers}, 
      author={Xiaohua Zhai and Alexander Kolesnikov and Neil Houlsby and Lucas Beyer},
      year={2022},
      eprint={2106.04560},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}
@misc{tschannen2023image,
      title={Image Captioners Are Scalable Vision Learners Too}, 
      author={Michael Tschannen and Manoj Kumar and Andreas Steiner and Xiaohua Zhai and Neil Houlsby and Lucas Beyer},
      year={2023},
      eprint={2306.07915},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}
@misc{darcet2023vision,
      title={Vision Transformers Need Registers}, 
      author={Timothée Darcet and Maxime Oquab and Julien Mairal and Piotr Bojanowski},
      year={2023},
      eprint={2309.16588},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}
@misc{dehghani2023patch,
      title={Patch n' Pack: NaViT, a Vision Transformer for any Aspect Ratio and Resolution}, 
      author={Mostafa Dehghani and Basil Mustafa and Josip Djolonga and Jonathan Heek and Matthias Minderer and Mathilde Caron and Andreas Steiner and Joan Puigcerver and Robert Geirhos and Ibrahim Alabdulmohsin and Avital Oliver and Piotr Padlewski and Alexey Gritsenko and Mario Lučić and Neil Houlsby},
      year={2023},
      eprint={2307.06304},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}
@misc{mckinzie2024mm1,
      title={MM1: Methods, Analysis & Insights from Multimodal LLM Pre-training}, 
      author={Brandon McKinzie and Zhe Gan and Jean-Philippe Fauconnier and Sam Dodge and Bowen Zhang and Philipp Dufter and Dhruti Shah and Xianzhi Du and Futang Peng and Floris Weers and Anton Belyi and Haotian Zhang and Karanjeet Singh and Doug Kang and Ankur Jain and Hongyu Hè and Max Schwarzer and Tom Gunter and Xiang Kong and Aonan Zhang and Jianyu Wang and Chong Wang and Nan Du and Tao Lei and Sam Wiseman and Guoli Yin and Mark Lee and Zirui Wang and Ruoming Pang and Peter Grasch and Alexander Toshev and Yinfei Yang},
      year={2024},
      eprint={2403.09611},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}
@misc{hsieh2023sugarcrepefixinghackablebenchmarks,
      title={SugarCrepe: Fixing Hackable Benchmarks for Vision-Language Compositionality}, 
      author={Cheng-Yu Hsieh and Jieyu Zhang and Zixian Ma and Aniruddha Kembhavi and Ranjay Krishna},
      year={2023},
      eprint={2306.14610},
      archivePrefix={arXiv},
      primaryClass={cs.CV},
      url={https://arxiv.org/abs/2306.14610}, 
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

clip_jax-0.0.6.tar.gz (102.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

clip_jax-0.0.6-py2.py3-none-any.whl (134.4 kB view details)

Uploaded Python 2Python 3

File details

Details for the file clip_jax-0.0.6.tar.gz.

File metadata

  • Download URL: clip_jax-0.0.6.tar.gz
  • Upload date:
  • Size: 102.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.1.dev22+gc1c02d1 CPython/3.12.6

File hashes

Hashes for clip_jax-0.0.6.tar.gz
Algorithm Hash digest
SHA256 8b1a0f9460b6ad36065ca561a220d73db116c490ef7437116bbac2f8bf0caa15
MD5 7f7294d5b04bd237f631eaad7c660827
BLAKE2b-256 88194f389ba2a6375c726e2bb017dff5ae20cdc75fe2d6843c210fd900ebe93e

See more details on using hashes here.

File details

Details for the file clip_jax-0.0.6-py2.py3-none-any.whl.

File metadata

  • Download URL: clip_jax-0.0.6-py2.py3-none-any.whl
  • Upload date:
  • Size: 134.4 kB
  • Tags: Python 2, Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.1.dev22+gc1c02d1 CPython/3.12.6

File hashes

Hashes for clip_jax-0.0.6-py2.py3-none-any.whl
Algorithm Hash digest
SHA256 a18275f7cb1bbc6cd3a3bef2751fd12aad7aea0ffa6d70b1287a30f4a0556d45
MD5 f751d14fd36917b0893045c30ae9c8ff
BLAKE2b-256 65cec9e38e41fb5161c106999681bca7ea02ec4a777976d99dfcae0ccc3552cd

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page