跳转到主要内容

机器学习工具箱

项目描述

PyPI version License

机器学习工具箱

这是一个机器学习工具箱。一个旨在重用和扩展的有用机器学习工具集合。工具箱包含以下模块

  • hyperopt - Hyperopt工具用于保存和重新启动评估
  • keras - Keras (tf.keras)的各种指标和其他Keras工具的回调
  • lightgbm - LightGBM的指标工具函数
  • metrics - 几种指标实现
  • plot - 绘图和可视化工具
  • tools - 各种(如统计)工具

模块: hyperopt

此模块包含一个用于保存和重新启动Hyperopt评估的工具函数。这是通过保存和加载hyperopt.Trials对象来完成的。使用方式如下

from mltb.hyperopt import fmin
from hyperopt import tpe, hp, STATUS_OK


def objective(x):
    return {
        'loss': x ** 2,
        'status': STATUS_OK,
        'other_stuff': {'type': None, 'value': [0, 1, 2]},
        }


best, trials = fmin(objective,
    space=hp.uniform('x', -10, 10),
    algo=tpe.suggest,
    max_evals=100,
    filename='trials_file')

print('best:', best)
print('number of trials:', len(trials.trials))

第一次运行输出

No trials file "trials_file" found. Created new trials object.
100%|██████████| 100/100 [00:00<00:00, 338.61it/s, best loss: 0.0007185087453453681]
best: {'x': 0.026805013436769026}
number of trials: 100

第二次运行输出

100 evals loaded from trials file "trials_file".
100%|██████████| 100/100 [00:00<00:00, 219.65it/s, best loss: 0.00012259809712488858]
best: {'x': 0.011072402500130158}
number of trials: 200

模块: lightgbm

此模块实现了LightGBM中未包含的指标函数。目前这是二元和多类问题的F1-和准确率分数。使用方式如下

bst = lgb.train(param,
                train_data,
                valid_sets=[validation_data]
                early_stopping_rounds=10,
                evals_result=evals_result,
                feval=mltb.lightgbm.multi_class_f1_score_factory(num_classes, 'macro'),
               )

模块:keras(用于tf.keras)

BinaryClassifierMetricsCallback

此模块以回调的形式提供自定义指标。因为回调将这些值添加到内部logs字典中,所以可以使用EarlyStopping回调在这些指标上执行早期停止。

参数

参数 描述 类型 默认值
val_data 验证输入 列表
val_label 验证输出 列表
pos_label 哪个索引是正标签 Optional[int] 1
metrics 支持的指标名称或自定义指标函数的列表 List[Union[str, Callable]] ['val_roc_auc', 'val_average_precision', 'val_f1', 'val_acc']

可用指标

  • val_roc_auc : ROC-AUC
  • val_f1 : F1分数
  • val_acc: 准确率
  • val_average_precision: 平均精度
  • val_mcc: Matthew相关系数

使用方法如下

bcm_callback = mltb.keras.BinaryClassifierMetricsCallback(val_data, val_labels)
es_callback = callbacks.EarlyStopping(monitor='val_roc_auc', patience=5,  mode='max')

history = network.fit(train_data, train_labels,
                      epochs=1000,
                      batch_size=128,

                      #do not give validation_data here or validation will be done twice
                      #validation_data=(val_data, val_labels),

                      #always provide BinaryClassifierMetricsCallback before the EarlyStopping callback
                      callbacks=[bcm_callback, es_callback],
)

您也可以定义自己的自定义指标

def custom_average_recall_score(y_true, y_pred, pos_label):
    rounded_pred = np.rint(y_pred)
    return sklearn.metrics.recall_score(y_true, rounded_pred, pos_label)


bcm_callback = mltb.keras.BinaryClassifierMetricsCallback(val_data, val_labels,metrics=[custom_average_recall_score])
es_callback = callbacks.EarlyStopping(monitor='custom_average_recall_score', patience=5,  mode='max')

history = network.fit(train_data, train_labels,
                      epochs=1000,
                      batch_size=128,

                      #do not give validation_data here or validation will be done twice
                      #validation_data=(val_data, val_labels),

                      #always provide BinaryClassifierMetricsCallback before the EarlyStopping callback
                      callbacks=[bcm_callback, es_callback],
)

项目详情


下载文件

下载适用于您平台的文件。如果您不确定选择哪个,请了解有关安装包的更多信息。

源分布

mltb-0.8.0.tar.gz (14.4 kB 查看哈希)

上传时间

构建分布

mltb-0.8.0-py3-none-any.whl (14.2 kB 查看哈希)

上传时间 Python 3

支持者

AWS AWS 云计算和安全赞助商 Datadog Datadog 监控 Fastly Fastly CDN Google Google 下载分析 Microsoft Microsoft PSF 赞助商 Pingdom Pingdom 监控 Sentry Sentry 错误记录 StatusPage StatusPage 状态页面