跳转到主要内容

HIPPO 计算病理可解释性工具包。

项目描述

HIPPO

HIPPO 是用于计算病理学中弱监督学习的可解释性工具包。

请参阅我们在arXiv上的预印本 https://arxiv.org/abs/2409.03080

[!NOTE] 此代码库是一个正在进行的工程。请定期查看更新。

摘要

深度学习模型在组织病理学图像分析中展现出巨大潜力,但它们不透明的决策过程在高风险医疗场景中带来了挑战。在这里,我们介绍了HIPPO,这是一种可解释人工智能方法,通过在整张切片图像中通过组织块修改生成反事实示例,对计算病理学中的基于注意力的多实例学习(ABMIL)模型进行质询。将HIPPO应用于训练用于检测乳腺癌转移的ABMIL模型表明,它们可能会忽略小型肿瘤,并且可能被非肿瘤组织误导,而广泛用于解释的注意力图通常突出显示对预测没有直接影响的区域。通过对预测预后任务的ABMIL模型进行解释,HIPPO确定了比高注意力区域具有更强预后作用的组织区域,有时这些区域对风险分数的影响似乎令人费解。这些发现证明了HIPPO在全面模型评估、偏差检测和定量假设检验方面的能力。HIPPO极大地扩展了可解释人工智能工具的功能,以评估计算病理学中弱监督模型的可靠、可信赖的开发、部署和监管。

如果您认为HIPPO很有用,请在您的作品中引用它。

如何使用HIPPO

HIPPO旨在用于计算病理学中的弱监督、多实例学习模型。在您使用HIPPO之前,您需要块嵌入和一个训练好的基于注意力的多实例学习(ABMIL)模型。以下,我们简要描述了如何从全切片图像(WSI)到训练好的ABMIL模型的过程。

我们还提供了基于CAMELYON16训练的转移检测模型。有关使用不同编码器训练的转移检测模型的以下HuggingFace存储库

为了简化可重复性,我们还上传了CAMELYON16的UNI嵌入到https://hugging-face.cn/datasets/kaczmarj/camelyon16-uni。其他模型的嵌入可能会在未来上传。

为ABMIL准备您的数据

首先,将您的全切片图像分成更小的、不重叠的块。CLAM工具包是做这件事的一种流行方法。在您有了块坐标后,您将需要使用预训练的模型对这些块进行编码。有无数种选择,但我会选择在大量且多样化的组织病理学图像上训练的最近的基础模型。跟踪块坐标和块特征。这将有助于下游HIPPO实验和可视化注意力图。

训练ABMIL模型

我们在https://hugging-face.cn/kaczmarj/metastasis-abmil-128um-uni/blob/main/train_classification.py提供了分类模型的训练脚本。或者,使用CLAM或另一个工具包训练模型。HIPPO可以与任何接受一组块并返回样本级输出的弱监督模型一起工作。

示例

使用合成数据的最小可重复示例

下面的代码并不是为了展示任何干预措施的效果。相反,目的是展示如何使用HIPPO在样本中创建干预措施,并使用预训练的ABMIL模型评估其效果。

要使用真实数据和预训练模型,请参阅下面的示例。

import hippo
import numpy as np
import torch

# Create the ABMIL model. Here, we use random initializations for the example.
# You should use a pretrained model in practice.
model = hippo.AttentionMILModel(in_features=1024, L=512, D=384, num_classes=2)
model.eval()

# We use random features. In practice, use actual features :)
features = torch.rand(1000, 1024)

# Define the intervention. Here, we want to remove five patches.
# We define the indices of the patches to keep.
patches_to_remove = np.array([500, 501, 502, 503, 504])
patches_to_keep = np.setdiff1d(np.arange(features.shape[0]), patches_to_remove)

# Get the model outputs for baseline and "treated" samples.
with torch.inference_mode():
    baseline = model(features).logits.softmax(1)
    treatment = model(features[patches_to_keep]).logits.softmax(1)

测试肿瘤对转移检测的充分性

以下示例中,我们加载了一个基于UNI的ABMIL模型进行转移检测,该模型在CAMELYON16上进行了训练。然后,我们从样本test_001中的一个肿瘤片嵌入中取出并添加到负样本test_003中。添加这个单个肿瘤片就足以导致阳性转移结果。

import hippo
import huggingface_hub
import numpy as np
import torch

# Create the ABMIL model. Here, we use random initializations for the example.
# You should use a pretrained model in practice.
model = hippo.AttentionMILModel(in_features=1024, L=512, D=384, num_classes=2)
model.eval()
# You may need to run huggingface_hub.login() to get this file.
state_dict_path = huggingface_hub.hf_hub_download(
    "kaczmarj/metastasis-abmil-128um-uni", filename="seed2/model_best.pt"
)
state_dict = torch.load(state_dict_path, map_location="cpu", weights_only=True)
model.load_state_dict(state_dict)

features_positive_path = huggingface_hub.hf_hub_download(
    "kaczmarj/camelyon16-uni", filename="embeddings/test_001.pt", repo_type="dataset"
)
features_positive = torch.load(features_positive_path, weights_only=True)
# This index contains the embedding for the tumor patch shown in Figure 2a of the HIPPO preprint.
tumor_patch = features_positive[7238].unsqueeze(0)  # 1x1024

features_negative_patch = huggingface_hub.hf_hub_download(
    "kaczmarj/camelyon16-uni", filename="embeddings/test_003.pt", repo_type="dataset"
)
features_negative = torch.load(features_negative_patch, weights_only=True)

# Get the model outputs for baseline and treated samples.
with torch.inference_mode():
    baseline = model(features_negative).logits.softmax(1)[0, 1].item()
    treatment = model(torch.cat([features_negative, tumor_patch])).logits.softmax(1)[0, 1].item()

print(f"Probability of tumor in baseline: {baseline:0.3f}")  # 0.002
print(f"Probability of tumor after adding one tumor patch: {treatment:0.3f}")  # 0.824

测试高关注区域的效果

在这个示例中,我们评估了高关注区域对转移检测的影响。我们发现以下内容

  1. 使用原始样本,模型强烈预测存在转移(概率为0.997)。
  2. 如果我们删除1%的关注片,转移的概率仍然很高(0.988)。这可能是由于在删除1%的关注后,样本中仍然存在一些肿瘤片。
  3. 如果我们删除5%的关注,那么转移的概率下降到0.001。

这样,我们可以量化高关注区域的影响。

import math
import hippo
import huggingface_hub
import torch

# Create the ABMIL model. Here, we use random initializations for the example.
# You should use a pretrained model in practice.
model = hippo.AttentionMILModel(in_features=1024, L=512, D=384, num_classes=2)
model.eval()
# You may need to run huggingface_hub.login() to get this file.
state_dict_path = huggingface_hub.hf_hub_download(
    "kaczmarj/metastasis-abmil-128um-uni", filename="seed2/model_best.pt"
)
state_dict = torch.load(state_dict_path, map_location="cpu", weights_only=True)
model.load_state_dict(state_dict)

# Load features for positive specimen.
features_path = huggingface_hub.hf_hub_download(
    "kaczmarj/camelyon16-uni", filename="embeddings/test_001.pt", repo_type="dataset"
)
features = torch.load(features_path, weights_only=True)

# Get the model outputs for baseline and treated samples.
with torch.inference_mode():
    logits, attn = model(features)
attn = attn.squeeze(1).numpy()  # flatten tensor
tumor_prob = logits.softmax(1)[0, 1].item()
print(f"Tumor probability at baseline: {tumor_prob:0.3f}")

inds = attn.argsort()[::-1].copy()  # indices high to low, and copy to please torch
num_patches = math.ceil(len(inds) * 0.01)
with torch.inference_mode():
    logits_01pct, _ = model(features[inds[num_patches:]])
tumor_prob_01pct = logits_01pct.softmax(1)[0, 1].item()
print(f"Tumor probability after removing top 1% of attention: {tumor_prob_01pct:0.3f}")

num_patches = math.ceil(len(inds) * 0.05)
with torch.inference_mode():
    logits_05pct, _ = model(features[inds[num_patches:]])
tumor_prob_05pct = logits_05pct.softmax(1)[0, 1].item()
print(f"Tumor probability after removing top 5% of attention: {tumor_prob_05pct:0.3f}")

以下内容被打印出来

Tumor probability at baseline: 0.997
Tumor probability after removing top 1% of attention: 0.988
Tumor probability after removing top 5% of attention: 0.001

HIPPO贪婪搜索算法

HIPPO实现了贪婪搜索算法来识别重要片段。以下,我们搜索对转移检测影响最大的片段。简而言之,我们识别出当删除时会导致转移检测概率最低的片段。

import math
import hippo
import huggingface_hub
import numpy as np
import torch

# Set our device.
device = torch.device("cpu")
# device = torch.device("cuda")  # Uncomment if you have a GPU.
# device = torch.device("mps")  # Uncomment if you have an ARM Apple computer.

# Load ABMIL model.
model = hippo.AttentionMILModel(in_features=1024, L=512, D=384, num_classes=2)
model.eval()
# You may need to run huggingface_hub.login() to get this file.
state_dict_path = huggingface_hub.hf_hub_download(
    "kaczmarj/metastasis-abmil-128um-uni", filename="seed2/model_best.pt"
)
state_dict = torch.load(state_dict_path, map_location="cpu", weights_only=True)
model.load_state_dict(state_dict)
model.to(device)

# Load features.
features_path = huggingface_hub.hf_hub_download(
    "kaczmarj/camelyon16-uni", filename="embeddings/test_064.pt", repo_type="dataset"
)
features = torch.load(features_path, weights_only=True).to(device)


# Define a function that takes in a bag of features and returns model probabilities.
# The output values are the values we want to optimize during our search.
# This is why we use a function -- models can have different outputs. By defining
# a function that returns the values we want to optimize on, we can streamline the code.
def model_probs_fn(features):
    with torch.inference_mode():
        logits, _ = model(features)
    # Shape of logits is 1xC, where C is number of classes.
    probs = logits.softmax(1).squeeze(0)  # C
    return probs


# Find the 1% highest effect patches. These are the patches that, when removed, drop the probability
# of metastasis the most. The `results` variable is a dictionary with.... results of the search!
# The model outputs in `results["model_outputs"]` correspond to the results after removing the patches
# in `results["ablated_patches"][:k]`.
num_rounds = math.ceil(len(features) * 0.01)
results = hippo.greedy_search(
    features=features,
    model_probs_fn=model_probs_fn,
    num_rounds=num_rounds,
    output_index_to_optimize=1,
    # We use minimize because we want to minimize the model outputs
    # when the patches are *removed*.
    optimizer=hippo.minimize,
)

# Now we can test the effect of removing the 1% highest effect patches.
patches_not_ablated = np.setdiff1d(np.arange(len(features)), results["ablated_patches"])
with torch.inference_mode():
    prob_baseline = model(features).logits.softmax(1)[0, 1].item()  # 1.000
    prob_without_high_effect = model(features[patches_not_ablated]).logits.softmax(1)[0, 1].item()  # 0.008

print(f"Probability of metastasis at baseline: {prob_baseline:0.3f}")
print(f"Probability of metastasis after removing 1% highest effect patches: {prob_without_high_effect:0.3f}")

我们还可以绘制当我们删除高影响片段时的模型输出,我们希望看到一条单调递减的线。

import matplotlib.pyplot as plt
import numpy as np

model_results = results["model_outputs"][:, results["optimized_class_index"]]
plt.plot(model_results)
plt.xlabel("Number of patches removed")
plt.ylabel("Probability of metastasis")

引用

@misc{kaczmarzyk2024explainableaicomputationalpathology,
      title={Explainable AI for computational pathology identifies model limitations and tissue biomarkers},
      author={Jakub R. Kaczmarzyk and Joel H. Saltz and Peter K. Koo},
      year={2024},
      eprint={2409.03080},
      archivePrefix={arXiv},
      primaryClass={q-bio.TO},
      url={https://arxiv.org/abs/2409.03080},
}

许可证

HIPPO代码根据3-Clause BSD许可证的条款进行许可,文档根据Creative Commons Attribution-NonCommercial-ShareAlike 4.0国际版权许可证(CC BY-NC-SA 4.0)的条款发布。

项目详情


下载文件

下载适合您平台的应用程序。如果您不确定该选择哪个,请了解有关安装包的更多信息。

源分布

hippo_nn-0.1.0.tar.gz (461.9 kB 查看哈希值)

上传时间

构建分布

hippo_nn-0.1.0-py3-none-any.whl (13.6 kB 查看哈希值)

上传时间 Python 3

支持者

AWS AWS 云计算和安全赞助商 Datadog Datadog 监控 Fastly Fastly CDN Google Google 下载分析 Microsoft Microsoft PSF 赞助商 Pingdom Pingdom 监控 Sentry Sentry 错误记录 StatusPage StatusPage 状态页面