Skip to main content

scikit-learn compatible neural network library for pytorch

Project description

.. image:: https://github.com/dnouri/skorch/blob/master/assets/skorch.svg
:width: 30%

------------

|build| |coverage| |docs| |powered|

A scikit-learn compatible neural network library that wraps PyTorch.

.. |build| image:: https://travis-ci.org/dnouri/skorch.svg?branch=master
:alt: Build Status
:scale: 100%
:target: https://travis-ci.org/dnouri/skorch?branch=master

.. |coverage| image:: https://github.com/dnouri/skorch/blob/master/assets/coverage.svg
:alt: Test Coverage
:scale: 100%

.. |docs| image:: https://readthedocs.org/projects/skorch/badge/?version=latest
:alt: Documentation Status
:scale: 100%
:target: https://skorch.readthedocs.io/en/latest/?badge=latest

.. |powered| image:: https://github.com/dnouri/skorch/blob/master/assets/powered.svg
:alt: Powered by
:scale: 100%
:target: https://github.com/ottogroup/

=========
Resources
=========

- `Documentation <https://skorch.readthedocs.io/en/latest/?badge=latest>`_
- `Source Code <https://github.com/dnouri/skorch/>`_

========
Examples
========

To see more elaborate examples, look `here
<https://github.com/dnouri/skorch/tree/master/notebooks/README.md>`__.

.. code:: python

import numpy as np
from sklearn.datasets import make_classification
from torch import nn
import torch.nn.functional as F

from skorch import NeuralNetClassifier


X, y = make_classification(1000, 20, n_informative=10, random_state=0)
X = X.astype(np.float32)
y = y.astype(np.int64)

class MyModule(nn.Module):
def __init__(self, num_units=10, nonlin=F.relu):
super(MyModule, self).__init__()

self.dense0 = nn.Linear(20, num_units)
self.nonlin = nonlin
self.dropout = nn.Dropout(0.5)
self.dense1 = nn.Linear(num_units, 10)
self.output = nn.Linear(10, 2)

def forward(self, X, **kwargs):
X = self.nonlin(self.dense0(X))
X = self.dropout(X)
X = F.relu(self.dense1(X))
X = F.softmax(self.output(X), dim=-1)
return X


net = NeuralNetClassifier(
MyModule,
max_epochs=10,
lr=0.1,
)

net.fit(X, y)
y_proba = net.predict_proba(X)

In an sklearn Pipeline:

.. code:: python

from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler


pipe = Pipeline([
('scale', StandardScaler()),
('net', net),
])

pipe.fit(X, y)
y_proba = pipe.predict_proba(X)

With grid search

.. code:: python

from sklearn.model_selection import GridSearchCV


params = {
'lr': [0.01, 0.02],
'max_epochs': [10, 20],
'module__num_units': [10, 20],
}
gs = GridSearchCV(net, params, refit=False, cv=3, scoring='accuracy')

gs.fit(X, y)
print(gs.best_score_, gs.best_params_)

skorch also provides many convenient features, among others:

- `Learning rate schedulers <https://skorch.readthedocs.io/en/stable/callbacks.html#skorch.callbacks.LRScheduler>`_ (Warm restarts, cyclic LR and many more)
- `Scoring using sklearn (and custom) scoring functions <https://skorch.readthedocs.io/en/stable/callbacks.html#skorch.callbacks.EpochScoring>`_
- `Early stopping <https://skorch.readthedocs.io/en/stable/callbacks.html#skorch.callbacks.EarlyStopping>`_
- `Checkpointing <https://skorch.readthedocs.io/en/stable/callbacks.html#skorch.callbacks.Checkpoint>`_
- `Parameter freezing/unfreezing <https://skorch.readthedocs.io/en/stable/callbacks.html#skorch.callbacks.Freezer>`_
- `Progress bar <https://skorch.readthedocs.io/en/stable/callbacks.html#skorch.callbacks.ProgressBar>`_ (for CLI as well as jupyter)
- `Automatic inference of CLI parameters <https://github.com/dnouri/skorch/tree/master/examples/cli>`_

============
Installation
============

skorch requires Python 3.5 or higher.

pip installation
================

To install with pip, run:

.. code:: bash

pip install -U skorch

We recommend to use a virtual environment for this.

From source
===========

If you would like to use the must recent additions to skorch or
help development, you should install skorch from source.

Using conda
===========

You need a working conda installation. Get the correct miniconda for
your system from `here <https://conda.io/miniconda.html>`__.

If you just want to use skorch, use:

.. code:: bash

git clone https://github.com/dnouri/skorch.git
cd skorch
conda env create
source activate skorch
# install pytorch version for your system (see below)
python setup.py install

If you want to help developing, run:

.. code:: bash

git clone https://github.com/dnouri/skorch.git
cd skorch
conda env create
source activate skorch
# install pytorch version for your system (see below)
conda install --file requirements-dev.txt
python setup.py develop

py.test # unit tests
pylint skorch # static code checks

Using pip
=========

If you just want to use skorch, use:

.. code:: bash

git clone https://github.com/dnouri/skorch.git
cd skorch
# create and activate a virtual environment
pip install -r requirements.txt
# install pytorch version for your system (see below)
python setup.py install

If you want to help developing, run:

.. code:: bash

git clone https://github.com/dnouri/skorch.git
cd skorch
# create and activate a virtual environment
pip install -r requirements.txt
# install pytorch version for your system (see below)
pip install -r requirements-dev.txt
python setup.py develop

py.test # unit tests
pylint skorch # static code checks

PyTorch
=======

PyTorch is not covered by the dependencies, since the PyTorch
version you need is dependent on your system. For installation
instructions for PyTorch, visit the `PyTorch website
<http://pytorch.org/>`__.

In general, this should work (assuming CUDA 9):

.. code:: bash

# using conda:
conda install pytorch -c pytorch
# using pip
pip install torch

=============
Communication
=============

- `GitHub issues <https://github.com/dnouri/skorch/issues>`_: bug
reports, feature requests, install issues, RFCs, thoughts, etc.

- Slack: We run the #skorch channel on the `PyTorch Slack server
<https://pytorch.slack.com/>`_. If you need an invite, send an
email to daniel.nouri@gmail.com.


Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

skorch-0.5.0.tar.gz (86.0 kB view details)

Uploaded Source

Built Distributions

skorch-0.5.0-py3.6.egg (245.3 kB view details)

Uploaded Source

skorch-0.5.0-py3-none-any.whl (99.7 kB view details)

Uploaded Python 3

File details

Details for the file skorch-0.5.0.tar.gz.

File metadata

  • Download URL: skorch-0.5.0.tar.gz
  • Upload date:
  • Size: 86.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.12.1 pkginfo/1.4.2 requests/2.20.1 setuptools/27.2.0 requests-toolbelt/0.8.0 tqdm/4.14.0 CPython/3.6.2

File hashes

Hashes for skorch-0.5.0.tar.gz
Algorithm Hash digest
SHA256 1ec8c9144da95a633ff13db8b769011ff958cdc5c2b4cff09ada8f9a4e6fb6a6
MD5 4f4f3bc8222817837bdc5f82f39ef96d
BLAKE2b-256 9b4a94661d27e6ac6e6ec14d1f2fd1fc03aa3d5226388ce1c8caf848132bbc39

See more details on using hashes here.

File details

Details for the file skorch-0.5.0-py3.6.egg.

File metadata

  • Download URL: skorch-0.5.0-py3.6.egg
  • Upload date:
  • Size: 245.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.12.1 pkginfo/1.4.2 requests/2.20.1 setuptools/27.2.0 requests-toolbelt/0.8.0 tqdm/4.14.0 CPython/3.6.2

File hashes

Hashes for skorch-0.5.0-py3.6.egg
Algorithm Hash digest
SHA256 826cc8138e992cecf7ea3be7d1e64fd1603ff1f9aefa50a0d57ebbb229559279
MD5 cf67080d7ba3beb048f4452bc602afdf
BLAKE2b-256 473e9555482e43c8e4e7b11d19c52941385866031bb7368cc7c9fc6d742a3feb

See more details on using hashes here.

File details

Details for the file skorch-0.5.0-py3-none-any.whl.

File metadata

  • Download URL: skorch-0.5.0-py3-none-any.whl
  • Upload date:
  • Size: 99.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.12.1 pkginfo/1.4.2 requests/2.20.1 setuptools/27.2.0 requests-toolbelt/0.8.0 tqdm/4.14.0 CPython/3.6.2

File hashes

Hashes for skorch-0.5.0-py3-none-any.whl
Algorithm Hash digest
SHA256 154e24e6cbd608cc41cbc8c515e980390c46a52b8380958d2f590205cd1f7674
MD5 e5d75d15c25ec61de0a176f3ad9d36ec
BLAKE2b-256 60e8ba5c79709b1e4381f15889662c9937442af5e85d6780bbf246053082bfc4

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page