适用于PyTorch的scikit-learn兼容神经网络库
项目描述
一个与scikit-learn兼容的神经网络库,它封装了PyTorch。
资源
示例
要查看更详细的示例,请点击这里。
import numpy as np
from sklearn.datasets import make_classification
from torch import nn
from skorch import NeuralNetClassifier
X, y = make_classification(1000, 20, n_informative=10, random_state=0)
X = X.astype(np.float32)
y = y.astype(np.int64)
class MyModule(nn.Module):
def __init__(self, num_units=10, nonlin=nn.ReLU()):
super().__init__()
self.dense0 = nn.Linear(20, num_units)
self.nonlin = nonlin
self.dropout = nn.Dropout(0.5)
self.dense1 = nn.Linear(num_units, num_units)
self.output = nn.Linear(num_units, 2)
self.softmax = nn.Softmax(dim=-1)
def forward(self, X, **kwargs):
X = self.nonlin(self.dense0(X))
X = self.dropout(X)
X = self.nonlin(self.dense1(X))
X = self.softmax(self.output(X))
return X
net = NeuralNetClassifier(
MyModule,
max_epochs=10,
lr=0.1,
# Shuffle training data on each epoch
iterator_train__shuffle=True,
)
net.fit(X, y)
y_proba = net.predict_proba(X)
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
pipe = Pipeline([
('scale', StandardScaler()),
('net', net),
])
pipe.fit(X, y)
y_proba = pipe.predict_proba(X)
使用网格搜索
from sklearn.model_selection import GridSearchCV
# deactivate skorch-internal train-valid split and verbose logging
net.set_params(train_split=False, verbose=0)
params = {
'lr': [0.01, 0.02],
'max_epochs': [10, 20],
'module__num_units': [10, 20],
}
gs = GridSearchCV(net, params, refit=False, cv=3, scoring='accuracy', verbose=2)
gs.fit(X, y)
print("best score: {:.3f}, best params: {}".format(gs.best_score_, gs.best_params_))
skorch还提供了许多便利的功能,其中包括
安装
skorch需要Python 3.8或更高版本。
conda安装
您需要一个有效的conda安装。从这里获取适用于您的系统的正确miniconda。
要安装skorch,您需要使用conda-forge频道
conda install -c conda-forge skorch
我们建议使用conda虚拟环境。
注意:conda频道不由skorch维护者管理。更多信息请参见这里。
pip安装
要使用pip安装,请运行
python -m pip install -U skorch
再次强调,我们建议使用虚拟环境来安装。
从源代码安装
如果您想使用skorch的最新功能或帮助开发,您应该从源代码安装skorch。
使用conda
要使用conda从源代码安装skorch,请按照以下步骤操作
git clone https://github.com/skorch-dev/skorch.git
cd skorch
conda create -n skorch-env python=3.10
conda activate skorch-env
conda install -c pytorch pytorch
python -m pip install -r requirements.txt
python -m pip install .
如果您想帮助开发,请运行
git clone https://github.com/skorch-dev/skorch.git
cd skorch
conda create -n skorch-env python=3.10
conda activate skorch-env
conda install -c pytorch pytorch
python -m pip install -r requirements.txt
python -m pip install -r requirements-dev.txt
python -m pip install -e .
py.test # unit tests
pylint skorch # static code checks
您可以调整Python版本为任何受支持的Python版本。
使用pip
对于pip,请按照以下说明操作
git clone https://github.com/skorch-dev/skorch.git
cd skorch
# create and activate a virtual environment
python -m pip install -r requirements.txt
# install pytorch version for your system (see below)
python -m pip install .
如果您想帮助开发,请运行
git clone https://github.com/skorch-dev/skorch.git
cd skorch
# create and activate a virtual environment
python -m pip install -r requirements.txt
# install pytorch version for your system (see below)
python -m pip install -r requirements-dev.txt
python -m pip install -e .
py.test # unit tests
pylint skorch # static code checks
PyTorch
由于您需要的PyTorch版本取决于您的操作系统和设备,因此PyTorch不是依赖项。有关PyTorch的安装说明,请访问PyTorch网站。skorch官方支持最后四个次要版本的PyTorch,目前是
2.0.1
2.1.2
2.2.2
2.3.0
但这并不意味着旧版本不工作,只是它们没有经过测试。由于skorch主要依赖于PyTorch API的稳定部分,因此旧版本的PyTorch应该可以正常工作。
通常,运行以下命令安装PyTorch应该可以正常工作
# using conda:
conda install pytorch pytorch-cuda -c pytorch
# using pip
python -m pip install torch
外部资源
@jakubczakon: 博文 “8 Creators and Core Contributors Talk About Their Model Training Libraries From PyTorch Ecosystem” 2020
@BenjaminBossan: 演讲 1 “skorch: A scikit-learn compatible neural network library” at PyCon/PyData 2019
@githubnemo: 海报 for the PyTorch developer conference 2019
@thomasjpfan: 演讲 2 “Skorch: A Union of Scikit learn and PyTorch” at SciPy 2019
@thomasjpfan: 演讲 3 “Skorch - A Union of Scikit-learn and PyTorch” at PyData 2018
@BenjaminBossan: 演讲 4 “Extend your scikit-learn workflow with Hugging Face and skorch” at PyData Amsterdam 2023 (幻灯片 4)
沟通
GitHub讨论:用户问题、想法、安装问题、一般讨论。
GitHub问题:错误报告、功能请求、RFC等。
Slack:我们在PyTorch Slack服务器上运行#skorch频道,您可以在这里申请访问。
项目详情
下载文件
下载适用于您的平台的文件。如果您不确定要选择哪个,请了解更多关于安装包的信息。
源代码分发
构建发行版
skorch-1.0.0.tar.gz 的哈希值
算法 | 哈希摘要 | |
---|---|---|
SHA256 | 25ca65c1a79894644025744d69cd586bf8605a81c4f8d5a3661094f5e692c914 |
|
MD5 | 63dc467a8629ad0044bb6c756550edce |
|
BLAKE2b-256 | 1da2ebb85f845e3fe319af724aa2a5f5a2faff9af51bbeb313775f8a481e354c |
skorch-1.0.0-py3-none-any.whl 的哈希值
算法 | 哈希摘要 | |
---|---|---|
SHA256 | 949d5ce0643d7892e4fc104c2ad6a508c7218fdc34437b87c19a57b606cf672f |
|
MD5 | f73136dbb11bdc490d6ea485ff21fd86 |
|
BLAKE2b-256 | f32f2e5df7adb64d8457ea4104bcdfe4265251825d17bd3752f2de4d364723e8 |