跳转到主要内容

Keras实现的ViT(视觉Transformer)

项目描述

vit-keras

这是《An Image is Worth 16x16 Words: Transformers For Image Recognition at Scale》中描述的模型的Keras实现。它基于tuvovan的一个早期实现,修改以匹配官方存储库中的Flax实现。

这里的权重是从官方存储库中提供的权重迁移过来的。请参阅utils.load_weights_numpy了解如何进行迁移(它不是很美观,但它确实完成了工作)。

使用方法

使用以下命令安装此包:pip install vit-keras

您可以使用类似以下方法使用模型进行开箱即用的ImageNet 2012类。权重将自动下载。

from vit_keras import vit, utils

image_size = 384
classes = utils.get_imagenet_classes()
model = vit.vit_b16(
    image_size=image_size,
    activation='sigmoid',
    pretrained=True,
    include_top=True,
    pretrained_top=True
)
url = 'https://upload.wikimedia.org/wikipedia/commons/d/d7/Granny_smith_and_cross_section.jpg'
image = utils.read(url, image_size)
X = vit.preprocess_inputs(image).reshape(1, image_size, image_size, 3)
y = model.predict(X)
print(classes[y[0].argmax()]) # Granny smith

您可以使用以下方式加载的模型进行微调。

image_size = 224
model = vit.vit_l32(
    image_size=image_size,
    activation='sigmoid',
    pretrained=True,
    include_top=True,
    pretrained_top=False,
    classes=200
)
# Train this model on your data as desired.

可视化注意力图

有一些功能可以用于绘制给定图像和模型的注意力图。请参阅下面的示例。我不确定我是否正确地做了这件事(官方存储库中没有示例代码)。欢迎反馈/纠正!

import numpy as np
import matplotlib.pyplot as plt
from vit_keras import vit, utils, visualize

# Load a model
image_size = 384
classes = utils.get_imagenet_classes()
model = vit.vit_b16(
    image_size=image_size,
    activation='sigmoid',
    pretrained=True,
    include_top=True,
    pretrained_top=True
)
classes = utils.get_imagenet_classes()

# Get an image and compute the attention map
url = 'https://upload.wikimedia.org/wikipedia/commons/b/bc/Free%21_%283987584939%29.jpg'
image = utils.read(url, image_size)
attention_map = visualize.attention_map(model=model, image=image)
print('Prediction:', classes[
    model.predict(vit.preprocess_inputs(image)[np.newaxis])[0].argmax()]
)  # Prediction: Eskimo dog, husky

# Plot results
fig, (ax1, ax2) = plt.subplots(ncols=2)
ax1.axis('off')
ax2.axis('off')
ax1.set_title('Original')
ax2.set_title('Attention Map')
_ = ax1.imshow(image)
_ = ax2.imshow(attention_map)

example of attention map

项目详情


下载文件

下载适用于您平台的文件。如果您不确定选择哪个,请了解有关安装包的更多信息。

源代码分发

本发行版没有提供源代码分发文件。请参阅有关生成分发存档的教程。

构建的分发

vit_keras-0.1.2-py3-none-any.whl (24.5 kB 查看哈希值)

上传时间 Python 3

由以下支持