Skip to main content

scikit-learn compatible neural network library for pytorch

Project description

.. image:: https://github.com/dnouri/skorch/blob/master/assets/skorch.svg
:width: 30%

------------

|build| |coverage| |docs| |powered|

A scikit-learn compatible neural network library that wraps PyTorch.

.. |build| image:: https://travis-ci.org/dnouri/skorch.svg?branch=master
:alt: Build Status
:scale: 100%
:target: https://travis-ci.org/dnouri/skorch?branch=master

.. |coverage| image:: https://github.com/dnouri/skorch/blob/master/assets/coverage.svg
:alt: Test Coverage
:scale: 100%

.. |docs| image:: https://readthedocs.org/projects/skorch/badge/?version=latest
:alt: Documentation Status
:scale: 100%
:target: https://skorch.readthedocs.io/en/latest/?badge=latest

.. |powered| image:: https://github.com/dnouri/skorch/blob/master/assets/powered.svg
:alt: Powered by
:scale: 100%
:target: https://github.com/ottogroup/

=========
Resources
=========

- `Documentation <https://skorch.readthedocs.io/en/latest/?badge=latest>`_
- `Source Code <https://github.com/dnouri/skorch/>`_

=======
Example
=======

To see a more elaborate example, look `here
<https://github.com/dnouri/skorch/tree/master/notebooks/README.md>`__.

.. code:: python

import numpy as np
from sklearn.datasets import make_classification
import torch
from torch import nn
import torch.nn.functional as F

from skorch.net import NeuralNetClassifier


X, y = make_classification(1000, 20, n_informative=10, random_state=0)
X = X.astype(np.float32)
y = y.astype(np.int64)

class MyModule(nn.Module):
def __init__(self, num_units=10, nonlin=F.relu):
super(MyModule, self).__init__()

self.dense0 = nn.Linear(20, num_units)
self.nonlin = nonlin
self.dropout = nn.Dropout(0.5)
self.dense1 = nn.Linear(num_units, 10)
self.output = nn.Linear(10, 2)

def forward(self, X, **kwargs):
X = self.nonlin(self.dense0(X))
X = self.dropout(X)
X = F.relu(self.dense1(X))
X = F.softmax(self.output(X), dim=-1)
return X


net = NeuralNetClassifier(
MyModule,
max_epochs=10,
lr=0.1,
)

net.fit(X, y)
y_proba = net.predict_proba(X)

In an sklearn Pipeline:

.. code:: python

from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler


pipe = Pipeline([
('scale', StandardScaler()),
('net', net),
])

pipe.fit(X, y)
y_proba = pipe.predict_proba(X)

With grid search

.. code:: python

from sklearn.model_selection import GridSearchCV


params = {
'lr': [0.01, 0.02],
'max_epochs': [10, 20],
'module__num_units': [10, 20],
}
gs = GridSearchCV(net, params, refit=False, cv=3, scoring='accuracy')

gs.fit(X, y)
print(gs.best_score_, gs.best_params_)

============
Installation
============

pip installation
================

To install with pip, run:

.. code:: bash

pip install -U skorch

We recommend to use a virtual environment for this.

>From source
===========

If you would like to use the must recent additions to skorch or
help development, you should install skorch from source.

Using conda
===========

You need a working conda installation. Get the correct miniconda for
your system from `here <https://conda.io/miniconda.html>`__.

If you just want to use skorch, use:

.. code:: bash

git clone https://github.com/dnouri/skorch.git
cd skorch
conda env create
source activate skorch
# install pytorch version for your system (see below)
python setup.py install

If you want to help developing, run:

.. code:: bash

git clone https://github.com/dnouri/skorch.git
cd skorch
conda env create
source activate skorch
# install pytorch version for your system (see below)
conda install --file requirements-dev.txt
python setup.py develop

py.test # unit tests
pylint skorch # static code checks

Using pip
=========

If you just want to use skorch, use:

.. code:: bash

git clone https://github.com/dnouri/skorch.git
cd skorch
# create and activate a virtual environment
pip install -r requirements.txt
# install pytorch version for your system (see below)
python setup.py install

If you want to help developing, run:

.. code:: bash

git clone https://github.com/dnouri/skorch.git
cd skorch
# create and activate a virtual environment
pip install -r requirements.txt
# install pytorch version for your system (see below)
pip install -r requirements-dev.txt
python setup.py develop

py.test # unit tests
pylint skorch # static code checks

PyTorch
=======

PyTorch is not covered by the dependencies, since the PyTorch
version you need is dependent on your system. For installation
instructions for PyTorch, visit the `PyTorch website
<http://pytorch.org/>`__.

In general, this should work (assuming CUDA 9):

.. code:: bash

# using conda:
conda install pytorch cuda90 -c pytorch
# using pip
pip install http://download.pytorch.org/whl/cu90/torch-0.4.0-cp36-cp36m-linux_x86_64.whl

=============
Communication
=============

- `GitHub issues <https://github.com/dnouri/skorch/issues>`_: bug
reports, feature requests, install issues, RFCs, thoughts, etc.

- Slack: We run the #skorch channel on the `PyTorch Slack server
<https://pytorch.slack.com/>`_. If you need an invite, send an
email to daniel.nouri@gmail.com.


Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

skorch-0.2.0.tar.gz (58.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

skorch-0.2.0-py3-none-any.whl (76.1 kB view details)

Uploaded Python 3

File details

Details for the file skorch-0.2.0.tar.gz.

File metadata

  • Download URL: skorch-0.2.0.tar.gz
  • Upload date:
  • Size: 58.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No

File hashes

Hashes for skorch-0.2.0.tar.gz
Algorithm Hash digest
SHA256 b1dd0690167df5d6c369592b41b48df2cf79c23584dac4039f4a3412b5fb28b2
MD5 27cdaec5f0865e81be41fcec04ae11d9
BLAKE2b-256 4b26b7c2ee9a91ac44a667b425747d834f89c938e2e000c76f7a36ce8cc82188

See more details on using hashes here.

File details

Details for the file skorch-0.2.0-py3-none-any.whl.

File metadata

File hashes

Hashes for skorch-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 16ab46e23a1c1aa2bf026b1696239ad373a0ff9476dabb3bd6452abfc560d955
MD5 c045391d33ad610eabbd664d80f40807
BLAKE2b-256 f591d69d044da36315c4c24182938cfb12d7941be0d141884b3ed75fe5d52818

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page