Skip to main content

deep learning utility library

Project description

CRAI-Nets 20200907_104224

Python 3.9 Python 3.8 Generic badge PyPI status License

The CRAI-Nets Project

This is just another model-zoo and utility library combined for developing deep learning models. The main reasons for this project to exist is to avoid boilerplate code across projects, letting others tap in on your work, making benchmarking/expermenting easy and fast while also sticking to readibility and reproducibility. The goal of the project is to include as many useful models as possible and also smart customized metrics and loss functions. The project, as of now, is aimed towards computer vision, although contribution within NLP or RL is more than welcome.

Getting started

0. Requirements

The library is platform agnostic although we strongly suggest to use Linux or Mac for ML development. We also suggest to use poetry or pyenv for dependency management unless you are on Win where Conda is the defacto(satans speed to you). Make sure to have python version 3.8 or later installed.

1. Install the package

As recommended, use poetry to install the package by running:

$ poetry add crainets
2. What you need to consider

The Trainer class you can use for simple benchmarking or fast expermenting expects mainly the following:

  1. A model configuration dict containing hyperparameters
  2. A dict containing your loss functions
  3. A dict containing your metrics (you can specify multiple)
  4. Train and test data that you should prep in dataloader class that inherits from the pytorch dataset class
  5. The model architecture imported from crainets model-zoo

We suggest to write your code modular such that configurations come from a config.py script and the dataloader comes from a dataloader.py script.

3. Example
  1. Lets write up two dataloaders that will lazy evaluate our data durng runtime when its batched for training. Cifar10 is used in this example and the only reason why is for brevity.
import torch
import torchvision
import testing.config as config
import torch.utils.data as data_utils

transform = torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor(),
     torchvision.transforms.RandomHorizontalFlip(),
     torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

transform_test = torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor(),
     torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])


train = torchvision.datasets.CIFAR10(
                    config.DATA_PATH, train=True, download=True,
                    transform=transform)


test = torchvision.datasets.CIFAR10(
                    config.DATA_PATH, train=False, download=True,
                    transform=transform_test)

train_loader = torch.utils.data.DataLoader(
                        train,
                        batch_size=config.batch_size_train,
                        shuffle=True
                    )
test_loader = torch.utils.data.DataLoader(
                        test,
                        batch_size=config.batch_size_test,
                        shuffle=True
                    )
  1. Now that we have our data, lets write up a config dict for our network to use.
import os
import torch

ROOT = os.getcwd()
DATA_PATH = os.path.join('/data')
CHCKPT = os.path.join('/checkpoints')
batch_size_train = 100
batch_size_test = 50

TRAIN_CONFIG = {
        "n_gpu": 1,
        "optimizer": {
                "type": "Adam",
                "args": {
                    "lr": 1e-3,
                    "weight_decay": 0,
                    "amsgrad": True
                }
            },
            "loss": "nll_loss",
            "metrics": [
                "accuracy", "top_k_acc"
            ],
            "lr_scheduler": {
                "type": "StepLR",
                "args": {
                    "step_size": 500,
                    "gamma": 0.1
                }
            },
            "trainer": {
                "epochs": 2,
                "iterative": False,
                "iterations": 5,
                "images_pr_iteration": 100,
                "val_images_pr_iteration": 10,
                "save_dir": CHCKPT,
                "save_period": 5,
                "early_stop": 1
                }
            }

METRICS = {
        'CrossEntropy': torch.nn.CrossEntropyLoss()
            }

Note that we also included METRICS as a config in the script. We could define many more metrics in the dict than what is written in the example.

  1. Now lets tie it all together in a controller script for running the network. We are going to use the sexy efficient-net in this example.
# Internal imports
from data_loader import train_loader, test_loader
from config import config

# CRAI-Nets imports
from crainets.trainer.trainer import Trainer
from crainets.models.efficientnet import EfficientNet
from crainets.essentials.multi_loss import MultiLoss
from crainets.essentials.multi_metric import MultiMetric

# specifiy the needed config
model = EfficientNet.from_name(in_channels=3, num_classes=10, model_name='efficientnet-b0')
loss = [(1, torch.nn.CrossEntropyLoss())]
loss = MultiLoss(losses=loss)
    
# Add metrics in the metrics dict from the config file
metrics = MultiMetric(config.METRICS)

# Instantiate zhe class
trainer = Trainer(
    model=model,
    loss_function=loss,
    metric_ftns=metrics,
    config=config.TRAIN_CONFIG,
    data_loader=train_loader,
    valid_data_loader=test_loader,
    seed=666,
    accumulative_metrics=True
)

# Gut gut! Now run the network training und zmile!
trainer.train()
The project is mainly developed and maintained by CRAI at the university hospital of Oslo

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

crainets-0.1.1b0.tar.gz (29.6 kB view details)

Uploaded Source

Built Distribution

crainets-0.1.1b0-py3-none-any.whl (41.2 kB view details)

Uploaded Python 3

File details

Details for the file crainets-0.1.1b0.tar.gz.

File metadata

  • Download URL: crainets-0.1.1b0.tar.gz
  • Upload date:
  • Size: 29.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.1.3 CPython/3.8.6 Darwin/20.6.0

File hashes

Hashes for crainets-0.1.1b0.tar.gz
Algorithm Hash digest
SHA256 20a883c95eba71b471cca18624181161f3a0817e8a883999ee2ecff92a1ff133
MD5 139779b8e49da90e4f0b04db09cd3a02
BLAKE2b-256 e7ae42a34063d579656353c792acd41bdc5107f560103024220558ecdaff84de

See more details on using hashes here.

File details

Details for the file crainets-0.1.1b0-py3-none-any.whl.

File metadata

  • Download URL: crainets-0.1.1b0-py3-none-any.whl
  • Upload date:
  • Size: 41.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.1.3 CPython/3.8.6 Darwin/20.6.0

File hashes

Hashes for crainets-0.1.1b0-py3-none-any.whl
Algorithm Hash digest
SHA256 8e328745a833d9aa4daeaada73d8f3b3a32177821f697da60d45e897a62b43d5
MD5 53044185128f83d550aec5a0078cfc14
BLAKE2b-256 53810741b63b49808dba564c92ddc66a2f66a6cc448c1b0da1a89b9e79e588e3

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page