跳转到主要内容

广义多类支持向量机

项目描述

GenSVM Python包

Build Status Documentation Status

这是Gerrit J.J. van den Burg和Patrick J.F. Groenen开发的GenSVM多类分类器的Python包。

有用链接

安装

安装GenSVM之前,需要一个有效的NumPy安装。因此,可以使用以下命令安装GenSVM:

$ pip install numpy && pip install gensvm

如果您遇到任何错误,请在GitHub上创建一个问题。不要犹豫,您正在帮助使这个项目变得更好!

引用

如果您在研究中使用了这个包,请引用这篇论文,例如使用以下BibTeX条目:

@article{JMLR:v17:14-526,
        author  = {{van den Burg}, G. J. J. and Groenen, P. J. F.},
        title   = {{GenSVM}: A Generalized Multiclass Support Vector Machine},
        journal = {Journal of Machine Learning Research},
        year    = {2016},
        volume  = {17},
        number  = {225},
        pages   = {1-42},
        url     = {http://jmlr.org/papers/v17/14-526.html}
}

使用

该包包含两个类来拟合GenSVM模型:GenSVMGenSVMGridSearchCV。这些类分别拟合单个GenSVM模型或拟合一系列模型进行参数网格搜索。这些类的接口与Scikit-Learn中的分类器接口相同,因此熟悉Scikit-Learn的用户应该没有使用这个包的问题。以下我们将展示一些在实际中使用GenSVM分类器和GenSVMGridSearchCV类的示例。

在示例中,我们假设已经按照以下方式从Scikit-Learn加载了iris数据集

>>> from sklearn.datasets import load_iris
>>> from sklearn.model_selection import train_test_split
>>> from sklearn.preprocessing import MaxAbsScaler
>>> X, y = load_iris(return_X_y=True)
>>> X_train, X_test, y_train, y_test = train_test_split(X, y)
>>> scaler = MaxAbsScaler().fit(X_train)
>>> X_train, X_test = scaler.transform(X_train), scaler.transform(X_test)

请注意,我们使用MaxAbsScaler函数进行数据缩放。该函数将数据矩阵的列缩放到[-1, 1]范围内,而不破坏稀疏性。对数据集进行缩放可以显著影响GenSVM的计算时间,通常建议对SVM进行缩放。

示例1:拟合单个GenSVM模型

让我们先在训练数据上拟合最基础的GenSVM模型。

>>> from gensvm import GenSVM
>>> clf = GenSVM()
>>> clf.fit(X_train, y_train)
GenSVM(coef=0.0, degree=2.0, epsilon=1e-06, gamma='auto', kappa=0.0,
kernel='linear', kernel_eigen_cutoff=1e-08, lmd=1e-05,
max_iter=100000000.0, p=1.0, random_state=None, verbose=0,
weights='unit')

模型拟合后,我们可以预测测试数据集。

>>> y_pred = clf.predict(X_test)

接下来,我们可以计算预测的评分。GenSVM类有一个score方法,可以计算预测的准确率。在GenSVM论文中,经常使用调整后的Rand指数来比较性能。下面展示了这两种选项(具体结果可能因训练/测试集的划分而异)。

>>> clf.score(X_test, y_test)
1.0
>>> from sklearn.metrics import adjusted_rand_score
>>> adjusted_rand_score(clf.predict(X_test), y_test)
1.0

我们可以尝试通过更改模型参数再次执行此操作,例如,我们可以启用可打印详细信息的模式并使用GenSVM模型中的欧几里得范数,通过设置p = 2实现。

>>> clf2 = GenSVM(verbose=True, p=2)
>>> clf2.fit(X_train, y_train)
Starting main loop.
Dataset:
    n = 112
    m = 4
    K = 3
Parameters:
    kappa = 0.000000
    p = 2.000000
    lambda = 0.0000100000000000
    epsilon = 1e-06

iter = 0, L = 3.4499531579689533, Lbar = 7.3369415851139745, reldiff = 1.1266786095824437
...
Optimization finished, iter = 4046, loss = 0.0230726364692517, rel. diff. = 0.0000009998645783
Number of support vectors: 9
GenSVM(coef=0.0, degree=2.0, epsilon=1e-06, gamma='auto', kappa=0.0,
    kernel='linear', kernel_eigen_cutoff=1e-08, lmd=1e-05,
    max_iter=100000000.0, p=2, random_state=None, verbose=True,
    weights='unit')

有关可以在GenSVM模型中调整的其他参数,请参阅GenSVM

示例2:使用“预热启动”拟合GenSVM模型

GenSVM分类器的一个关键特性是,通过使用所谓的“预热启动”,可以加速训练过程。这样,优化可以从比随机起始位置更接近最终解决方案的位置开始。为此,GenSVM类的fit方法有一个可选的seed_V参数。下面我们将展示如何使用它。

我们从模型中的epsilon参数的相对较大值开始。这是决定优化持续时间的停止参数(因此也决定了拟合的精确度)。

>>> clf1 = GenSVM(epsilon=1e-3)
>>> clf1.fit(X_train, y_train)
...
>>> clf1.n_iter_
163

n_iter_属性告诉我们模型进行了多少次迭代。现在,我们可以使用此模型的解决方案来启动下一个模型的训练。

>>> clf2 = GenSVM(epsilon=1e-8)
>>> clf2.fit(X_train, y_train, seed_V=clf1.combined_coef_)
...
>>> clf2.n_iter_
3196

将此与具有相同停止参数但没有预热启动的模型进行比较。

>>> clf2.fit(X_train, y_train)
...
>>> clf2.n_iter_
3699

因此,我们节省了大约500次迭代!在大数据集和尝试许多参数配置时,这种效果将特别明显。因此,这种技术被整合到了GenSVMGridSearchCV类中,该类可用于进行参数网格搜索。

示例3:运行GenSVM网格搜索

当拟合机器学习模型,如GenSVM时,我们通常需要尝试多种参数配置,以确定哪个在我们的给定数据集上表现最佳。这通常与交叉验证结合使用,以避免过拟合。为了有效地执行此操作并利用预热启动,提供了GenSVMGridSearchCV类。此类的工作方式与GridSearchCV类相同,但使用了GenSVM C库以提高速度。

要进行网格搜索,我们首先必须定义我们想要变化的参数以及我们想要尝试的值。

>>> from gensvm import GenSVMGridSearchCV
>>> param_grid = {'p': [1.0, 2.0], 'lmd': [1e-8, 1e-6, 1e-4, 1e-2, 1.0], 'kappa': [-0.9, 0.0] }

对于参数网格中未变化的值,将使用默认值。这意味着如果您想更改特定值(例如,例如,更改epsilon),您可以将其添加到参数网格中作为具有单个尝试值的参数(例如,'epsilon': [1e-8])。

运行网格搜索现在变得简单。

>>> gg = GenSVMGridSearchCV(param_grid)
>>> gg.fit(X_train, y_train)
GenSVMGridSearchCV(cv=None, iid=True,
      param_grid={'p': [1.0, 2.0], 'lmd': [1e-06, 0.0001, 0.01, 1.0], 'kappa': [-0.9, 0.0]},
      refit=True, return_train_score=True, scoring=None, verbose=0)

请注意,如果我们已设置refit=True(默认值),则可以使用GenSVMGridSearchCV实例使用网格搜索中找到的最佳估计量进行预测或评分。

>>> y_pred = gg.predict(X_test)
>>> gg.score(X_test, y_test)
1.0

Scikit-Learn借用的一个不错的特点是,网格搜索的结果可以用作pandas DataFrame。

>>> from pandas import DataFrame
>>> df = DataFrame(gg.cv_results_)

这可以使探索网格搜索结果更加容易。

已知限制

以下为已知限制,这些限制将在未来版本中解决。如果您需要这些功能中的任何一项,请前往链接的GitHub问题页面进行投票(这可以帮助我们更快地添加它们!)。

  1. 稀疏矩阵支持。NumPy支持稀疏矩阵,GenSVM C库也是如此。使它们协同工作需要一些额外的工作。在此期间,如果您真的想使用GenSVM与稀疏数据(这可能导致显著加速!),请查看GenSVM C库。
  2. 指定类误分类权重。目前,将类A的对象错误分类到类C与将类B的对象错误分类到类C一样糟糕。根据应用场景,这可能不是期望的效果。添加类误分类权重可以解决这个问题。

问题和疑问

如果您在使用此软件包时有任何疑问或遇到任何问题,请在GitHub上提问。

许可证

本软件包根据GNU通用公共许可证第3版授权。

版权(c)G.J.J. van den Burg,不包括明确标记为来自Scikit-Learn的代码部分。

项目详情


下载文件

下载适合您平台的应用文件。如果您不确定选择哪个,请了解更多关于安装软件包的信息。

源代码发行版

gensvm-0.2.7.tar.gz (179.0 kB 查看哈希值)

上传时间 源代码

构建发行版

gensvm-0.2.7-cp38-cp38-manylinux2010_x86_64.whl (4.1 MB 查看哈希值)

上传时间 CPython 3.8 manylinux: glibc 2.12+ x86-64

gensvm-0.2.7-cp38-cp38-manylinux2010_i686.whl (3.5 MB 查看哈希值)

上传时间 CPython 3.8 manylinux: glibc 2.12+ i686

gensvm-0.2.7-cp38-cp38-macosx_10_14_x86_64.whl (134.7 kB 查看哈希值)

上传时间 CPython 3.8 macOS 10.14+ x86-64

gensvm-0.2.7-cp37-cp37m-manylinux2010_x86_64.whl (4.1 MB 查看哈希值)

上传时间 CPython 3.7m manylinux: glibc 2.12+ x86-64

gensvm-0.2.7-cp37-cp37m-manylinux2010_i686.whl (3.5 MB 查看哈希值)

上传时间: CPython 3.7m manylinux: glibc 2.12+ i686

gensvm-0.2.7-cp37-cp37m-macosx_10_14_intel.whl (215.0 kB 查看哈希值)

上传时间: CPython 3.7m macOS 10.14+ intel

gensvm-0.2.7-cp36-cp36m-manylinux2010_x86_64.whl (4.1 MB 查看哈希值)

上传时间: CPython 3.6m manylinux: glibc 2.12+ x86-64

gensvm-0.2.7-cp36-cp36m-manylinux2010_i686.whl (3.5 MB 查看哈希值)

上传时间: CPython 3.6m manylinux: glibc 2.12+ i686

gensvm-0.2.7-cp36-cp36m-macosx_10_14_intel.whl (134.8 kB 查看哈希值)

上传时间: CPython 3.6m macOS 10.14+ intel

由以下支持

AWSAWS 云计算和安全赞助商 DatadogDatadog 监控 FastlyFastly CDN GoogleGoogle 下载分析 MicrosoftMicrosoft PSF赞助商 PingdomPingdom 监控 SentrySentry 错误记录 StatusPageStatusPage 状态页面