Skip to main content

scikit-learn compatible neural network library for pytorch

Project description

https://github.com/skorch-dev/skorch/blob/master/assets/skorch.svg

Build Status Test Coverage Documentation Status Powered by

A scikit-learn compatible neural network library that wraps PyTorch.

Resources

Examples

To see more elaborate examples, look here.

import numpy as np
from sklearn.datasets import make_classification
from torch import nn

from skorch import NeuralNetClassifier


X, y = make_classification(1000, 20, n_informative=10, random_state=0)
X = X.astype(np.float32)
y = y.astype(np.int64)

class MyModule(nn.Module):
    def __init__(self, num_units=10, nonlin=nn.ReLU()):
        super(MyModule, self).__init__()

        self.dense0 = nn.Linear(20, num_units)
        self.nonlin = nonlin
        self.dropout = nn.Dropout(0.5)
        self.dense1 = nn.Linear(num_units, num_units)
        self.output = nn.Linear(num_units, 2)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, X, **kwargs):
        X = self.nonlin(self.dense0(X))
        X = self.dropout(X)
        X = self.nonlin(self.dense1(X))
        X = self.softmax(self.output(X))
        return X


net = NeuralNetClassifier(
    MyModule,
    max_epochs=10,
    lr=0.1,
    # Shuffle training data on each epoch
    iterator_train__shuffle=True,
)

net.fit(X, y)
y_proba = net.predict_proba(X)

In an sklearn Pipeline:

from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler


pipe = Pipeline([
    ('scale', StandardScaler()),
    ('net', net),
])

pipe.fit(X, y)
y_proba = pipe.predict_proba(X)

With grid search:

from sklearn.model_selection import GridSearchCV


# deactivate skorch-internal train-valid split and verbose logging
net.set_params(train_split=False, verbose=0)
params = {
    'lr': [0.01, 0.02],
    'max_epochs': [10, 20],
    'module__num_units': [10, 20],
}
gs = GridSearchCV(net, params, refit=False, cv=3, scoring='accuracy', verbose=2)

gs.fit(X, y)
print("best score: {:.3f}, best params: {}".format(gs.best_score_, gs.best_params_))

skorch also provides many convenient features, among others:

Installation

skorch requires Python 3.5 or higher.

conda installation

You need a working conda installation. Get the correct miniconda for your system from here.

To install skorch, you need to use the conda-forge channel:

conda install -c conda-forge skorch

We recommend to use a conda virtual environment.

Note: The conda channel is not managed by the skorch maintainers. More information is available here.

pip installation

To install with pip, run:

pip install -U skorch

Again, we recommend to use a virtual environment for this.

From source

If you would like to use the most recent additions to skorch or help development, you should install skorch from source.

Using conda

To install skorch from source using conda, proceed as follows:

git clone https://github.com/skorch-dev/skorch.git
cd skorch
conda env create
source activate skorch
pip install .

If you want to help developing, run:

git clone https://github.com/skorch-dev/skorch.git
cd skorch
conda env create
source activate skorch
pip install -e .

py.test  # unit tests
pylint skorch  # static code checks

Using pip

For pip, follow these instructions instead:

git clone https://github.com/skorch-dev/skorch.git
cd skorch
# create and activate a virtual environment
pip install -r requirements.txt
# install pytorch version for your system (see below)
pip install .

If you want to help developing, run:

git clone https://github.com/skorch-dev/skorch.git
cd skorch
# create and activate a virtual environment
pip install -r requirements.txt
# install pytorch version for your system (see below)
pip install -r requirements-dev.txt
pip install -e .

py.test  # unit tests
pylint skorch  # static code checks

PyTorch

PyTorch is not covered by the dependencies, since the PyTorch version you need is dependent on your OS and device. For installation instructions for PyTorch, visit the PyTorch website. skorch officially supports the last four minor PyTorch versions, which currently are:

  • 1.3.1

  • 1.4.0

  • 1.5.1

  • 1.6.0

However, that doesn’t mean that older versions don’t work, just that they aren’t tested. Since skorch mostly relies on the stable part of the PyTorch API, older PyTorch versions should work fine.

In general, running this to install PyTorch should work (assuming CUDA 10.2):

# using conda:
conda install pytorch cudatoolkit==10.2 -c pytorch
# using pip
pip install torch

Communication

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

skorch-0.9.0.tar.gz (113.2 kB view details)

Uploaded Source

Built Distributions

skorch-0.9.0-py3.7.egg (308.7 kB view details)

Uploaded Source

skorch-0.9.0-py3-none-any.whl (125.8 kB view details)

Uploaded Python 3

File details

Details for the file skorch-0.9.0.tar.gz.

File metadata

  • Download URL: skorch-0.9.0.tar.gz
  • Upload date:
  • Size: 113.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.12.1 pkginfo/1.5.0.1 requests/2.24.0 setuptools/46.1.3.post20200330 requests-toolbelt/0.9.1 tqdm/4.44.1 CPython/3.7.7

File hashes

Hashes for skorch-0.9.0.tar.gz
Algorithm Hash digest
SHA256 bdce9370153fd80c5c4ec499a639f55eef0620e45d4b15fbf7d7ff2a225a3d40
MD5 e099bd2fc0eb688df0875659f600533b
BLAKE2b-256 19c336d3305afd5f44663799c0d6ff962ba8e27e41c388599ac5c43e210332c5

See more details on using hashes here.

File details

Details for the file skorch-0.9.0-py3.7.egg.

File metadata

  • Download URL: skorch-0.9.0-py3.7.egg
  • Upload date:
  • Size: 308.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.12.1 pkginfo/1.5.0.1 requests/2.24.0 setuptools/46.1.3.post20200330 requests-toolbelt/0.9.1 tqdm/4.44.1 CPython/3.7.7

File hashes

Hashes for skorch-0.9.0-py3.7.egg
Algorithm Hash digest
SHA256 12bb80276719cdbd114bc5042f4d0b395ce1ebe5dbc29aba5d4ea2f1792f9705
MD5 71cbffba2363334318a929008cd1b5b6
BLAKE2b-256 80c8f63f35a88cfbf999706eaf22069be08f28658236fe168a400703be9e8c7d

See more details on using hashes here.

File details

Details for the file skorch-0.9.0-py3-none-any.whl.

File metadata

  • Download URL: skorch-0.9.0-py3-none-any.whl
  • Upload date:
  • Size: 125.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.12.1 pkginfo/1.5.0.1 requests/2.24.0 setuptools/46.1.3.post20200330 requests-toolbelt/0.9.1 tqdm/4.44.1 CPython/3.7.7

File hashes

Hashes for skorch-0.9.0-py3-none-any.whl
Algorithm Hash digest
SHA256 26317da14837f372fdeb8fb4eee9199c2cc0b0db1056fc4ab69696402e17e135
MD5 224a4af819037ba9156e9cb544c35015
BLAKE2b-256 18c72f6434f9360c91a4bf14ae85f634758e5dacd3539cca4266a60be9f881ae

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page