Animal pose estimation using Vision Transformers and HRNet(VHR)
Project description
Welcome to our module!
Aja-pose
AJA-pose helps you train, validate and test your animal pose estimation model.
Check out how we have done it in Google Colab.
We have evaluated our model (PCK@0.05) and the mean accuracy for the 23 keypoints for our model is 93.073%
We recommend using at least Nvidia V100 GPU for faster inferencing but T4 GPUs will still work.
Getting Started
git clone https://github.com/Antony-gitau/AJA-pose.git
cd AJA-pose
pip install -e .
Getting our model and the dataset. We have made our model public and can be downloaded here
import urllib.request
# # Get the dataset
url = "https://storage.googleapis.com/figures-gp/animal-kingdom/dataset.zip"
destination = "dataset.zip"
urllib.request.urlretrieve(url, destination)
# Unzip the file
!unzip dataset.zip
# The model file
url = "https://storage.googleapis.com/figures-gp/animal-kingdom/all_animals_no_pretrain_106.pth"
destination = "all_animals_no_pretrain_60.pth"
urllib.request.urlretrieve(url, destination)
Test our model
We require the path to the test images directory and the test.json file in MPII format
from aja_pose import Model
# path to the images directory and annotation in mpii json format
images_directory = '' # Path to the images directory
mpii_json = '' # Path to the test.json file
model_file = 'all_animals_no_pretrain_60.pth' # Path to the model file
# Initialize the class
model = Model()
# Test the model on Protocol 1
model.test(images_directory, protocol='P1', model=model_file)
# Test the model on Protocol 2
model.test(images_directory, protocol='P2', model=model_file)
# Test the model on birds class Protocol 3
model.test(images_directory, protocol='P3', model=model_file, animal_class='bird')
# Test the model on reptiles class Protocol 3
model.test(images_directory, protocol='P3', model=model_file, animal_class='reptile')
# Test the model on mammals class Protocol 3
model.test(images_directory, protocol='P3', model=model_file, animal_class='mammal')
# Test the model on fish class Protocol 3
model.test(images_directory, protocol='P3', model=model_file, animal_class='fish')
# Test the model on amphibian class Protocol 3
model.test(images_directory, protocol='P3', model=model_file, animal_class='amphibian')
You can also start to train your model or pretrain on top of ours
# train a VHR model
train_json = '' # labels for the train set (train.json)
valid_json = '' # Labels for the validation set (test.json)
model_file = '' # A pytorch model file to pretrain on.
model.train(images_directory, train_json, valid_json, pretrained=model_file)
# Train a model on a particular class e.g (Ampibian)
model.train(images_directory, protocol='P3', animal_class='amphibian', model=model_file)
Results
A sanity check on our model.
Ground Truth
Predictions
Performance
The performance of our model on the different animal classes is as shown below.
Animal Class | Samples | Head | Shoulder | Elbow | Wrist | Hip | Knee | Ankle | Mouth | Tail | Mean | |
---|---|---|---|---|---|---|---|---|---|---|---|---|
Birds | 1705 | 95.756 | 93.637 | 89.774 | 88.179 | 98.975 | 97.582 | 94.326 | 98.447 | 95.112 | 95.164 | |
Reptiles | 1209 | 91.538 | 85.291 | 84.662 | 85.587 | 90.457 | 88.097 | 85.239 | 96.723 | 83.925 | 89.553 | |
Mammals | 1496 | 90.641 | 89.269 | 88.509 | 89.927 | 90.263 | 88.655 | 89.535 | 93.622 | 82.161 | 90.038 | |
Fish | 918 | 96.468 | 96.249 | 98.643 | 96.058 | 98.403 | 96.743 | 95.775 | 97.564 | 98.256 | 96.467 | |
Amphibian | 1279 | 98.128 | 94.342 | 97.948 | 98.508 | 95.491 | 94.957 | 94.319 | 98.702 | 99.568 | 95.493 |
The model performance on Protocol 1 and Protocol 2 is as shown below.
Protocol | Samples | Head | Shoulder | Elbow | Wrist | Hip | Knee | Ankle | Mouth | Tail | Mean | |
---|---|---|---|---|---|---|---|---|---|---|---|---|
P1 | 6620 | 94.230 | 91.054 | 90.806 | 90.920 | 94.414 | 93.233 | 92.094 | 96.867 | 92.346 | 93.073 | |
P2 | 2883 | 88.683 | 75.815 | 80.223 | 81.136 | 85.568 | 83.840 | 82.028 | 94.799 | 72.506 | 83.711 |
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file aja_pose-0.0.1.tar.gz
.
File metadata
- Download URL: aja_pose-0.0.1.tar.gz
- Upload date:
- Size: 24.0 MB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.10.12
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 58778af9d50e380c56e04afdf4a1f45c1b3166331be0244e69460538f9815f79 |
|
MD5 | f47586f4dcee9880722c37eebe80aadf |
|
BLAKE2b-256 | 5750a0a56bf9ed73903072736729051d8fd0b56922717414692fb06e52c5c22e |
File details
Details for the file aja_pose-0.0.1-py3-none-any.whl
.
File metadata
- Download URL: aja_pose-0.0.1-py3-none-any.whl
- Upload date:
- Size: 25.5 MB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.10.12
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 970879e03e08cd49be6349f7641356d8a1fe077a770dfb3cd8e2cfb50fcf8d61 |
|
MD5 | 954d289518c296353fbad6ddae94844f |
|
BLAKE2b-256 | 201e8de7cb6da6561a4f12d6342e4c703ca47304383ed1e058610a0983239874 |