跳转到主要内容

适用于PyTorch的scikit-learn兼容神经网络库

项目描述

https://github.com/skorch-dev/skorch/blob/master/assets/skorch_bordered.svg

Test Status Test Coverage Documentation Status Hugging Face Integration Powered by

一个与scikit-learn兼容的神经网络库,它封装了PyTorch。

资源

示例

要查看更详细的示例,请点击这里

import numpy as np
from sklearn.datasets import make_classification
from torch import nn
from skorch import NeuralNetClassifier

X, y = make_classification(1000, 20, n_informative=10, random_state=0)
X = X.astype(np.float32)
y = y.astype(np.int64)

class MyModule(nn.Module):
    def __init__(self, num_units=10, nonlin=nn.ReLU()):
        super().__init__()

        self.dense0 = nn.Linear(20, num_units)
        self.nonlin = nonlin
        self.dropout = nn.Dropout(0.5)
        self.dense1 = nn.Linear(num_units, num_units)
        self.output = nn.Linear(num_units, 2)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, X, **kwargs):
        X = self.nonlin(self.dense0(X))
        X = self.dropout(X)
        X = self.nonlin(self.dense1(X))
        X = self.softmax(self.output(X))
        return X

net = NeuralNetClassifier(
    MyModule,
    max_epochs=10,
    lr=0.1,
    # Shuffle training data on each epoch
    iterator_train__shuffle=True,
)

net.fit(X, y)
y_proba = net.predict_proba(X)

sklearn Pipeline

from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler

pipe = Pipeline([
    ('scale', StandardScaler()),
    ('net', net),
])

pipe.fit(X, y)
y_proba = pipe.predict_proba(X)

使用网格搜索

from sklearn.model_selection import GridSearchCV

# deactivate skorch-internal train-valid split and verbose logging
net.set_params(train_split=False, verbose=0)
params = {
    'lr': [0.01, 0.02],
    'max_epochs': [10, 20],
    'module__num_units': [10, 20],
}
gs = GridSearchCV(net, params, refit=False, cv=3, scoring='accuracy', verbose=2)

gs.fit(X, y)
print("best score: {:.3f}, best params: {}".format(gs.best_score_, gs.best_params_))

skorch还提供了许多便利的功能,其中包括

安装

skorch需要Python 3.8或更高版本。

conda安装

您需要一个有效的conda安装。从这里获取适用于您的系统的正确miniconda。

要安装skorch,您需要使用conda-forge频道

conda install -c conda-forge skorch

我们建议使用conda虚拟环境

注意:conda频道不由skorch维护者管理。更多信息请参见这里

pip安装

要使用pip安装,请运行

python -m pip install -U skorch

再次强调,我们建议使用虚拟环境来安装。

从源代码安装

如果您想使用skorch的最新功能或帮助开发,您应该从源代码安装skorch。

使用conda

要使用conda从源代码安装skorch,请按照以下步骤操作

git clone https://github.com/skorch-dev/skorch.git
cd skorch
conda create -n skorch-env python=3.10
conda activate skorch-env
conda install -c pytorch pytorch
python -m pip install -r requirements.txt
python -m pip install .

如果您想帮助开发,请运行

git clone https://github.com/skorch-dev/skorch.git
cd skorch
conda create -n skorch-env python=3.10
conda activate skorch-env
conda install -c pytorch pytorch
python -m pip install -r requirements.txt
python -m pip install -r requirements-dev.txt
python -m pip install -e .

py.test  # unit tests
pylint skorch  # static code checks

您可以调整Python版本为任何受支持的Python版本。

使用pip

对于pip,请按照以下说明操作

git clone https://github.com/skorch-dev/skorch.git
cd skorch
# create and activate a virtual environment
python -m pip install -r requirements.txt
# install pytorch version for your system (see below)
python -m pip install .

如果您想帮助开发,请运行

git clone https://github.com/skorch-dev/skorch.git
cd skorch
# create and activate a virtual environment
python -m pip install -r requirements.txt
# install pytorch version for your system (see below)
python -m pip install -r requirements-dev.txt
python -m pip install -e .

py.test  # unit tests
pylint skorch  # static code checks

PyTorch

由于您需要的PyTorch版本取决于您的操作系统和设备,因此PyTorch不是依赖项。有关PyTorch的安装说明,请访问PyTorch网站。skorch官方支持最后四个次要版本的PyTorch,目前是

  • 2.0.1

  • 2.1.2

  • 2.2.2

  • 2.3.0

但这并不意味着旧版本不工作,只是它们没有经过测试。由于skorch主要依赖于PyTorch API的稳定部分,因此旧版本的PyTorch应该可以正常工作。

通常,运行以下命令安装PyTorch应该可以正常工作

# using conda:
conda install pytorch pytorch-cuda -c pytorch
# using pip
python -m pip install torch

外部资源

  • @jakubczakon: 博文 “8 Creators and Core Contributors Talk About Their Model Training Libraries From PyTorch Ecosystem” 2020

  • @BenjaminBossan: 演讲 1 “skorch: A scikit-learn compatible neural network library” at PyCon/PyData 2019

  • @githubnemo: 海报 for the PyTorch developer conference 2019

  • @thomasjpfan: 演讲 2 “Skorch: A Union of Scikit learn and PyTorch” at SciPy 2019

  • @thomasjpfan: 演讲 3 “Skorch - A Union of Scikit-learn and PyTorch” at PyData 2018

  • @BenjaminBossan: 演讲 4 “Extend your scikit-learn workflow with Hugging Face and skorch” at PyData Amsterdam 2023 (幻灯片 4)

沟通

项目详情


下载文件

下载适用于您的平台的文件。如果您不确定要选择哪个,请了解更多关于安装包的信息。

源代码分发

skorch-1.0.0.tar.gz (218.4 kB 查看哈希值)

上传时间 源码

构建发行版

skorch-1.0.0-py3-none-any.whl (239.4 kB 查看哈希值)

上传时间 Python 3

由以下支持