Skip to main content

scikit-learn compatible neural network library for pytorch

Project description

.. image:: https://github.com/dnouri/skorch/blob/master/assets/skorch.svg
:width: 30%

------------

|build| |coverage| |docs| |powered|

A scikit-learn compatible neural network library that wraps PyTorch.

.. |build| image:: https://travis-ci.org/dnouri/skorch.svg?branch=master
:alt: Build Status
:scale: 100%
:target: https://travis-ci.org/dnouri/skorch?branch=master

.. |coverage| image:: https://github.com/dnouri/skorch/blob/master/assets/coverage.svg
:alt: Test Coverage
:scale: 100%

.. |docs| image:: https://readthedocs.org/projects/skorch/badge/?version=latest
:alt: Documentation Status
:scale: 100%
:target: https://skorch.readthedocs.io/en/latest/?badge=latest

.. |powered| image:: https://github.com/dnouri/skorch/blob/master/assets/powered.svg
:alt: Powered by
:scale: 100%
:target: https://github.com/ottogroup/

=========
Resources
=========

- `Documentation <https://skorch.readthedocs.io/en/latest/?badge=latest>`_
- `Source Code <https://github.com/dnouri/skorch/>`_

========
Examples
========

To see more elaborate examples, look `here
<https://github.com/dnouri/skorch/tree/master/notebooks/README.md>`__.

.. code:: python

import numpy as np
from sklearn.datasets import make_classification
from torch import nn
import torch.nn.functional as F

from skorch import NeuralNetClassifier


X, y = make_classification(1000, 20, n_informative=10, random_state=0)
X = X.astype(np.float32)
y = y.astype(np.int64)

class MyModule(nn.Module):
def __init__(self, num_units=10, nonlin=F.relu):
super(MyModule, self).__init__()

self.dense0 = nn.Linear(20, num_units)
self.nonlin = nonlin
self.dropout = nn.Dropout(0.5)
self.dense1 = nn.Linear(num_units, 10)
self.output = nn.Linear(10, 2)

def forward(self, X, **kwargs):
X = self.nonlin(self.dense0(X))
X = self.dropout(X)
X = F.relu(self.dense1(X))
X = F.softmax(self.output(X), dim=-1)
return X


net = NeuralNetClassifier(
MyModule,
max_epochs=10,
lr=0.1,
# Shuffle training data on each epoch
iterator_train__shuffle=True,
)

net.fit(X, y)
y_proba = net.predict_proba(X)

In an sklearn Pipeline:

.. code:: python

from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler


pipe = Pipeline([
('scale', StandardScaler()),
('net', net),
])

pipe.fit(X, y)
y_proba = pipe.predict_proba(X)

With grid search

.. code:: python

from sklearn.model_selection import GridSearchCV


params = {
'lr': [0.01, 0.02],
'max_epochs': [10, 20],
'module__num_units': [10, 20],
}
gs = GridSearchCV(net, params, refit=False, cv=3, scoring='accuracy')

gs.fit(X, y)
print(gs.best_score_, gs.best_params_)

skorch also provides many convenient features, among others:

- `Learning rate schedulers <https://skorch.readthedocs.io/en/stable/callbacks.html#skorch.callbacks.LRScheduler>`_ (Warm restarts, cyclic LR and many more)
- `Scoring using sklearn (and custom) scoring functions <https://skorch.readthedocs.io/en/stable/callbacks.html#skorch.callbacks.EpochScoring>`_
- `Early stopping <https://skorch.readthedocs.io/en/stable/callbacks.html#skorch.callbacks.EarlyStopping>`_
- `Checkpointing <https://skorch.readthedocs.io/en/stable/callbacks.html#skorch.callbacks.Checkpoint>`_
- `Parameter freezing/unfreezing <https://skorch.readthedocs.io/en/stable/callbacks.html#skorch.callbacks.Freezer>`_
- `Progress bar <https://skorch.readthedocs.io/en/stable/callbacks.html#skorch.callbacks.ProgressBar>`_ (for CLI as well as jupyter)
- `Automatic inference of CLI parameters <https://github.com/dnouri/skorch/tree/master/examples/cli>`_

============
Installation
============

skorch requires Python 3.5 or higher.

pip installation
================

To install with pip, run:

.. code:: bash

pip install -U skorch

We recommend to use a virtual environment for this.

From source
===========

If you would like to use the must recent additions to skorch or
help development, you should install skorch from source.

Using conda
===========

You need a working conda installation. Get the correct miniconda for
your system from `here <https://conda.io/miniconda.html>`__.

If you just want to use skorch, use:

.. code:: bash

git clone https://github.com/dnouri/skorch.git
cd skorch
conda env create
source activate skorch
# install pytorch version for your system (see below)
python setup.py install

If you want to help developing, run:

.. code:: bash

git clone https://github.com/dnouri/skorch.git
cd skorch
conda env create
source activate skorch
# install pytorch version for your system (see below)
conda install -c conda-forge --file requirements-dev.txt
python setup.py develop

py.test # unit tests
pylint skorch # static code checks

Using pip
=========

If you just want to use skorch, use:

.. code:: bash

git clone https://github.com/dnouri/skorch.git
cd skorch
# create and activate a virtual environment
pip install -r requirements.txt
# install pytorch version for your system (see below)
python setup.py install

If you want to help developing, run:

.. code:: bash

git clone https://github.com/dnouri/skorch.git
cd skorch
# create and activate a virtual environment
pip install -r requirements.txt
# install pytorch version for your system (see below)
pip install -r requirements-dev.txt
python setup.py develop

py.test # unit tests
pylint skorch # static code checks

PyTorch
=======

PyTorch is not covered by the dependencies, since the PyTorch version
you need is dependent on your system. For installation instructions
for PyTorch, visit the `PyTorch website <http://pytorch.org/>`__. The
current version of skorch assumes PyTorch >= 1.1.0.

In general, this should work (assuming CUDA 9):

.. code:: bash

# using conda:
conda install pytorch -c pytorch
# using pip
pip install torch

=============
Communication
=============

- `GitHub issues <https://github.com/dnouri/skorch/issues>`_: bug
reports, feature requests, install issues, RFCs, thoughts, etc.

- Slack: We run the #skorch channel on the `PyTorch Slack server
<https://pytorch.slack.com/>`_. If you need an invite, send an
email to daniel.nouri@gmail.com.


Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

skorch-0.6.0.tar.gz (88.4 kB view details)

Uploaded Source

Built Distributions

skorch-0.6.0-py3.6.egg (251.2 kB view details)

Uploaded Source

skorch-0.6.0-py3-none-any.whl (101.8 kB view details)

Uploaded Python 3

File details

Details for the file skorch-0.6.0.tar.gz.

File metadata

  • Download URL: skorch-0.6.0.tar.gz
  • Upload date:
  • Size: 88.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.12.1 pkginfo/1.5.0.1 requests/2.22.0 setuptools/27.2.0 requests-toolbelt/0.9.1 tqdm/4.14.0 CPython/3.6.2

File hashes

Hashes for skorch-0.6.0.tar.gz
Algorithm Hash digest
SHA256 4ee4c67bf8a412a5b583a1ac4c5e4b6b92c8d98ae9319eafa977ce8209e0fc93
MD5 370ffef71611db35d604c9c6d288ef0f
BLAKE2b-256 c3fc517e70d1262daba416eb79b5a24a83b7da11dc4cad8d61f7034388cc3ea4

See more details on using hashes here.

Provenance

File details

Details for the file skorch-0.6.0-py3.6.egg.

File metadata

  • Download URL: skorch-0.6.0-py3.6.egg
  • Upload date:
  • Size: 251.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.12.1 pkginfo/1.5.0.1 requests/2.22.0 setuptools/27.2.0 requests-toolbelt/0.9.1 tqdm/4.14.0 CPython/3.6.2

File hashes

Hashes for skorch-0.6.0-py3.6.egg
Algorithm Hash digest
SHA256 df9128e8605792882de113150b40049a849a536cb538d9e03c51f2f3cf2b28d0
MD5 8ff6cd644bd6ca634c885934b6c7ee93
BLAKE2b-256 497eded52f837fe8fb328ee2ba5ff3833e6216ed5654380560d7737141bb4c54

See more details on using hashes here.

Provenance

File details

Details for the file skorch-0.6.0-py3-none-any.whl.

File metadata

  • Download URL: skorch-0.6.0-py3-none-any.whl
  • Upload date:
  • Size: 101.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.12.1 pkginfo/1.5.0.1 requests/2.22.0 setuptools/27.2.0 requests-toolbelt/0.9.1 tqdm/4.14.0 CPython/3.6.2

File hashes

Hashes for skorch-0.6.0-py3-none-any.whl
Algorithm Hash digest
SHA256 3c52f5874b0e231a1cb431b1af096a778aee9b4696f04332e94b0a768da9772a
MD5 b8fb6df3e9fc0ea50e9590e181fe5285
BLAKE2b-256 c7df1e0be91bf4c91fce5f99cc4edd89d3dfc16930d3fc77588493558036a8d2

See more details on using hashes here.

Provenance

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page