A small example package
Project description
DisJR Networks: Disjointed Representation Learning for Better Fall Recognition
DisJR(Disjointing Representation) is an effective and simple computational unit that disjoints human from unwanted elements(e.g., background) in the video scene without any hints about the human region. Our proposed DisJR operation is designed to reflect relations between human and various surrounding contexts from data itself, not preprocessed data. In contrast to the existing methods that uses preprocessed data for the human region, the proposed DisJR operations do not rely on the fixed region. Instead, the proposed method learns how to separate representations of human region and unwanted elements through explicit feature-level decomposition, i.e., DisJR. In this way, the model grasps more general representations about the video scene.
Model overview
Examples
Here is code example for using DisJRNet:
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from disjrnet.model.models import DisJRNet
from disjrnet.model.loss import compute_loss
alpha = 2.0 # hyperparameter
fusion_method = 'gating' # candidates = 'gating' | 'gconv'
# 2D CNN
model = DisJRNet(num_classes = 10,
base_model = 'resnet50',
dimension = 2,
dropout = 0.8,
margin = alpha,
fusion_method = fusion_method)
# # 3D CNN
# model = DisJRNet(num_classes = 10,
# base_model = 'r2plus1d_18',
# dimension = 3,
# dropout = 0.8,
# margin = alpha,
# fusion_method = fusion_method)
# classification loss = CE
criterion = nn.CrossEntropyLoss()
# dummy data example
inps = torch.randn(10, 3, 112, 112)
tgts = torch.arange(10, dtype=torch.float32).view(10,-1)
dataset = TensorDataset(inps, tgts)
loader = DataLoader(dataset, batch_size=8)
loader_iter = iter(loader)
inputs, target = next(loader_iter)
logits = model(inputs)
loss = compute_loss(model, criterion, logits, target)
pred = logits.argmax(1)
print(f"loss : {loss:.4f}, pred : {pred}, target : {target.view(-1)}")
Training scripts
Here are script examples for training available model in this project:
- DisJRNet
# FDD
python main.py --dataset FDD --root <dataset_root> --output_path <checkpoint_dir> --num_classes 2 --drop_rate 0.8 --base_model r2plus1d_18 --fusion_method gating --n_fold 5 --batch_size 8 --epochs 25 --sample_length 10 --num_workers 8 --monitor val_f1 --lr 1e-4 --c 5.0 --arch DisJRNet --gpu_ids 0
# URFD
python main.py --dataset URFD --root <dataset_root> --output_path <checkpoint_dir> --num_classes 2 --drop_rate 0.8 --base_model r2plus1d_18 --fusion_method gating --n_fold 5 --batch_size 8 --epochs 25 --sample_length 10 --num_workers 8 --monitor val_f1 --lr 1e-4 --c 2.0 --arch DisJRNet --gpu_ids 0
- Baseline
# FDD
python main.py --dataset FDD --root <dataset_root> --output_path <checkpoint_dir> --num_classes 2 --drop_rate 0.8 --base_model r2plus1d_18 --n_fold 5 --batch_size 8 --epochs 25 --sample_length 10 --num_workers 8 --monitor val_f1 --lr 1e-4 --arch Baseline --gpu_ids 0
# URFD
python main.py --dataset URFD --root <dataset_root> --output_path <checkpoint_dir> --num_classes 2 --drop_rate 0.8 --base_model r2plus1d_18 --n_fold 5 --batch_size 8 --epochs 25 --sample_length 10 --num_workers 8 --monitor val_f1 --lr 1e-4 --arch Baseline --gpu_ids 0
Results
Activation Map Visualization
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
Built Distribution
File details
Details for the file disjrnet_pytorch-0.1.1-py3-none-any.whl
.
File metadata
- Download URL: disjrnet_pytorch-0.1.1-py3-none-any.whl
- Upload date:
- Size: 12.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.8.0 pkginfo/1.8.2 readme-renderer/34.0 requests/2.26.0 requests-toolbelt/0.9.1 urllib3/1.26.7 tqdm/4.62.3 importlib-metadata/4.8.1 keyring/23.4.1 rfc3986/1.5.0 colorama/0.4.4 CPython/3.6.9
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 47adaca84bfda4ecffa15fb21161f64f840430423aada482b206706e53cc5362 |
|
MD5 | eeedd4e8bc35f38ad35d79dbeb8666de |
|
BLAKE2b-256 | 3fa1e74448069539f63aaf1d39318e7610cc3e59f7f30d51c1c0c637f984fc37 |