Skip to main content

A small example package

Project description

DisJR Networks: Disjointed Representation Learning for Better Fall Recognition

DisJR(Disjointing Representation) is an effective and simple computational unit that disjoints human from unwanted elements(e.g., background) in the video scene without any hints about the human region. Our proposed DisJR operation is designed to reflect relations between human and various surrounding contexts from data itself, not preprocessed data. In contrast to the existing methods that uses preprocessed data for the human region, the proposed DisJR operations do not rely on the fixed region. Instead, the proposed method learns how to separate representations of human region and unwanted elements through explicit feature-level decomposition, i.e., DisJR. In this way, the model grasps more general representations about the video scene.

Model overview

model_overview

Example

Here is code example for using pip-downloaded DisJRNet:

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from disjrnet.model.models import DisJRNet
from disjrnet.model.loss import compute_loss

alpha           =   2.0       # hyperparameter
fusion_method   =   'gating'   # candidates = 'gating' | 'gconv'

# 2D CNN
model = DisJRNet(num_classes    =   10,
                base_model      =   'resnet50',
                dimension       =   2,
                dropout         =   0.8,
                margin          =   alpha,
                fusion_method   =   fusion_method)

# # 3D CNN
# model = DisJRNet(num_classes    =   10,
#                 base_model      =   'r2plus1d_18',
#                 dimension       =   3,
#                 dropout         =   0.8,
#                 margin          =   alpha,
#                 fusion_method   =   fusion_method)

# classification loss = CE
criterion = nn.CrossEntropyLoss()

# dummy data example
inps = torch.randn(10, 3, 112, 112)
tgts = torch.arange(10, dtype=torch.float32).view(10,-1)

dataset = TensorDataset(inps, tgts)
loader = DataLoader(dataset, batch_size=8)
loader_iter = iter(loader)

inputs, target = next(loader_iter)

logits = model(inputs)

loss = compute_loss(model, criterion, logits, target)
pred = logits.argmax(1)

print(f"loss : {loss:.4f}, pred : {pred}, target : {target.view(-1)}")

Training scripts

First, you need to change directory to disjrnet

cd disjrnet

Here are script examples for training available model in this project:

  • DisJRNet
# FDD
python main.py --dataset FDD --root <dataset_root> --output_path <checkpoint_dir> --num_classes 2 --drop_rate 0.8 --base_model r2plus1d_18 --fusion_method gating --n_fold 5 --batch_size 8 --epochs 25 --sample_length 10 --num_workers 8 --monitor val_f1 --lr 1e-4 --c 5.0 --arch DisJRNet

# URFD
python main.py --dataset URFD --root <dataset_root> --output_path <checkpoint_dir> --num_classes 2 --drop_rate 0.8 --base_model r2plus1d_18 --fusion_method gating --n_fold 5 --batch_size 8 --epochs 25 --sample_length 10 --num_workers 8 --monitor val_f1 --lr 1e-4 --c 2.0 --arch DisJRNet
  • Baseline
# FDD
python main.py --dataset FDD --root <dataset_root> --output_path <checkpoint_dir> --num_classes 2 --drop_rate 0.8 --base_model r2plus1d_18 --n_fold 5 --batch_size 8 --epochs 25 --sample_length 10 --num_workers 8 --monitor val_f1 --lr 1e-4 --arch Baseline

# URFD
python main.py --dataset URFD --root <dataset_root> --output_path <checkpoint_dir> --num_classes 2 --drop_rate 0.8 --base_model r2plus1d_18 --n_fold 5 --batch_size 8 --epochs 25 --sample_length 10 --num_workers 8 --monitor val_f1 --lr 1e-4 --arch Baseline

Results

result_table

Activation Map Visualization

activation

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

disjrnet_pytorch-0.1.2.tar.gz (10.7 kB view details)

Uploaded Source

Built Distribution

disjrnet_pytorch-0.1.2-py3-none-any.whl (12.8 kB view details)

Uploaded Python 3

File details

Details for the file disjrnet_pytorch-0.1.2.tar.gz.

File metadata

  • Download URL: disjrnet_pytorch-0.1.2.tar.gz
  • Upload date:
  • Size: 10.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.8.0 pkginfo/1.8.2 readme-renderer/34.0 requests/2.26.0 requests-toolbelt/0.9.1 urllib3/1.26.7 tqdm/4.62.3 importlib-metadata/4.8.1 keyring/23.4.1 rfc3986/1.5.0 colorama/0.4.4 CPython/3.6.9

File hashes

Hashes for disjrnet_pytorch-0.1.2.tar.gz
Algorithm Hash digest
SHA256 271b39e73f27a05810e528b5181419f0137e982fba581adc5136a1081f236d49
MD5 88bf70acea978ec957662a7b9f44cbb5
BLAKE2b-256 d937711fc8cdea826995438abf40d3b49e9c775fe1341175875f3ea8a8c339b6

See more details on using hashes here.

File details

Details for the file disjrnet_pytorch-0.1.2-py3-none-any.whl.

File metadata

  • Download URL: disjrnet_pytorch-0.1.2-py3-none-any.whl
  • Upload date:
  • Size: 12.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.8.0 pkginfo/1.8.2 readme-renderer/34.0 requests/2.26.0 requests-toolbelt/0.9.1 urllib3/1.26.7 tqdm/4.62.3 importlib-metadata/4.8.1 keyring/23.4.1 rfc3986/1.5.0 colorama/0.4.4 CPython/3.6.9

File hashes

Hashes for disjrnet_pytorch-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 3d9c385f0a4c5b9fe66523cbdb74ef42032acf64d8037d8b07321779aaa7524d
MD5 ad56491f5e04612c156e2a34a81b030c
BLAKE2b-256 65971701276171acd008ab3834f5a3b3f0e53feab6a663d2b74c618933db8a03

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page