Skip to main content

Easiest way of fine-tuning HuggingFace video classification models.

Project description

Easiest way of fine-tuning HuggingFace video classification models.

🚀 Features

video-transformers uses:

and supports:

⌛ Incoming Features

🏁 Installation

  • Install Pytorch:
conda install pytorch=1.11.0 torchvision=0.12.0 cudatoolkit=11.3 -c pytorch
  • Install video-transformers:
pip install video-transformers

🔥 Usage

  • Prepare video classification dataset in such folder structure (.avi and .mp4 extensions are supported):
train_root
    label_1
        video_1
        video_2
        ...
    label_2
        video_1
        video_2
        ...
    ...
val_root
    label_1
        video_1
        video_2
        ...
    label_2
        video_1
        video_2
        ...
    ...
  • Fine-tune CVT (from HuggingFace) + Transformer based video classifier:
from torch.optim import AdamW
from video_transformers import TimeDistributed, VideoClassificationModel
from video_transformers.backbones.transformers import TransformersBackbone
from video_transformers.data import VideoDataModule
from video_transformers.heads import LinearHead
from video_transformers.necks import TransformerNeck
from video_transformers.trainer import trainer_factory

backbone = TimeDistributed(TransformersBackbone("microsoft/cvt-13", num_unfrozen_stages=0))
neck = TransformerNeck(
    num_features=backbone.num_features,
    num_timesteps=8,
    transformer_enc_num_heads=4,
    transformer_enc_num_layers=2,
    dropout_p=0.1,
)
optimizer = AdamW(model.parameters(), lr=1e-4)

datamodule = VideoDataModule(
    train_root=".../ucf6/train",
    val_root=".../ucf6/val",
    clip_duration=2,
    train_dataset_multiplier=1,
    batch_size=4,
    num_workers=4,
    num_timesteps=8,
    preprocess_input_size=224,
    preprocess_means=backbone.mean,
    preprocess_stds=backbone.std,
    preprocess_min_short_side_scale=256,
    preprocess_max_short_side_scale=320,
    preprocess_horizontal_flip_p=0.5,
)

head = LinearHead(hidden_size=neck.num_features, num_classes=datamodule.num_classes)
model = VideoClassificationModel(backbone, head, neck)

Trainer = trainer_factory("single_label_classification")
trainer = Trainer(
    datamodule,
    model,
    optimizer=optimizer
)

trainer.fit()
  • Fine-tune MobileViT (from Timm) + GRU based video classifier:
from video_transformers import TimeDistributed, VideoClassificationModel
from video_transformers.backbones.timm import TimmBackbone
from video_transformers.data import VideoDataModule
from video_transformers.heads import LinearHead
from video_transformers.necks import GRUNeck
from video_transformers.trainer import trainer_factory

backbone = TimeDistributed(TimmBackbone("mobilevitv2_100", num_unfrozen_stages=0))
neck = GRUNeck(num_features=backbone.num_features, hidden_size=128, num_layers=2, return_last=True)

datamodule = VideoDataModule(
    train_root=".../ucf6/train",
    val_root=".../ucf6/val",
    clip_duration=2,
    train_dataset_multiplier=1,
    batch_size=4,
    num_workers=4,
    num_timesteps=8,
    preprocess_input_size=224,
    preprocess_means=backbone.mean,
    preprocess_stds=backbone.std,
    preprocess_min_short_side_scale=256,
    preprocess_max_short_side_scale=320,
    preprocess_horizontal_flip_p=0.5,
)

head = LinearHead(hidden_size=neck.hidden_size, num_classes=datamodule.num_classes)
model = VideoClassificationModel(backbone, head, neck)

Trainer = trainer_factory("single_label_classification")
trainer = Trainer(
    datamodule,
    model,
)

trainer.fit()

🤗 Full HuggingFace Integration

  • Push your fine-tuned model to the hub:
from video_transformers import VideoClassificationModel

model = VideoClassificationModel.from_pretrained("runs/exp/checkpoint")

model.push_to_hub('model_name')
  • Load any pretrained video-transformer model from the hub:
from video_transformers import VideoClassificationModel

model = VideoClassificationModel.from_pretrained("runs/exp/checkpoint")

model.from_pretrained('account_name/model_name')
  • (Incoming feature) automatically Gradio app Huggingface Space:
from video_transformers import VideoClassificationModel

model = VideoClassificationModel.from_pretrained("runs/exp/checkpoint")
model.push_to_space('account_name/app_name')

📈 Multiple tracker support

  • Tensorboard tracker is enabled by default.

  • To add Neptune/Layer ... tracking:

from video_transformers.tracking import NeptuneTracker
from accelerate.tracking import WandBTracker

trackers = [
    NeptuneTracker(EXPERIMENT_NAME, api_token=NEPTUNE_API_TOKEN, project=NEPTUNE_PROJECT),
    WandBTracker(project_name=WANDB_PROJECT)
]

trainer = Trainer(
    datamodule,
    model,
    trackers=trackers
)

🕸️ ONNX support

  • Convert your trained models into ONNX format for deployment:
from video_transformers import VideoClassificationModel

model = VideoClassificationModel.from_pretrained("runs/exp/checkpoint")
model.to_onnx(quantize=False, opset_version=12, export_dir="runs/exports/", export_filename="model.onnx")

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

video-transformers-0.0.5.tar.gz (24.7 kB view details)

Uploaded Source

Built Distribution

video_transformers-0.0.5-py3-none-any.whl (28.2 kB view details)

Uploaded Python 3

File details

Details for the file video-transformers-0.0.5.tar.gz.

File metadata

  • Download URL: video-transformers-0.0.5.tar.gz
  • Upload date:
  • Size: 24.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.10.6

File hashes

Hashes for video-transformers-0.0.5.tar.gz
Algorithm Hash digest
SHA256 6529b8abe23ab2fbb010c0a840b1fa2d1b0a295fd72719a61d4fbd2d401d5d4f
MD5 9ed037b62db9d1656ab83ae30a6995aa
BLAKE2b-256 061f803b4f84f9594b311b97975470b32258490d48dd5e2b62e0ef2d3eb39acd

See more details on using hashes here.

File details

Details for the file video_transformers-0.0.5-py3-none-any.whl.

File metadata

File hashes

Hashes for video_transformers-0.0.5-py3-none-any.whl
Algorithm Hash digest
SHA256 4ba13b720444137b7ded2daad64d52f553763a79d5cea5504b02ea7f29e8c7f5
MD5 0aaba94b2737c4d7247afb25b0bd05f3
BLAKE2b-256 8c572b812b4962dd2ce7f01596e3554f31cc2dc94b589ba7c83e1d66cc087dd1

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page