Skip to main content

Neural Networks library for image classification task.

Project description

Open Neural Networks library for image classification.

PyPI Generic badge
Docker PyTorch

Table of content

  1. Quick start
  2. Warnings
  3. Encoders
  4. Decoders
  5. Pretrained
  6. Datasets
  7. Losses
  8. Metrics
  9. Optimizers
  10. Schedulers
  11. Examples

Quick start

1. Straight install.

1.1 Install torch with cuda.
pip install -U torch --extra-index-url https://download.pytorch.org/whl/cu113
1.2 Install opennn_pytorch.
pip install -U opennn_pytorch

2. Dockerfile.

cd docker/
docker build -t opennn:latest .

Warnings

  1. Cuda is only supported for nvidia graphics cards.
  2. Alexnet decoder doesn't support bce losses family.
  3. Sometimes combination of dataset/encoder/decoder/loss can give bad results, try to combine others.
  4. Custom cross-entropy support only mode when preds have (n, c) shape and labels have (n) shape.
  5. Not all options in transform.yaml and config.yaml are required.
  6. Mean and std in datasets section must be used in transform.yaml, for example [mean=[0.2859], std=[0.3530]] -> normalize: [[0.2859], [0.3530]]

Encoders

Decoders

Pretrained

LeNet
Encoder Decoder Dataset Weights Configs Logs
LeNet LeNet MNIST BEST, PLAN CONFIG, TRANSFORM TRAINVAL
LeNet LeNet FASHION-MNIST BEST, PLAN CONFIG, TRANSFORM TRAINVAL
LeNet Linear MNIST BEST, PLAN CONFIG, TRANSFORM TRAINVAL
LeNet Linear FASHION-MNIST BEST, PLAN CONFIG, TRANSFORM TRAINVAL
LeNet AlexNet MNIST BEST, PLAN CONFIG, TRANSFORM TRAINVAL
LeNet AlexNet FASHION-MNIST BEST, PLAN CONFIG, TRANSFORM TRAINVAL
AlexNet
Encoder Decoder Dataset Weights Configs Logs
AlexNet LeNet MNIST BEST, PLAN CONFIG, TRANSFORM TRAINVAL
AlexNet LeNet FASHION-MNIST BEST, PLAN CONFIG, TRANSFORM TRAINVAL
AlexNet Linear MNIST BEST, PLAN CONFIG, TRANSFORM TRAINVAL
AlexNet Linear FASHION-MNIST BEST, PLAN CONFIG, TRANSFORM TRAINVAL
AlexNet AlexNet MNIST BEST, PLAN CONFIG, TRANSFORM TRAINVAL
AlexNet AlexNet FASHION-MNIST BEST, PLAN CONFIG, TRANSFORM TRAINVAL
GoogleNet
Encoder Decoder Dataset Weights Configs Logs
GoogleNet Linear MNIST BEST, PLAN CONFIG, TRANSFORM TRAINVAL
GoogleNet Linear FASHION-MNIST BEST, PLAN CONFIG, TRANSFORM TRAINVAL
ResNet
Encoder Decoder Dataset Weights Configs Logs
ResNet18 Linear MNIST BEST, PLAN CONFIG, TRANSFORM TRAINVAL
ResNet34 Linear MNIST BEST, PLAN CONFIG, TRANSFORM TRAINVAL
ResNet50 Linear MNIST BEST, PLAN CONFIG, TRANSFORM TRAINVAL
ResNet101 Linear MNIST BEST, PLAN CONFIG, TRANSFORM TRAINVAL
ResNet152 Linear MNIST BEST, PLAN CONFIG, TRANSFORM TRAINVAL
MobileNet
Encoder Decoder Dataset Weights Configs Logs
MobileNet Linear MNIST BEST, PLAN CONFIG, TRANSFORM TRAINVAL
VGG
Encoder Decoder Dataset Weights Configs Logs
VGG-11 Linear MNIST BEST, PLAN CONFIG, TRANSFORM TRAINVAL
VGG-16 Linear MNIST BEST, PLAN CONFIG, TRANSFORM TRAINVAL
VGG-19 Linear MNIST BEST, PLAN CONFIG, TRANSFORM TRAINVAL

Datasets

  • MNIST [files] [code] [classes=10] [mean=[0.1307], std=[0.3801]]
  • FASHION-MNIST [files] [code] [classes=10] [mean=[0.2859], std=[0.3530]]
  • CIFAR-10 [files] [code] [classes=10] [mean=[0.491, 0.482, 0.446], std=[0.247, 0.243, 0.261]]
  • CIFAR-100 [files] [code] [classes=100] [mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761]]
  • GTSRB [files] [code] [classes=43] [mean=unknown, std=unknown]
  • CUSTOM [docs] [code] [example] [classes=nc] [mean=unknown, std=unknown]

Losses

Metrics

Optimizers

Schedulers

Examples

  1. Run from yaml config.
import opennn_pytorch
  
config = 'path to yaml config'  # check configs folder
opennn_pytorch.run(config)
  1. Get encoder and decoder.
import opennn_pytorch
  
encoder_name = 'resnet18'
decoder_name = 'alexnet'
decoder_mode = 'decoder'
input_channels = 1
number_classes = 10
device = 'cuda'

encoder = opennn_pytorch.encoders.get_encoder(encoder_name, input_channels).to(device)
model = opennn_pytorch.decoders.get_decoder(decoder_name, encoder, number_classes, decoder_mode, device).to(device)

3.1 Get dataset.

import opennn_pytorch
from torchvision import transforms

transform_config = 'path to transform yaml config'
dataset_name = 'mnist'
datafiles = None
train_part = 0.7
valid_part = 0.2

transform_lst = opennn_pytorch.transforms_lst(transform_config)
transform = transforms.Compose(transform_lst)
  
train_data, valid_data, test_data = opennn_pytorch.datasets.get_dataset(dataset_name, train_part, valid_part, transform, datafiles)

3.2 Get custom dataset.

import opennn_pytorch
from torchvision import transforms

transform_config = 'path to transform yaml config'
dataset_name = 'custom'
images = 'path to folder with images'
annotation = 'path to annotation yaml file with image: class structure'
datafiles = (images, annotation)
train_part = 0.7
valid_part = 0.2

transform_lst = opennn_pytorch.transforms_lst(transform_config)
transform = transforms.Compose(transform_lst)
  
train_data, valid_data, test_data = opennn_pytorch.datasets.get_dataset(dataset_name, train_part, valid_part, transform, datafiles)
  1. Get optimizer.
import opennn_pytorch

optim_name = 'adam'
lr = 1e-3
betas = (0.9, 0.999)
eps = 1e-8
weight_decay = 1e-6
optimizer = opennn_pytorch.optimizers.get_optimizer(optim_name, model, lr=lr, betas=betas, eps=opt_eps, weight_decay=weight_decay)
  1. Get scheduler.
import opennn_pytorch

scheduler_name = 'steplr'
step = 10
gamma = 0.5
scheduler = opennn_pytorch.schedulers.get_scheduler(sched, optimizer, step=step, gamma=gamma, milestones=None)
  1. Get loss function.
import opennn_pytorch

loss_fn = 'custom_mse'
loss_fn, one_hot = opennn_pytorch.losses.get_loss(loss_fn)
  1. Get metrics functions.
import opennn_pytorch

metrics_names = ['accuracy', 'precision', 'recall', 'f1_score']
number_classes = 10
metrics_fn = opennn_pytorch.metrics.get_metrics(metrics_names, nc=number_classes)
  1. Train/Test.
import opennn_pytorch

algorithm = 'train'
batch_size = 16
class_names = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
number_classes = 10
save_every = 5
epochs = 20

train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
valid_dataloader = torch.utils.data.DataLoader(valid_data, batch_size=batch_size, shuffle=False)
test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=False)

if algorithm == 'train':
  opennn_pytorch.algo.train(train_dataloader, valid_dataloader, model, optimizer, scheduler, loss_fn, metrics_fn, epochs, checkpoints, logs, device, save_every, one_hot, number_classes)
elif algorithm == 'test':
  test_logs = opennn_pytorch.algo.test(test_dataloader, model, loss_fn, metrics_fn, logs, device, one_hot, number_classes)
  if viz:
    os.mkdir(test_logs + '/vizualize', 0o777)
    for i in range(10):
      os.mkdir(test_logs + f'/vizualize/{i}', 0o777)
      opennn_pytorch.algo.vizualize(valid_data, model, device, {i: class_names[i] for i in range(number_classes)}, test_logs + f'/vizualize/{i}')

Citation

Project citation.

License

Project is distributed under MIT License.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

opennn_pytorch-1.0.7.tar.gz (29.3 kB view details)

Uploaded Source

Built Distribution

opennn_pytorch-1.0.7-py2.py3-none-any.whl (38.4 kB view details)

Uploaded Python 2 Python 3

File details

Details for the file opennn_pytorch-1.0.7.tar.gz.

File metadata

  • Download URL: opennn_pytorch-1.0.7.tar.gz
  • Upload date:
  • Size: 29.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.0 CPython/3.8.11

File hashes

Hashes for opennn_pytorch-1.0.7.tar.gz
Algorithm Hash digest
SHA256 d3fb7d22470207b25f07dc766249728e516f81718673b6a54c6f1029869301bd
MD5 6183e6292aa2a857560de6d7a8a124d6
BLAKE2b-256 e15e7884271a18c3f8eed040d71ee22ce0788ef487c3c19bd4ff71d9583f8b44

See more details on using hashes here.

File details

Details for the file opennn_pytorch-1.0.7-py2.py3-none-any.whl.

File metadata

File hashes

Hashes for opennn_pytorch-1.0.7-py2.py3-none-any.whl
Algorithm Hash digest
SHA256 5b5c472d3e2b8ce79ad10dfa7245deeb642ae3c09fd1b1a8242ff4b7ade4b59a
MD5 90a614033f5a92faf94ce3aaf5bbd35e
BLAKE2b-256 8e28b4a72acee580ed282eabb2332af6a52753b0cb081c20a7a0e177d4fb0c0f

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page