I3D in Jax
Project description
I3D-Jax
Jax / Flax port of the original Kinetics-400 I3D network from TF
Installation
pip install i3d-jax
Usage
For convenience, we provide a wrapper to run inference on input videos
import i3d_jax
import numpy as np
video = np.random.randn(1, 16, 224, 224, 3) # B x T x H x W x C in [-1, 1]
i3d = i3d_jax.I3DWrapper(replicate=False) # set to True to auto-use pmap
# out returns a tuple of:
# 1) logits
# 2) a dictionary mapping endpoint names to features at each endpoint
out = i3d(video)
You can separate get the model and variables through:
import i3d_jax
# Load model
i3d_model = i3d_jax.InceptionI3d()
# Load variables (params + batch_stats)
variables = i3d_jax.load_variables(replicate=False)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
i3d_jax-0.0.1.tar.gz
(47.2 MB
view details)
File details
Details for the file i3d_jax-0.0.1.tar.gz
.
File metadata
- Download URL: i3d_jax-0.0.1.tar.gz
- Upload date:
- Size: 47.2 MB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.8.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 2e8e4b796d6d54fcbd8fb2aae71687db1be0f6ba6c364c75c560e9731f842867 |
|
MD5 | b35e50e2f9e51ecbd0f261d0275bfeb4 |
|
BLAKE2b-256 | bcd0cf1c881e3d30fe319092c398f9d4663f0822b36e2ec5ee25c538dd953a96 |