Easiest way of fine-tuning HuggingFace video classification models.
Project description
Easiest way of fine-tuning HuggingFace video classification models.
🚀 Features
video-transformers
uses:
-
🤗 accelerate for distributed training,
-
🤗 evaluate for evaluation,
-
pytorchvideo for dataloading
and supports:
-
creating and fine-tunining video models using transformers and timm vision models
-
experiment tracking with layer, neptune, tensorboard and other trackers
-
exporting fine-tuned models in ONNX format
-
pushing fine-tuned models into HuggingFace Hub
-
loading pretrained models from HuggingFace Hub
⌛ Incoming Features
-
Automated Gradio app, and space creation
-
Layer Hub support
🏁 Installation
- Install
Pytorch
:
conda install pytorch=1.11.0 torchvision=0.12.0 cudatoolkit=11.3 -c pytorch
- Install
video-transformers
:
pip install video-transformers
🔥 Usage
- Prepare video classification dataset in such folder structure (.avi and .mp4 extensions are supported):
train_root
label_1
video_1
video_2
...
label_2
video_1
video_2
...
...
val_root
label_1
video_1
video_2
...
label_2
video_1
video_2
...
...
- Fine-tune CVT (from HuggingFace) + Transformer based video classifier:
from torch.optim import AdamW
from video_transformers import TimeDistributed, VideoClassificationModel
from video_transformers.backbones.transformers import TransformersBackbone
from video_transformers.data import VideoDataModule
from video_transformers.heads import LinearHead
from video_transformers.necks import TransformerNeck
from video_transformers.trainer import trainer_factory
backbone = TimeDistributed(TransformersBackbone("microsoft/cvt-13", num_unfrozen_stages=0))
neck = TransformerNeck(
num_features=backbone.num_features,
num_timesteps=8,
transformer_enc_num_heads=4,
transformer_enc_num_layers=2,
dropout_p=0.1,
)
optimizer = AdamW(model.parameters(), lr=1e-4)
datamodule = VideoDataModule(
train_root=".../ucf6/train",
val_root=".../ucf6/val",
clip_duration=2,
train_dataset_multiplier=1,
batch_size=4,
num_workers=4,
num_timesteps=8,
preprocess_input_size=224,
preprocess_means=backbone.mean,
preprocess_stds=backbone.std,
preprocess_min_short_side_scale=256,
preprocess_max_short_side_scale=320,
preprocess_horizontal_flip_p=0.5,
)
head = LinearHead(hidden_size=neck.num_features, num_classes=datamodule.num_classes)
model = VideoClassificationModel(backbone, head, neck)
Trainer = trainer_factory("single_label_classification")
trainer = Trainer(
datamodule,
model,
optimizer=optimizer
)
trainer.fit()
- Fine-tune MobileViT (from Timm) + GRU based video classifier:
from video_transformers import TimeDistributed, VideoClassificationModel
from video_transformers.backbones.timm import TimmBackbone
from video_transformers.data import VideoDataModule
from video_transformers.heads import LinearHead
from video_transformers.necks import GRUNeck
from video_transformers.trainer import trainer_factory
backbone = TimeDistributed(TimmBackbone("mobilevitv2_100", num_unfrozen_stages=0))
neck = GRUNeck(num_features=backbone.num_features, hidden_size=128, num_layers=2, return_last=True)
datamodule = VideoDataModule(
train_root=".../ucf6/train",
val_root=".../ucf6/val",
clip_duration=2,
train_dataset_multiplier=1,
batch_size=4,
num_workers=4,
num_timesteps=8,
preprocess_input_size=224,
preprocess_means=backbone.mean,
preprocess_stds=backbone.std,
preprocess_min_short_side_scale=256,
preprocess_max_short_side_scale=320,
preprocess_horizontal_flip_p=0.5,
)
head = LinearHead(hidden_size=neck.hidden_size, num_classes=datamodule.num_classes)
model = VideoClassificationModel(backbone, head, neck)
Trainer = trainer_factory("single_label_classification")
trainer = Trainer(
datamodule,
model,
)
trainer.fit()
🤗 Full HuggingFace Integration
- Push your fine-tuned model to the hub:
from video_transformers import VideoClassificationModel
model = VideoClassificationModel.from_pretrained("runs/exp/checkpoint")
model.push_to_hub('model_name')
- Load any pretrained video-transformer model from the hub:
from video_transformers import VideoClassificationModel
model = VideoClassificationModel.from_pretrained("runs/exp/checkpoint")
model.from_pretrained('account_name/model_name')
- (Incoming feature) automatically Gradio app Huggingface Space:
from video_transformers import VideoClassificationModel
model = VideoClassificationModel.from_pretrained("runs/exp/checkpoint")
model.push_to_space('account_name/app_name')
📈 Multiple tracker support
-
Tensorboard tracker is enabled by default.
-
To add Neptune/Layer ... tracking:
from video_transformers.tracking import NeptuneTracker
from accelerate.tracking import WandBTracker
trackers = [
NeptuneTracker(EXPERIMENT_NAME, api_token=NEPTUNE_API_TOKEN, project=NEPTUNE_PROJECT),
WandBTracker(project_name=WANDB_PROJECT)
]
trainer = Trainer(
datamodule,
model,
trackers=trackers
)
🕸️ ONNX support
- Convert your trained models into ONNX format for deployment:
from video_transformers import VideoClassificationModel
model = VideoClassificationModel.from_pretrained("runs/exp/checkpoint")
model.to_onnx(quantize=False, opset_version=12, export_dir="runs/exports/", export_filename="model.onnx")
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file video-transformers-0.0.5.tar.gz
.
File metadata
- Download URL: video-transformers-0.0.5.tar.gz
- Upload date:
- Size: 24.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.10.6
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 6529b8abe23ab2fbb010c0a840b1fa2d1b0a295fd72719a61d4fbd2d401d5d4f |
|
MD5 | 9ed037b62db9d1656ab83ae30a6995aa |
|
BLAKE2b-256 | 061f803b4f84f9594b311b97975470b32258490d48dd5e2b62e0ef2d3eb39acd |
File details
Details for the file video_transformers-0.0.5-py3-none-any.whl
.
File metadata
- Download URL: video_transformers-0.0.5-py3-none-any.whl
- Upload date:
- Size: 28.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.10.6
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 4ba13b720444137b7ded2daad64d52f553763a79d5cea5504b02ea7f29e8c7f5 |
|
MD5 | 0aaba94b2737c4d7247afb25b0bd05f3 |
|
BLAKE2b-256 | 8c572b812b4962dd2ce7f01596e3554f31cc2dc94b589ba7c83e1d66cc087dd1 |