Skip to main content

scikit-learn compatible neural network library for pytorch

Project description

https://github.com/skorch-dev/skorch/blob/master/assets/skorch_bordered.svg

Test Status Test Coverage Documentation Status Hugging Face Integration Powered by

A scikit-learn compatible neural network library that wraps PyTorch.

Resources

Examples

To see more elaborate examples, look here.

import numpy as np
from sklearn.datasets import make_classification
from torch import nn
from skorch import NeuralNetClassifier

X, y = make_classification(1000, 20, n_informative=10, random_state=0)
X = X.astype(np.float32)
y = y.astype(np.int64)

class MyModule(nn.Module):
    def __init__(self, num_units=10, nonlin=nn.ReLU()):
        super().__init__()

        self.dense0 = nn.Linear(20, num_units)
        self.nonlin = nonlin
        self.dropout = nn.Dropout(0.5)
        self.dense1 = nn.Linear(num_units, num_units)
        self.output = nn.Linear(num_units, 2)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, X, **kwargs):
        X = self.nonlin(self.dense0(X))
        X = self.dropout(X)
        X = self.nonlin(self.dense1(X))
        X = self.softmax(self.output(X))
        return X

net = NeuralNetClassifier(
    MyModule,
    max_epochs=10,
    lr=0.1,
    # Shuffle training data on each epoch
    iterator_train__shuffle=True,
)

net.fit(X, y)
y_proba = net.predict_proba(X)

In an sklearn Pipeline:

from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler

pipe = Pipeline([
    ('scale', StandardScaler()),
    ('net', net),
])

pipe.fit(X, y)
y_proba = pipe.predict_proba(X)

With grid search:

from sklearn.model_selection import GridSearchCV

# deactivate skorch-internal train-valid split and verbose logging
net.set_params(train_split=False, verbose=0)
params = {
    'lr': [0.01, 0.02],
    'max_epochs': [10, 20],
    'module__num_units': [10, 20],
}
gs = GridSearchCV(net, params, refit=False, cv=3, scoring='accuracy', verbose=2)

gs.fit(X, y)
print("best score: {:.3f}, best params: {}".format(gs.best_score_, gs.best_params_))

skorch also provides many convenient features, among others:

Installation

skorch requires Python 3.9 or higher.

conda installation

You need a working conda installation. Get the correct miniconda for your system from here.

To install skorch, you need to use the conda-forge channel:

conda install -c conda-forge skorch

We recommend to use a conda virtual environment.

Note: The conda channel is not managed by the skorch maintainers. More information is available here.

pip installation

To install with pip, run:

python -m pip install -U skorch

Again, we recommend to use a virtual environment for this.

From source

If you would like to use the most recent additions to skorch or help development, you should install skorch from source.

Using conda

To install skorch from source using conda, proceed as follows:

git clone https://github.com/skorch-dev/skorch.git
cd skorch
conda create -n skorch-env python=3.12
conda activate skorch-env
python -m pip install torch
python -m pip install .

If you want to help developing, run:

git clone https://github.com/skorch-dev/skorch.git
cd skorch
conda create -n skorch-env python=3.12
conda activate skorch-env
python -m pip install torch
python -m pip install '.[test,docs,dev,extended]'

py.test  # unit tests
pylint skorch  # static code checks

You may adjust the Python version to any of the supported Python versions.

Using pip

For pip, follow these instructions instead:

git clone https://github.com/skorch-dev/skorch.git
cd skorch
# create and activate a virtual environment
# install pytorch version for your system (see below)
python -m pip install .

If you want to help developing, run:

git clone https://github.com/skorch-dev/skorch.git
cd skorch
# create and activate a virtual environment
# install pytorch version for your system (see below)
python -m pip install -e '.[test,docs,dev,extended]'

py.test  # unit tests
pylint skorch  # static code checks

PyTorch

PyTorch is not covered by the dependencies, since the PyTorch version you need is dependent on your OS and device. For installation instructions for PyTorch, visit the PyTorch website. skorch officially supports the last four minor PyTorch versions, which currently are:

  • 2.6.0

  • 2.7.1

  • 2.8.0

  • 2.9.0

However, that doesn’t mean that older versions don’t work, just that they aren’t tested. Since skorch mostly relies on the stable part of the PyTorch API, older PyTorch versions should work fine.

In general, running this to install PyTorch should work:

python -m pip install torch

External resources

  • @jakubczakon: blog post “8 Creators and Core Contributors Talk About Their Model Training Libraries From PyTorch Ecosystem” 2020

  • @BenjaminBossan: talk 1 “skorch: A scikit-learn compatible neural network library” at PyCon/PyData 2019

  • @githubnemo: poster for the PyTorch developer conference 2019

  • @thomasjpfan: talk 2 “Skorch: A Union of Scikit learn and PyTorch” at SciPy 2019

  • @thomasjpfan: talk 3 “Skorch - A Union of Scikit-learn and PyTorch” at PyData 2018

  • @BenjaminBossan: talk 4 “Extend your scikit-learn workflow with Hugging Face and skorch” at PyData Amsterdam 2023 (slides 4)

Communication

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

skorch-1.3.1.tar.gz (249.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

skorch-1.3.1-py3-none-any.whl (268.5 kB view details)

Uploaded Python 3

File details

Details for the file skorch-1.3.1.tar.gz.

File metadata

  • Download URL: skorch-1.3.1.tar.gz
  • Upload date:
  • Size: 249.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.8.0 pkginfo/1.12.1.2 readme-renderer/44.0 requests/2.32.5 requests-toolbelt/1.0.0 urllib3/2.6.2 tqdm/4.67.1 importlib-metadata/8.7.1 keyring/25.7.0 rfc3986/2.0.0 colorama/0.4.6 CPython/3.13.11

File hashes

Hashes for skorch-1.3.1.tar.gz
Algorithm Hash digest
SHA256 7081a0c9ab2361d524826f90c84b04a74cf55338c2b2028fa59a2e39a9019e43
MD5 3ed336e68714273016aa04b09352be32
BLAKE2b-256 211290d072b197bef5033c694ceca3fc5714edda122d8a5ef003d8d03febcb1e

See more details on using hashes here.

File details

Details for the file skorch-1.3.1-py3-none-any.whl.

File metadata

  • Download URL: skorch-1.3.1-py3-none-any.whl
  • Upload date:
  • Size: 268.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.8.0 pkginfo/1.12.1.2 readme-renderer/44.0 requests/2.32.5 requests-toolbelt/1.0.0 urllib3/2.6.2 tqdm/4.67.1 importlib-metadata/8.7.1 keyring/25.7.0 rfc3986/2.0.0 colorama/0.4.6 CPython/3.13.11

File hashes

Hashes for skorch-1.3.1-py3-none-any.whl
Algorithm Hash digest
SHA256 bb06c65a15d0bfc765928a0b3fadf569222e7ec772f81b21d422603d52b4ad32
MD5 e41e275fa07c943a60d8381c0f3aeb99
BLAKE2b-256 564d6fbe78427fa6b5c54cbdbd3b9bdacd214a5bcf9bf4dd247fc9537fba1644

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page