跳转到主要内容

JAX中的Frechet Inception Distance。

项目描述

PyPI

FID JAX

在JAX中干净实现的Frechet Inception Distance

  • 重现了OpenAI的TensorFlow实现。
  • 纯JAX实现可在CPU/GPU/TPU和JIT内运行。
  • 可以使用pathlib API从GCS加载权重。
  • 代码干净且简单。

说明

1️⃣ FID JAX是一个单个文件,所以您可以直接将其复制到项目目录中。或者,您可以安装该软件包

pip install fidjax

2️⃣ 下载Inception权重(归功于Matthias Wright

wget https://www.dropbox.com/s/xt6zvlvt22dcwck/inception_v3_weights_fid.pickle?dl=1

3️⃣ 下载所需分辨率的ImageNet参考统计数据(为其他数据集生成您自己的

wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/VIRTUAL_imagenet64_labeled.npz
wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/VIRTUAL_imagenet128_labeled.npz
wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/VIRTUAL_imagenet256_labeled.npz
wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/VIRTUAL_imagenet512.npz

4️⃣ 在JAX中计算激活、统计信息和得分

import fidjax
import numpy as np

weights = './inception_v3_weights_fid.pickle?dl=1'
reference = './VIRTUAL_imagenet128_labeled.npz'
fid = fidjax.FID(weights, reference)

fid_total = 50000
fid_batch = 1000
acts = []
for range(fid_total // fid_batch):
  samples = ...  # (B, H, W, 3) jnp.uint8
  acts.append(fid.compute_acts(samples))
stats = fid.compute_stats(acts)
score = fid.compute_score(stats)

print(float(score))  # FID

准确性

数据集 模型 FID JAX OpenAI TF
ImageNet 256 ADM (引导,上采样) 3.937 3.943

教程

使用云存储

通过支持您的云存储的pathlib.Path实现指向文件。例如,对于GCS

import elements  # pip install elements
import fidjax

weights = elements.Path('gs://bucket/fid/inception_v3_weights_fid.pickle')
reference = elements.Path('gs://bucket/fid/VIRTUAL_imagenet128_labeled.npz')

fid = fidjax.FID(weights, reference)

自定义数据集

为自定义数据集生成参考统计数据

import fidjax
import numpy as np

weights = './inception_v3_weights_fid.pickle?dl=1'
fid = fidjax.FID(weights)

acts = fid.compute_acts(images)
mu, sigma = fid.compute_stats(acts)

np.savez('reference.npz', {'mu': mu, 'sigma': sigma})

资源

问题

请在Github上提交问题

项目详情


下载文件

下载适用于您平台的文件。如果您不确定选择哪个,请了解有关安装软件包的更多信息。

源分布

fidjax-1.0.1.tar.gz (5.3 kB 查看散列)

上传于

支持者