Skip to main content

Animal pose estimation using Vision Transformers and HRNet(VHR)

Project description

Welcome to our module!

Aja-pose

AJA-pose helps you train, validate and test your animal pose estimation model. Check out how we have done it in Google Colab. We have evaluated our model (PCK@0.05) and the mean accuracy for the 23 keypoints for our model is 93.073%
We recommend using at least Nvidia V100 GPU for faster inferencing but T4 GPUs will still work.


Open In Colab

Getting Started

git clone https://github.com/Antony-gitau/AJA-pose.git
cd AJA-pose
pip install -e .

Getting our model and the dataset. We have made our model public and can be downloaded here

import urllib.request

# # Get the dataset
url = "https://storage.googleapis.com/figures-gp/animal-kingdom/dataset.zip"
destination = "dataset.zip"

urllib.request.urlretrieve(url, destination)

# Unzip the file
!unzip dataset.zip

# The model file
url = "https://storage.googleapis.com/figures-gp/animal-kingdom/all_animals_no_pretrain_106.pth"
destination = "all_animals_no_pretrain_60.pth"

urllib.request.urlretrieve(url, destination)

Test our model
We require the path to the test images directory and the test.json file in MPII format

from aja_pose import Model

# path to the images directory and annotation in mpii json format
images_directory = '' # Path to the images directory
mpii_json = '' # Path to the test.json file
model_file = 'all_animals_no_pretrain_60.pth' # Path to the model file 

# Initialize the class
model = Model()
# Test the model on Protocol 1
model.test(images_directory, protocol='P1', model=model_file)
# Test the model on Protocol 2
model.test(images_directory, protocol='P2', model=model_file)
# Test the model on birds class Protocol 3
model.test(images_directory, protocol='P3', model=model_file, animal_class='bird')
# Test the model on reptiles class Protocol 3
model.test(images_directory, protocol='P3', model=model_file, animal_class='reptile')
# Test the model on mammals class Protocol 3
model.test(images_directory, protocol='P3', model=model_file, animal_class='mammal')
# Test the model on fish class Protocol 3
model.test(images_directory, protocol='P3', model=model_file, animal_class='fish')
# Test the model on amphibian class Protocol 3
model.test(images_directory, protocol='P3', model=model_file, animal_class='amphibian')

You can also start to train your model or pretrain on top of ours

# train a VHR model
train_json = '' # labels for the train set (train.json)
valid_json = '' # Labels for the validation set (test.json)
model_file = '' # A pytorch model file to pretrain on.
model.train(images_directory, train_json, valid_json, pretrained=model_file)

# Train a model on a particular class e.g (Ampibian)
model.train(images_directory, protocol='P3', animal_class='amphibian', model=model_file)

Results

A sanity check on our model.

image3 image6
Ground Truth image
Predictions image

Performance

The performance of our model on the different animal classes is as shown below.

Animal Class Samples Head Shoulder Elbow Wrist Hip Knee Ankle Mouth Tail Mean
Birds 1705 95.756 93.637 89.774 88.179 98.975 97.582 94.326 98.447 95.112 95.164
Reptiles 1209 91.538 85.291 84.662 85.587 90.457 88.097 85.239 96.723 83.925 89.553
Mammals 1496 90.641 89.269 88.509 89.927 90.263 88.655 89.535 93.622 82.161 90.038
Fish 918 96.468 96.249 98.643 96.058 98.403 96.743 95.775 97.564 98.256 96.467
Amphibian 1279 98.128 94.342 97.948 98.508 95.491 94.957 94.319 98.702 99.568 95.493

The model performance on Protocol 1 and Protocol 2 is as shown below.

Protocol Samples Head Shoulder Elbow Wrist Hip Knee Ankle Mouth Tail Mean
P1 6620 94.230 91.054 90.806 90.920 94.414 93.233 92.094 96.867 92.346 93.073
P2 2883 88.683 75.815 80.223 81.136 85.568 83.840 82.028 94.799 72.506 83.711

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

aja_pose-0.0.1.tar.gz (24.0 MB view details)

Uploaded Source

Built Distribution

aja_pose-0.0.1-py3-none-any.whl (25.5 MB view details)

Uploaded Python 3

File details

Details for the file aja_pose-0.0.1.tar.gz.

File metadata

  • Download URL: aja_pose-0.0.1.tar.gz
  • Upload date:
  • Size: 24.0 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.10.12

File hashes

Hashes for aja_pose-0.0.1.tar.gz
Algorithm Hash digest
SHA256 58778af9d50e380c56e04afdf4a1f45c1b3166331be0244e69460538f9815f79
MD5 f47586f4dcee9880722c37eebe80aadf
BLAKE2b-256 5750a0a56bf9ed73903072736729051d8fd0b56922717414692fb06e52c5c22e

See more details on using hashes here.

File details

Details for the file aja_pose-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: aja_pose-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 25.5 MB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.10.12

File hashes

Hashes for aja_pose-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 970879e03e08cd49be6349f7641356d8a1fe077a770dfb3cd8e2cfb50fcf8d61
MD5 954d289518c296353fbad6ddae94844f
BLAKE2b-256 201e8de7cb6da6561a4f12d6342e4c703ca47304383ed1e058610a0983239874

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page