Skip to main content

Easiest way of fine-tuning HuggingFace video classification models.

Project description

Easiest way of fine-tuning HuggingFace video classification models.

🚀 Features

video-transformers uses:

and supports:

⌛ Incoming Features

🏁 Installation

  • Install Pytorch:
conda install pytorch=1.11.0 torchvision=0.12.0 cudatoolkit=11.3 -c pytorch
  • Install video-transformers:
pip install video-transformers

🔥 Usage

  • Prepare video classification dataset in such folder structure (.avi and .mp4 extensions are supported):
train_root
    label_1
        video_1
        video_2
        ...
    label_2
        video_1
        video_2
        ...
    ...
val_root
    label_1
        video_1
        video_2
        ...
    label_2
        video_1
        video_2
        ...
    ...
  • Fine-tune CVT (from HuggingFace) + Transformer based video classifier:
from video_transformers import TimeDistributed, VideoClassificationModel
from video_transformers.backbones.transformers import TransformersBackbone
from video_transformers.data import VideoDataModule
from video_transformers.heads import LinearHead
from video_transformers.necks import TransformerNeck
from video_transformers.trainer import trainer_factory

backbone = TimeDistributed(TransformersBackbone("microsoft/cvt-13", num_unfrozen_stages=0))
neck = TransformerNeck(
    num_features=backbone.num_features,
    num_timesteps=8,
    transformer_enc_num_heads=4,
    transformer_enc_num_layers=2,
    dropout_p=0.1,
)

datamodule = VideoDataModule(
    train_root=".../ucf6/train",
    val_root=".../ucf6/val",
    clip_duration=2,
    train_dataset_multiplier=1,
    batch_size=4,
    num_workers=4,
    video_timesteps=8,
    video_crop_size=224,
    video_means=backbone.mean,
    video_stds=backbone.std,
    video_min_short_side_scale=256,
    video_max_short_side_scale=320,
    video_horizontal_flip_p=0.5,
)

head = LinearHead(hidden_size=neck.num_features, num_classes=datamodule.num_classes)
model = VideoClassificationModel(backbone, head, neck)

Trainer = trainer_factory("single_label_classification")
trainer = Trainer(
    datamodule,
    model,
)

trainer.fit()
  • Fine-tune MobileViT (from Timm) + GRU based video classifier:
from video_transformers import TimeDistributed, VideoClassificationModel
from video_transformers.backbones.timm import TimmBackbone
from video_transformers.data import VideoDataModule
from video_transformers.heads import LinearHead
from video_transformers.necks import GRUNeck
from video_transformers.trainer import trainer_factory

backbone = TimeDistributed(TimmBackbone("mobilevitv2_100", num_unfrozen_stages=0))
neck = GRUNeck(num_features=backbone.num_features, hidden_size=128, num_layers=2, return_last=True)

datamodule = VideoDataModule(
    train_root=".../ucf6/train",
    val_root=".../ucf6/val",
    clip_duration=2,
    train_dataset_multiplier=1,
    batch_size=4,
    num_workers=4,
    video_timesteps=8,
    video_crop_size=224,
    video_means=backbone.mean,
    video_stds=backbone.std,
    video_min_short_side_scale=256,
    video_max_short_side_scale=320,
    video_horizontal_flip_p=0.5,
)

head = LinearHead(hidden_size=neck.num_features, num_classes=datamodule.num_classes)
model = VideoClassificationModel(backbone, head, neck)

Trainer = trainer_factory("single_label_classification")
trainer = Trainer(
    datamodule,
    model,
)

trainer.fit()

🤗 Full HuggingFace Integration

  • Push your fine-tuned model to the hub:
from video_transformers import VideoClassificationModel

model = VideoClassificationModel.from_pretrained("runs/exp/checkpoint")

model.push_to_hub('model_name')
  • Load any pretrained video-transformer model from the hub:
from video_transformers import VideoClassificationModel

model = VideoClassificationModel.from_pretrained("runs/exp/checkpoint")

model.from_pretrained('account_name/model_name')
  • (Incoming feature) automatically Gradio app Huggingface Space:
from video_transformers import VideoClassificationModel

model = VideoClassificationModel.from_pretrained("runs/exp/checkpoint")
model.push_to_space('account_name/app_name')

📈 Multiple tracker support

  • Tensorboard tracker is enabled by default.

  • To add Neptune/Layer ... tracking:

from video_transformers.tracking import NeptuneTracker
from accelerate.tracking import WandBTracker

trackers = [
    NeptuneTracker(EXPERIMENT_NAME, api_token=NEPTUNE_API_TOKEN, project=NEPTUNE_PROJECT),
    WandBTracker(project_name=WANDB_PROJECT)
]

trainer = Trainer(
    datamodule,
    model,
    trackers=trackers
)

🕸️ ONNX support

  • Convert your trained models into ONNX format for deployment:
from video_transformers import VideoClassificationModel

model = VideoClassificationModel.from_pretrained("runs/exp/checkpoint")
model.to_onnx(quantize=False, opset_version=12, export_dir="runs/exports/", export_filename="model.onnx")

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

video-transformers-0.0.3.tar.gz (24.5 kB view details)

Uploaded Source

Built Distribution

video_transformers-0.0.3-py3-none-any.whl (28.0 kB view details)

Uploaded Python 3

File details

Details for the file video-transformers-0.0.3.tar.gz.

File metadata

  • Download URL: video-transformers-0.0.3.tar.gz
  • Upload date:
  • Size: 24.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.10.6

File hashes

Hashes for video-transformers-0.0.3.tar.gz
Algorithm Hash digest
SHA256 912678f1476f6c3e203c862ceb1a6242e4b07356fca1cbb2db6c360749a24681
MD5 3c618662554ae941e4261a7e4ff117c5
BLAKE2b-256 afc5d98948346119565cb045d363880dbb4f9e36bd631882ba95b6cd7a6591bd

See more details on using hashes here.

File details

Details for the file video_transformers-0.0.3-py3-none-any.whl.

File metadata

File hashes

Hashes for video_transformers-0.0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 12c54d038fcf1c3ff9533c313a7b60df052683140e896995f59ab1b11b61017b
MD5 41a0dc14b2d005d98ad402bcf3ee1526
BLAKE2b-256 cecebf47ac3a00a551950bb379386fda51984d59840904140be4fef821411286

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page