Skip to main content

TransNetV2 PyTorch implementation for video scene detection

Project description

TransNet V2: Shot Boundary Detection Neural Network (PyTorch)

This repository contains a PyTorch implementation of TransNet V2: An effective deep network architecture for fast shot transition detection.

This is a PyTorch reimplementation of the TransNetV2 model that produces identical results as the original TensorFlow version. The code is for inference only.

Performance

Our reevaluation of other publicly available state-of-the-art shot boundary methods (F1 scores):

Model ClipShots BBC Planet Earth RAI
TransNet V2 77.9 96.2 93.9
TransNet (github) 73.5 92.9 94.3
Hassanien et al. (github) 75.9 92.6 93.9
Tang et al., ResNet baseline (github) 76.1 89.3 92.8

Installation

pip install transnetv2-pytorch

Or install from source:

git clone https://github.com/allenday/transnetv2_pytorch.git
cd transnetv2_pytorch
pip install -e .

Usage

Command Line Interface

The package provides both a direct command and Python module execution:

# Direct command
transnetv2_pytorch path/to/video.mp4

# Python module execution
python -m transnetv2_pytorch path/to/video.mp4

CLI Arguments

# Basic usage
transnetv2_pytorch path/to/video.mp4

# Specify output file
transnetv2_pytorch path/to/video.mp4 --output predictions.txt

# Use specific device
transnetv2_pytorch path/to/video.mp4 --device cuda

# Get help for all options
transnetv2_pytorch --help

Python API

Basic Usage

import torch
from transnetv2_pytorch import TransNetV2

# Initialize model with automatic device detection
model = TransNetV2(device='auto')  # Automatically selects best device
model.eval()

# Load weights
state_dict = torch.load("transnetv2-pytorch-weights.pth", map_location=model.device)
model.load_state_dict(state_dict)

with torch.no_grad():
    # Basic prediction (returns raw data)
    video_frames, single_frame_pred, all_frame_pred = model.predict_video("video.mp4")
    
    # Enhanced prediction with rich scene metadata
    results = model.predict_video_with_scenes("video.mp4", threshold=0.5)
    
    print(f"Video FPS: {results['fps']}")
    print(f"Total scenes: {results['total_scenes']}")
    
    # Access rich scene data with timestamps and shot IDs
    for scene in results['scenes'][:3]:
        print(f"Shot {scene['shot_id']}: "
              f"frames {scene['start_frame']}-{scene['end_frame']} "
              f"({scene['start_time']}s-{scene['end_time']}s) "
              f"probability={scene['probability']:.4f}")

Enhanced Features for Application Developers

The TransNetV2 class now provides rich functionality previously only available in the CLI:

# Automatic device detection
model = TransNetV2(device='auto')  # Chooses CUDA > MPS > CPU automatically

# Extract video FPS
fps = model.get_video_fps("video.mp4")

# Convert frame numbers to timestamps
timestamp = TransNetV2.frame_to_timestamp(frame_number=150, fps=25.0)
print(f"Frame 150 = {timestamp}s")

# Get structured scene data with metadata
scenes = model.predictions_to_scenes_with_data(
    predictions,  # numpy array or torch tensor
    fps=25.0,     # optional, can auto-extract from video
    threshold=0.5
)

# Each scene contains:
# - shot_id: Scene number (1-indexed)
# - start_frame, end_frame: Frame boundaries
# - start_time, end_time: Timestamps (if FPS available)
# - probability: Maximum probability in the scene

Advanced Usage

# Custom device handling
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = TransNetV2(device=device)

# Working with existing predictions
import numpy as np
predictions = np.array([...])  # Your existing predictions

# Convert to rich scene data
scenes = model.predictions_to_scenes_with_data(
    predictions, 
    video_path="video.mp4",  # Will auto-extract FPS
    threshold=0.5
)

# Or provide FPS directly
scenes = model.predictions_to_scenes_with_data(
    predictions, 
    fps=29.97,
    threshold=0.5
)

# Comprehensive video analysis
results = model.predict_video_with_scenes("video.mp4")
# Returns: video_frames, predictions, fps, scenes, total_scenes

Device Support

This implementation supports:

  • CPU: Works on all systems
  • CUDA: For NVIDIA GPUs
  • MPS: For Apple Silicon Macs (automatic fallback for unsupported operations)

The model automatically detects and uses the best available device. For MPS devices, unsupported operations (like 3D convolutions) automatically fall back to CPU.

Original Work & Training

This PyTorch implementation is based on the original TensorFlow version. For:

  • Training code and datasets
  • TensorFlow implementation
  • Weight conversion utilities
  • Research replication

Please visit the original repository: soCzech/TransNetV2

Credits

Original Work

This PyTorch implementation is based on the original TensorFlow TransNet V2 by Tomáš Souček and Jakub Lokoč.

If found useful, please cite the original work:

@article{soucek2020transnetv2,
    title={TransNet V2: An effective deep network architecture for fast shot transition detection},
    author={Sou{\v{c}}ek, Tom{\'a}{\v{s}} and Loko{\v{c}}, Jakub},
    year={2020},
    journal={arXiv preprint arXiv:2008.04838},
}

PyTorch Implementation

This production-ready PyTorch package was developed by [Your Name] with significant improvements including:

  • Complete PyTorch reimplementation for inference
  • Cross-platform device support (CPU, CUDA, MPS)
  • Command-line interface
  • Package distribution and installation
  • Comprehensive testing and error handling

Related Papers

License

MIT License

Original work Copyright (c) 2020 Tomáš Souček, Jakub Lokoč
PyTorch implementation Copyright (c) 2025 Allen Day

See the original TransNetV2 repository for the original license.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

transnetv2_pytorch-1.0.1.tar.gz (31.3 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

transnetv2_pytorch-1.0.1-py3-none-any.whl (31.3 MB view details)

Uploaded Python 3

File details

Details for the file transnetv2_pytorch-1.0.1.tar.gz.

File metadata

  • Download URL: transnetv2_pytorch-1.0.1.tar.gz
  • Upload date:
  • Size: 31.3 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.10.17

File hashes

Hashes for transnetv2_pytorch-1.0.1.tar.gz
Algorithm Hash digest
SHA256 f2b8abce464353289e61a28d89e602c7b1873c308e6d9a03266ede8e23d239e6
MD5 32d6913b04d80f38c65e8c843356591d
BLAKE2b-256 5e2cd64bf104b97f9eb0e5f29a3dcf2b11b377af74d19a7286bf9dbe793856e9

See more details on using hashes here.

File details

Details for the file transnetv2_pytorch-1.0.1-py3-none-any.whl.

File metadata

File hashes

Hashes for transnetv2_pytorch-1.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 8c2e2f671e1310c2af12491d0f3c108eaae77c42f75675fea793c0d2df272479
MD5 deb8f7d53417b91cf42d9aa7cf6d34c5
BLAKE2b-256 10aaa29fe9325dd0ba01a3f62d0a185ef9d1d67b31beb60a5f0f26435f77b86d

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page