在CIFAR10、CIFAR100、TinyImagenet200、ImageNet上使决策树与最先进的神经网络相竞争。将任何图像分类神经网络转换成可解释的基于神经网络的决策树。
项目描述
神经背靠决策树
项目页面 // 论文 // 无代码Web演示 // Colab笔记本
由Alvin Wan,*Lisa Dunlap,*Daniel Ho,Jihan Yin,Scott Lee,Henry Jin,Suzanne Petryk,Sarah Adel Bargal,Joseph E. Gonzalez *表示贡献相同
运行决策树,在CIFAR10、CIFAR100、TinyImagenet200和ImageNet上实现可解释模型的最优准确率。NBDTs在CIFAR10、CIFAR100和TinyImagenet200上使用最新的WideResNet实现了与原始神经网络1%以内的准确率;在ImageNet上使用最新的EfficientNet实现了与原始神经网络2%以内的准确率。我们实现了ImageNet的top-1准确率为75.13%。
目录
根据上述流程图,我们(1) 生成层次结构,(2)使用树监督损失 训练神经网络。然后,我们(3) 通过使用网络骨干和运行嵌入式决策规则进行特征化来运行推理。
快速入门
在示例上运行预训练的NBDT
不想下载?在网络演示上尝试使用自己的图片。
安装nbdt
工具,并在您选择的图片上运行它。这可以是本地图片路径或图片URL。以下,我们在网络上的一张猫的图片上进行评估。这张猫的图片如下。
pip install nbdt
nbdt https://images.pexels.com/photos/126407/pexels-photo-126407.jpeg?auto=compress&cs=tinysrgb&dpr=2&w=32
这将输出类别预测和所有中间决策,如下所示
Prediction: cat // Decisions: animal (99.47%), chordate (99.20%), carnivore (99.42%), cat (99.86%)
默认情况下,这个评估工具使用在CIFAR10上预训练的WideResNet。您也可以传递CIFAR10中没有看到的类别。以下,我们传递了一头熊和一匹斑马的图片。这张斑马的图片也如下所示。
nbdt https://images.pexels.com/photos/750539/pexels-photo-750539.jpeg?auto=compress&cs=tinysrgb&dpr=2&h=32
nbdt https://images.pexels.com/photos/158109/kodiak-brown-bear-adult-portrait-wildlife-158109.jpeg?auto=compress&cs=tinysrgb&dpr=2&w=32
与之前一样,这将输出类别预测和中间决策。尽管熊和斑马类别在训练时没有看到,但模型仍然正确地将两者都选择为动物而不是车辆。注意,对于斑马,最接近的明显类别是马,与模型下面的预测相匹配。
Prediction: horse // Decisions: animal (99.31%), ungulate (99.25%), horse (99.62%)
Prediction: dog // Decisions: animal (99.51%), chordate (99.35%), carnivore (99.69%), dog (99.22%)
图片来自pexels.com,根据Pexels许可免费使用。
在代码中加载预训练的NBDTs
不想下载?尝试在预先填充的Google Colab笔记本上进行推理。
如果您还没有安装,请使用pip安装nbdt
工具。
pip install nbdt
然后,选择NBDT推理模式(硬或软)、数据集和骨干网络。默认情况下,我们支持CIFAR10、CIFAR100和TinyImagenet200的ResNet18和WideResNet28x10。有关使用ImageNet上的EfficientNet-EdgeTPUSmall的信息,请参阅nbdt-pytorch-image-models。
from nbdt.model import SoftNBDT
from nbdt.models import ResNet18, wrn28_10_cifar10, wrn28_10_cifar100, wrn28_10 # use wrn28_10 for TinyImagenet200
model = wrn28_10_cifar10()
model = SoftNBDT(
pretrained=True,
dataset='CIFAR10',
arch='wrn28_10_cifar10',
model=model)
注意torchvision.models.resnet18
仅支持224x224输入。然而,nbdt.models.resnet.ResNet18
支持可变大小的输入。有关使用您喜欢的图像分类神经网络的信息,请参阅模型。
30行示例:请参阅nbdt/bin/nbdt
,该文件在约30行内加载预训练模型、加载图像并在图像上运行推理。此文件是前节中的可执行文件nbdt
。在Google Colab笔记本中尝试此操作。
将神经网络转换为决策树
要将您的神经网络转换为神经背靠决策树,执行以下3个步骤
- 首先,如果您还没有安装,请使用pip安装
nbdt
工具:pip install nbdt
- 其次,在训练期间,使用自定义NBDT损失包装您的损失
criterion
。以下,我们展示了在CIFAR10数据集上的软树监督损失。默认情况下,我们支持CIFAR10
、CIFAR100
、TinyImagenet200
和Imagenet1000
。
from nbdt.loss import SoftTreeSupLoss
criterion = SoftTreeSupLoss(dataset='CIFAR10', criterion=criterion) # `criterion` is your original loss function e.g., nn.CrossEntropyLoss
- 第三,在推理或验证期间,使用以下方式包装您的
model
。这只是为了在验证或推理时作为NBDT运行预测。不要在训练期间像下面那样包装您的模型。
from nbdt.model import SoftNBDT
model = SoftNBDT(dataset='CIFAR10', model=model) # `model` is your original model
与存储库的示例集成:请参阅nbdt-pytorch-image-models
,该文件将此3步集成应用于流行的图像分类存储库pytorch-image-models
。
16行中的随机神经网络集成示例 [点击展开]
您还可以包括本存储库中未明确支持的任意图像分类神经网络。例如,在安装pretrained-models.pytorch
后,您可以使用pip实例化和传递任何预训练模型到我们的NBDT实用函数中。
from nbdt.model import SoftNBDT
from nbdt.loss import SoftTreeSupLoss
from nbdt.hierarchy import generate_hierarchy
import pretrainedmodels
model = pretrainedmodels.__dict__['fbresnet152'](num_classes=1000, pretrained='imagenet')
# 1. generate hierarchy from pretrained model
generate_hierarchy(dataset='Imagenet1000', arch='fbresnet152', model=model)
# 2. Fine-tune model with tree supervision loss
criterion = ...
criterion = SoftTreeSupLoss(dataset='Imagenet1000', hierarchy='induced-fbresnet152', criterion=criterion)
# 3. Run inference using embedded decision rules
model = SoftNBDT(model=model, dataset='Imagenet1000', hierarchy='induced-fbresnet152')
有关生成不同层次结构的信息,请参阅诱导层次结构。
想要构建和使用您自己的诱导层次结构? (点击展开)
使用nbdt-hierarchy
工具从预训练模型生成新的诱导层次结构。
nbdt-hierarchy --arch=efficientnet_b0 --dataset=Imagenet1000
然后,将层次结构名称传递给损失和模型。您还可以传递完全合格的path_graph
路径。
from nbdt.loss import SoftTreeSupLoss
from nbdt.model import SoftNBDT
criterion = SoftTreeSupLoss(dataset='Imagenet1000', criterion=criterion, hierarchy='induced-efficientnet_b0')
model = SoftNBDT(dataset='Imagenet1000', model=model, hierarchy='induced-efficientnet_b0')
有关生成不同层次结构的信息,请参阅诱导层次结构。
训练和评估
要重现实验结果,首先克隆存储库并安装所有要求。
git clone git@github.com:alvinwan/neural-backed-decision-trees.git # or http addr if you don't have private-public github key setup
cd neural-backed-decision-trees
python setup.py develop
要重现论文中的核心实验结果(忽略消融研究),只需运行以下bash脚本
bash scripts/gen_train_eval_wideresnet.sh
需要更详细的分步指导?上面的bash脚本在以下章节中有更详细的解释: 诱导层次结构,软树监督损失,以及 软推理。这些脚本重现了我们的 CIFAR10、CIFAR100 和 TinyImagenet200 结果。要重现我们的 ImageNet 结果,请参阅 nbdt-pytorch-image-models
。
对于所有脚本,您可以使用任何torchvision
模型或任何pytorchcv
模型,因为我们直接支持这两个模型库。每个步骤的定制将在下面解释。
1. 层次结构
诱导层次结构
运行以下命令以生成和测试基于WideResNet模型的CIFAR10诱导层次结构。
nbdt-hierarchy --arch=wrn28_10_cifar10 --dataset=CIFAR10
查看如何工作以及如何配置。 (点击展开)
该脚本加载预训练模型(步骤A),用全连接层权重填充树的叶子节点(步骤B)并执行层次聚类(步骤C)。请注意,上述命令可以用不同的架构、不同的数据集或随机的神经网络检查点重新运行,以产生不同的层次结构。
# different architecture: ResNet18
nbdt-hierarchy --arch=ResNet18 --dataset=CIFAR10
# different dataset: ImageNet
nbdt-hierarchy --arch=efficientnet_b7 --dataset=Imagenet1000
# arbitrary checkpoint
wget https://download.pytorch.org/models/resnet18-5c106cde.pth -O resnet18.pth
nbdt-hierarchy --checkpoint=resnet18.pth --dataset=Imagenet1000
您还可以通过传递预训练模型直接从源代码运行层次结构生成,而不使用命令行工具。
from nbdt.hierarchy import generate_hierarchy
from nbdt.models import wrn28_10_cifar10
model = wrn28_10_cifar10(pretrained=True)
generate_hierarchy(dataset='Imagenet1000', arch='wrn28_10_cifar10', model=model)
查看示例可视化。 (点击展开)
默认情况下,生成脚本输出包含d3可视化的HTML文件。所有可视化都存储在out/
中。我们将生成另一个具有更大字体大小和包含wordnet ID的可视化。
nbdt-hierarchy --vis-sublabels --vis-zoom=1.25 --dataset=CIFAR10 --arch=wrn28_10_cifar10
上述脚本的输出将结束于以下内容。
==> Reading from ./nbdt/hierarchies/CIFAR10/graph-induced-wrn28_10_cifar10.json
Found just 1 root.
==> Wrote HTML to out/induced-wrn28_10_cifar10-tree.html
在您的浏览器中打开out/induced-wrn28_10_cifar10-tree.html
以查看d3树可视化。
想要重现论文中的层次结构可视化? (点击展开)
要生成论文中的图,请使用更大的缩放比例,并且不要包括子标签。用于生成诱导层次结构可视化的检查点包含在此存储库的模型库中。
nbdt-hierarchy --vis-zoom=2.5 --dataset=CIFAR10 --arch=ResNet10 --vis-force-labels-left conveyance vertebrate chordate vehicle motor_vehicle mammal placental
nbdt-hierarchy --vis-zoom=2.5 --dataset=CIFAR10 --arch=wrn28_10_cifar10 --vis-leaf-images --vis-image-resize-factor=1.5 --vis-force-labels-left motor_vehicle craft chordate vertebrate carnivore ungulate craft
nbdt-hierarchy --vis-zoom=2.5 --dataset=CIFAR10 --arch=wrn28_10_cifar10 --vis-color-nodes whole --vis-no-color-leaves --vis-force-labels-left motor_vehicle craft chordate vertebrate carnivore ungulate craft
WordNet层次结构
运行以下命令以生成和测试CIFAR10、CIFAR100和TinyImagenet200的WordNet层次结构。该脚本还下载NLTK WordNet语料库。
bash scripts/generate_hierarchies_wordnet.sh
查看如何工作。 (点击展开)
以下仅解释上述generate_hierarchies_wordnet.sh
,使用CIFAR10。您在运行上述bash脚本后不需要运行以下内容。
# Generate mapping from classes to WNID. This is required for CIFAR10 and CIFAR100.
nbdt-wnids --dataset=CIFAR10
# Generate hierarchy, using the WNIDs. This is required for all datasets: CIFAR10, CIFAR100, TinyImagenet200
nbdt-hierarchy --method=wordnet --dataset=CIFAR10
查看示例可视化。 (点击展开)
我们可以生成一个具有略微改进的缩放比例和包含wordnet ID的可视化。默认情况下,脚本为CIFAR10构建Wordnet层次结构。
nbdt-hierarchy --method=wordnet --vis-zoom=1.25 --vis-sublabels
随机层次结构
使用--method=random
随机生成一个类似二进制的层次结构。可选地使用--seed
(--seed=-1
表示不洗牌叶子节点)和--branching-factor
标志。在调试时,我们将分支因子设置为类别的数量。例如,CIFAR10的健全性检查层次结构为
nbdt-hierarchy --seed=-1 --branching-factor=10 --dataset=CIFAR10
2. 树监督损失
以下训练命令中,我们统一使用--path-resume=<path/to/checkpoint> --lr=0.01
进行微调,而不是从头开始训练。我们使用最近最先进的预训练检查点(WideResNet)进行微调。运行以下命令,在CIFAR10上使用软树监督损失进行WideResNet的微调。
python main.py --lr=0.01 --dataset=CIFAR10 --arch=wrn28_10_cifar10 --hierarchy=induced-wrn28_10_cifar10 --pretrained --loss=SoftTreeSupLoss
查看如何工作以及如何配置。 (点击展开)
树监督损失有两种变体:一种是硬版本,另一种是软版本。只需将损失更改为HardTreeSupLoss
或SoftTreeSupLoss
,取决于您想要使用哪种。
# fine-tune the wrn pretrained checkpoint on CIFAR10 with hard tree supervision loss
python main.py --lr=0.01 --dataset=CIFAR10 --arch=wrn28_10_cifar10 --hierarchy=induced-wrn28_10_cifar10 --pretrained --loss=HardTreeSupLoss
# fine-tune the wrn pretrained checkpoint on CIFAR10 with soft tree supervision loss
python main.py --lr=0.01 --dataset=CIFAR10 --arch=wrn28_10_cifar10 --hierarchy=induced-wrn28_10_cifar10 --pretrained --loss=SoftTreeSupLoss
要从头开始训练,使用--lr=0.1
,并且不传递--path-resume
或--pretrained
标志。我们在CIFAR10和CIFAR100上微调WideResnet,但在基线神经网络准确度可复制的位置,我们从头开始训练。
3. 推理
与树监督损失的变体类似,也有两种推理变体:一种是硬推理,另一种是软推理。下面,我们在使用软损失训练的模型上运行软推理。
运行以下bash脚本以获得这些数字。
python main.py --dataset=CIFAR10 --arch=wrn28_10_cifar10 --hierarchy=induced-wrn28_10_cifar10 --loss=SoftTreeSupLoss --eval --resume --analysis=SoftEmbeddedDecisionRules
查看如何工作以及如何配置。 (点击展开)
请注意,以下命令几乎与相应的训练命令相同--我们省略了lr
和pretrained
标志,并添加了resume
、eval
和analysis
类型(硬推理或软推理)。令人惊讶的是,我们在论文中报告的最佳结果是通过在由软树监督损失监督的神经网络上同时运行硬推理和软推理获得的。这反映在下面的命令中。
# running soft inference on soft-supervised model
python main.py --dataset=CIFAR10 --arch=wrn28_10_cifar10 --hierarchy=induced-wrn28_10_cifar10 --loss=SoftTreeSupLoss --eval --resume --analysis=SoftEmbeddedDecisionRules
# running hard inference on soft-supervised model
python main.py --dataset=CIFAR10 --arch=wrn28_10_cifar10 --hierarchy=induced-wrn28_10_cifar10 --loss=SoftTreeSupLoss --eval --resume --analysis=HardEmbeddedDecisionRules
结果
我们与所有之前在CIFAR10、CIFAR100和/或ImageNet上报告的基于决策树的方法进行比较,包括使用不纯的叶子或随机森林来损害可解释性的方法。我们报告了所有这些方法中准确度最高的基线:深度神经网络决策森林(DNDF,更新为ResNet18)、可解释观察者-分类器(XOC)、深度卷积决策丛林(DCDJ)、专家网络(NofE)、深度决策网络(DDN)和自适应神经网络树(ANT)。
CIFAR10 | CIFAR100 | TinyImagenet200 | ImageNet | |
---|---|---|---|---|
NBDT-S(我们自己的) | 97.57% | 82.87% | 66.66% | 75.13% |
NBDT-H(我们自己的) | 97.55% | 82.21% | 64.39% | 74.79% |
最佳预NBDT准确率 | 94.32% | 76.24% | 44.56% | 61.29% |
最佳预NBDT方法 | DNDF | NofE | DNDF | NofE |
我们的改进 | 3.25% | 6.63% | 22.1% | 13.84% |
如最后一行所示,我们在CIFAR10、CIFAR100和TinyImagenet200等中小型数据集上的准确率比所有之前的方法高出3%到13%以上。请注意,由于我们使用当前公共版本重新训练了所有模型,因此在小到中型数据集(CIFAR10、CIFAR100和TinyImagenet200)上预训练检查点的准确率可能波动0.1-0.2%。
开发设置
如上所述,您可以使用nbdt
Python库将NBDT训练集成到任何现有的训练管道中。但是,如果您希望使用这里的裸骨训练实用工具,请参阅以下部分以添加自定义模型和数据集。
如果您尚未这样做,请首先克隆存储库并安装所有要求。
git clone git@github.com:alvinwan/neural-backed-decision-trees.git # or http addr if you don't have private-public github key setup
cd neural-backed-decision-trees
python setup.py develop
作为一个示例,我们包括WideResNet bash脚本的副本,但针对ResNet18。
bash scripts/gen_train_eval_resnet.sh
对于任何具有感兴趣数据集的预训练检查点的模型(例如,来自pytorchcv
的CIFAR10、CIFAR100和ImageNet模型,或来自torchvision
的ImageNet模型),修改scripts/gen_train_eval_pretrained.sh
;只需更改模型名称。对于所有没有感兴趣数据集的预训练检查点的模型,修改scripts/gen_train_eval_nopretrained.sh
。
模型
在不修改main.py
的情况下,您可以用您喜欢的网络替换ResNet18:将任何torchvision.models
模型或任何pytorchcv
模型传递到--arch
,因为我们直接支持这两个模型库。请注意,前者仅支持在ImageNet上预训练的模型。后者支持在CIFAR10、CIFAR100和ImageNet上预训练的模型;对于每个数据集,相应的模型名称包括数据集,例如wrn28_10_cifar10
。但是,两者都不支持在TinyImagenet上预训练的模型。
从头开始添加新模型
- 创建一个包含您网络的新文件,例如
./nbdt/models/yournet.py
。此文件应仅包含一个__all__
,仅公开返回模型的函数。这些函数应接受pretrained: bool
和progress: bool
,然后将所有其他关键字参数传递给模型构造函数。 - 通过
./nbdt/models/__init__.py
公开您的文件:from .yournet import *
。 - 在目标数据集上训练原始神经网络。例如,
python main.py --arch=yournet18
。
数据集
在不修改 main.py
的情况下,您可以通过将其传递给 --dataset
使用任何在 torchvision.datasets
中找到的图像分类数据集。要从头开始添加新数据集
- 创建一个包含您数据集的新文件,例如
./nbdt/data/yourdata.py
。假设数据类是YourData10
。像之前一样,仅通过__all__
公开数据集类。此数据集类应支持一个.classes
属性,它返回一个人类可读的类名列表。 - 通过
'./nbdt/data/__init__.py'
公开您的文件:from .yourdata import *
。 - 在
./nbdt/wnids/{dataset}.txt
中创建一个包含 wordnet IDs 的文本文件。此列表应与您的数据集的.classes
的顺序相同。您可以选择使用nbdt-wnids
工具生成 wnids(见以下注释) - 在目标数据集上训练原始神经网络。例如,
python main.py --dataset=YourData10
注意:您可以选择使用
nbdt-wnids
工具生成 wnidsnbdt-wnids --dataset=YourData10
,其中
YourData
是您的数据集名称。如果从YourData.classes
提供的类名在 WordNet 语料库中不存在,则脚本将生成一个假的 wnid。这不会影响训练,但随后的分析脚本将无法提供 WordNet 估计的节点意义。
引用
如果您认为这项工作对您的科研有用,请引用我们的论文
@article{wan2020nbdt,
title={NBDT: Neural-Backed Decision Trees},
author={Alvin Wan and Lisa Dunlap and Daniel Ho and Jihan Yin and Scott Lee and Henry Jin and Suzanne Petryk and Sarah Adel Bargal and Joseph E. Gonzalez},
year={2020},
eprint={},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
项目详情
nbdt-0.0.4.tar.gz 的哈希值
算法 | 哈希摘要 | |
---|---|---|
SHA256 | da5873b68e62d7215a72bbb46e8b6f0fc2044380e73b111e654422c5601a3d14 |
|
MD5 | 4b3a8a0725f2942266985539a0a288f0 |
|
BLAKE2b-256 | d93a75fb13e538bb75df5bb4802a7296311e88923eec0c1f76e9da5e2887f6b9 |