Saliency Detection library (models, loss, utils) with PyTorch
Project description
saldet
Saliency Detection (saldet) is a collection of models and tools to perform Saliency Detection with PyTorch (cuda, mps, etc.).
Models
List of saliency detection models supported by saldet:
- U2Net - https://arxiv.org/abs/2005.09007v3 (U2Net repo)
- PGNet - https://arxiv.org/abs/2204.05041 (follow training instructions from PGNet repo)
- PFAN - https://arxiv.org/pdf/1903.00179v2.pdf (PFAN repo)
Weights
- PGNet -> weights from PGNet repo converted to saldet version from here
- U2Net Lite -> weights from here (U2Net repository)
- U2Net Full -> weights from here (U2Net repository)
- U2Net Full - Portrait -> weights for portrait images from here (U2Net repository)
- U2Net Full - Human Segmentation -> weights for segmenting humans from here (U2Net repository)
- PFAN -> weights from PFAN repo converted to saldet version from here
To load pre-trained weights:
from saldet import create_model
model = create_model("pgnet", checkpoint_path="PATH/TO/pgnet.pth")
Train
Easy Mode
The library comes with easy access to train models thanks to the amazing PyTorch Lightning support.
from saldet.experiment import train
train(
data_dir=...,
config_path="config/u2net_lite.yaml", # check the config folder with some configurations
output_dir=...,
resume_from=...,
seed=42
)
Once the training is over, configuration file and checkpoints will be saved into the output dir.
[WARNING] The dataset must be structured as follows:
dataset
├── train
| ├── images
| │ ├── img_1.jpg
| │ └── img_2.jpg
| └── masks
| ├── img_1.png
| └── img_2.png
└── val
├── images
│ ├── img_10.jpg
│ └── img_11.jpg
└── masks
├── img_10.png
└── img_11.png
PyTorch Lighting Mode
The library provides utils for model and data PyTorch Lightning Modules.
import pytorch_lightning as pl
from saldet import create_model
from saldet.pl import
SaliencyPLDataModule, SaliencyPLModel
from saldet.transform import SaliencyTransform
# datamodule
datamodule = SaliencyPLDataModule(
root_dir=data_dir,
train_transform=SaliencyTransform(train=True, **config["transform"]),
val_transform=SaliencyTransform(train=False, **config["transform"]),
**config["datamodule"],
)
model = create_model(...)
criterion = ...
optimizer = ...
lr_scheduler = ...
pl_model = SaliencyPLModel(
model=model, criterion=criterion, optimizer=optimizer, lr_scheduler=lr_scheduler
)
trainer = pl.Trainer(...)
# fit
print(f"Launching training...")
trainer.fit(model=pl_model, datamodule=datamodule)
PyTorch Mode
Alternatively you can define your custom training process and use the create_model()
util to use the model you like.
Inference
The library comes with easy access to inference saliency maps from a folder with images.
from saldet.experiment import inference
inference(
images_dir=...,
ckpt=..., # path to ckpt/pth model file
config_path=..., # path to configuration file from saldet train
output_dir=..., # where to save saliency maps
sigmoid=..., # whether to apply sigmoid to predicted masks
)
To-Dos
[ ] Improve code coverage
[ ] ReadTheDocs documentation
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file saldet-0.6.1.tar.gz
.
File metadata
- Download URL: saldet-0.6.1.tar.gz
- Upload date:
- Size: 28.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.4.2 CPython/3.9.6 Darwin/22.5.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 62afa0246dd90cf0cb8d43051e199abe2a13b3f0656312ec9c7b8601c1133d3b |
|
MD5 | b0cd3037aeabd505ee1449b9bc310e26 |
|
BLAKE2b-256 | 382471838b72210002901cd5c0702a27cf923c48048a3a8d6a7a4d9ed71d5f9e |
File details
Details for the file saldet-0.6.1-py3-none-any.whl
.
File metadata
- Download URL: saldet-0.6.1-py3-none-any.whl
- Upload date:
- Size: 37.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.4.2 CPython/3.9.6 Darwin/22.5.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 87939ac87fe1042665562cb6d50060798d1ce4b6a7d372133498fb902a778206 |
|
MD5 | c0d540c2638fd32512c835d4b280db13 |
|
BLAKE2b-256 | bb09f07fbc6c5ff8f0ef6e22639155219c007e900d090d865ab35a86d8e47927 |