Skip to main content

Training of CLIP in JAX

Project description

CLIP-JAX

This repository is used to CLIP models from 🤗 transformers using JAX.

Installation

pip install -e .

Usage

  1. Use dataset/prepare_dataset.ipynb to prepare your dataset.
  2. Train the model with training/train_clip.py.

Supported downstream tasks

  • Image classification with FlaxCLIPVisionModelForImageClassification

TODO

  • Add guides
  • Add pre-trained models
  • Add more downstream tasks

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

clip_jax-0.0.1.tar.gz (22.7 kB view hashes)

Uploaded Source

Built Distribution

clip_jax-0.0.1-py2.py3-none-any.whl (24.5 kB view hashes)

Uploaded Python 2 Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page