Skip to main content

Cool package for robust AI

Project description

Installation

Robbytorch requires Pytorch, however it's not specified in the dependencies - we recommend installing Pytorch manually via conda and only later installing Robbytorch by pip.

Use your conda env or create a new one:

conda create --name <ENV NAME> python=3.8 pip
conda activate <ENV NAME>

Install Pytorch. If you have older drivers for GPU you may want to require older version of CUDA, i.e.:

conda install pytorch torchvision torchaudio cudatoolkit=10.1 -c pytorch -c conda-forge

or even older Pytorch version:

conda install pytorch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2 cudatoolkit=10.1 -c pytorch

Then run:

pip install robbytorch

Usage

See jupyter notebooks in ipython/ for complete examples. For step-by-step introduction continue reading this file.

Prepare Dataset

Place your data into chosen root directory, i.e. "/dysk1/approx/robby".

You can subclass robbytorch.datasets.DictDateset and implement two methods - for more info please read the docstring for that class. Here's an example implementation:

import torch
from torchvision import transforms
from torch.utils.data import DataLoader

from robbytorch.datasets import DictDataset


class CreationsDataset(DictDataset):

    def load_data(self, idx):
        file_name = f"{self.metadata.iloc[idx]['creation_id']}.png"
        return self.load_image(file_name)
    
    def load_target_dict(self, idx):
        record = self.metadata.iloc[idx].to_dict()
        
        return {col: torch.tensor(record[col]).float() 
                for col in ['label', 'CR']
               }

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((224, 224))
])
data_root = "/dysk1/approx/robby"
metadata = pd.DataFrame(data=[(3141, 0, 0.01), (2137, 1, 0.012)], columns=["creation_id", "label", "CR"])
dataset = CreationsDataset(data_root, metadata, transform=transform)
dataloader = DataLoader(dataset, batch_size=128, shuffle=False, num_workers=2)

Now whenever you iter through the dataloader you get a dict of batched tensors (with .shape[0] == 128):

{
    "data": batched_tensor_data,
    "label": batched_tensor_label,
    "CR": batched_CR_of_tensors
}

You can use this structure however you like during training/evaluation.

TODO - further eplanations:

  • trenowanie: 3x forward
  • configs z lib2
  • Dodawanie auxiliary losses za pomocą magic hooks
  • Writers objaśnić, livelossplot i mlflow
  • wczytywanie robust networks
  • opis utilities - notebook, visualization

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

robbytorch-0.1.3.tar.gz (32.3 kB view details)

Uploaded Source

Built Distribution

robbytorch-0.1.3-py3-none-any.whl (43.7 kB view details)

Uploaded Python 3

File details

Details for the file robbytorch-0.1.3.tar.gz.

File metadata

  • Download URL: robbytorch-0.1.3.tar.gz
  • Upload date:
  • Size: 32.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.5.0 pkginfo/1.5.0.1 requests/2.21.0 requests-toolbelt/0.9.1 tqdm/4.31.1 CPython/3.7.3

File hashes

Hashes for robbytorch-0.1.3.tar.gz
Algorithm Hash digest
SHA256 210afda28c00e9d72260c71c82368e408578edf7031a955fcf981dc123ea7212
MD5 3eaa2176460e70795a590510fad1732f
BLAKE2b-256 9ffa3a2d18d7c4c3bc03def206102d41067caf577ba3f34a9ee56ecdcd3dc26e

See more details on using hashes here.

File details

Details for the file robbytorch-0.1.3-py3-none-any.whl.

File metadata

  • Download URL: robbytorch-0.1.3-py3-none-any.whl
  • Upload date:
  • Size: 43.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.5.0 pkginfo/1.5.0.1 requests/2.21.0 requests-toolbelt/0.9.1 tqdm/4.31.1 CPython/3.7.3

File hashes

Hashes for robbytorch-0.1.3-py3-none-any.whl
Algorithm Hash digest
SHA256 2c76b674799833a31dbf5945d6243dd6291792bee4331c14b4de1a358bf85aa8
MD5 a1b546d1b09682c426babd8a2e61dc8c
BLAKE2b-256 a6f58188597825c13077f170117629993550b7646988d6280073615a85ef7741

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page