Skip to main content

Saliency Detection library (models, loss, utils) with PyTorch

Project description

saldet

Saliency Detection (saldet) is a collection of models and tools to perform Saliency Detection with PyTorch (cuda, mps, etc.).

PyPI Version Build Status Code Coverage

Models

List of saliency detection models supported by saldet:

Weights

  • PGNet -> weights from PGNet repo converted to saldet version from here
  • U2Net Lite -> weights from here (U2Net repository)
  • U2Net Full -> weights from here (U2Net repository)
  • U2Net Full - Portrait -> weights for portrait images from here (U2Net repository)
  • U2Net Full - Human Segmentation -> weights for segmenting humans from here (U2Net repository)
  • PFAN -> weights from PFAN repo converted to saldet version from here

To load pre-trained weights:

from saldet import create_model
model = create_model("pgnet", checkpoint_path="PATH/TO/pgnet.pth")

Train

Easy Mode

The library comes with easy access to train models thanks to the amazing PyTorch Lightning support.

from saldet.experiment import train

train(
    data_dir=...,
    config_path="config/u2net_lite.yaml", # check the config folder with some configurations
    output_dir=...,
    resume_from=...,
    seed=42
)

Once the training is over, configuration file and checkpoints will be saved into the output dir.

[WARNING] The dataset must be structured as follows:

dataset
    ├── train                    
    |       ├── images          
    |       │   ├── img_1.jpg
    |       │   └── img_2.jpg                
    |       └── masks
    |           ├── img_1.png
    |           └── img_2.png   
    └── val
           ├── images          
           │   ├── img_10.jpg
           │   └── img_11.jpg                
           └── masks
               ├── img_10.png
               └── img_11.png   

PyTorch Lighting Mode

The library provides utils for model and data PyTorch Lightning Modules.

import pytorch_lightning as pl
from saldet import create_model
from saldet.pl import
 SaliencyPLDataModule, SaliencyPLModel
from saldet.transform import SaliencyTransform

# datamodule
datamodule = SaliencyPLDataModule(
    root_dir=data_dir,
    train_transform=SaliencyTransform(train=True, **config["transform"]),
    val_transform=SaliencyTransform(train=False, **config["transform"]),
    **config["datamodule"],
)

model = create_model(...)
criterion = ...
optimizer = ...
lr_scheduler = ...

pl_model = SaliencyPLModel(
    model=model, criterion=criterion, optimizer=optimizer, lr_scheduler=lr_scheduler
)

trainer = pl.Trainer(...)

# fit
print(f"Launching training...")
trainer.fit(model=pl_model, datamodule=datamodule)

PyTorch Mode

Alternatively you can define your custom training process and use the create_model() util to use the model you like.

Inference

The library comes with easy access to inference saliency maps from a folder with images.

from saldet.experiment import inference

inference(
    images_dir=...,
    ckpt=..., # path to ckpt/pth model file
    config_path=..., # path to configuration file from saldet train
    output_dir=..., # where to save saliency maps
    sigmoid=..., # whether to apply sigmoid to predicted masks
)

To-Dos

[ ] Improve code coverage

[ ] ReadTheDocs documentation

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

saldet-0.6.1.tar.gz (28.5 kB view hashes)

Uploaded Source

Built Distribution

saldet-0.6.1-py3-none-any.whl (37.3 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page