JAX中的Frechet Inception Distance。
项目描述
FID JAX
在JAX中干净实现的Frechet Inception Distance。
- 重现了OpenAI的TensorFlow实现。
- 纯JAX实现可在CPU/GPU/TPU和JIT内运行。
- 可以使用
pathlib
API从GCS加载权重。 - 代码干净且简单。
说明
1️⃣ FID JAX是一个单个文件,所以您可以直接将其复制到项目目录中。或者,您可以安装该软件包
pip install fidjax
2️⃣ 下载Inception权重(归功于Matthias Wright)
wget https://www.dropbox.com/s/xt6zvlvt22dcwck/inception_v3_weights_fid.pickle?dl=1
3️⃣ 下载所需分辨率的ImageNet参考统计数据(为其他数据集生成您自己的)
wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/VIRTUAL_imagenet64_labeled.npz
wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/VIRTUAL_imagenet128_labeled.npz
wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/VIRTUAL_imagenet256_labeled.npz
wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/VIRTUAL_imagenet512.npz
4️⃣ 在JAX中计算激活、统计信息和得分
import fidjax
import numpy as np
weights = './inception_v3_weights_fid.pickle?dl=1'
reference = './VIRTUAL_imagenet128_labeled.npz'
fid = fidjax.FID(weights, reference)
fid_total = 50000
fid_batch = 1000
acts = []
for range(fid_total // fid_batch):
samples = ... # (B, H, W, 3) jnp.uint8
acts.append(fid.compute_acts(samples))
stats = fid.compute_stats(acts)
score = fid.compute_score(stats)
print(float(score)) # FID
准确性
数据集 | 模型 | FID JAX | OpenAI TF |
---|---|---|---|
ImageNet 256 | ADM (引导,上采样) | 3.937 | 3.943 |
教程
使用云存储
通过支持您的云存储的pathlib.Path
实现指向文件。例如,对于GCS
import elements # pip install elements
import fidjax
weights = elements.Path('gs://bucket/fid/inception_v3_weights_fid.pickle')
reference = elements.Path('gs://bucket/fid/VIRTUAL_imagenet128_labeled.npz')
fid = fidjax.FID(weights, reference)
自定义数据集
为自定义数据集生成参考统计数据
import fidjax
import numpy as np
weights = './inception_v3_weights_fid.pickle?dl=1'
fid = fidjax.FID(weights)
acts = fid.compute_acts(images)
mu, sigma = fid.compute_stats(acts)
np.savez('reference.npz', {'mu': mu, 'sigma': sigma})
资源
问题
请在Github上提交问题。