跳转到主要内容

DeepMind的BigGAN模型的PyTorch版本,带有预训练模型

项目描述

PyTorch预训练BigGAN

DeepMind的BigGAN模型,使用DeepMind的预训练权重在PyTorch上的逐对实现

简介

此存储库包含DeepMind的BigGAN的逐对PyTorch实现,与Andrew Brocky、Jeff Donahuey和Karen Simonyan发表的论文Large Scale GAN Training for High Fidelity Natural Image Synthesis一起发布。

此BigGAN的PyTorch实现提供了DeepMind的预训练的128x128、256x256和512x512模型。我们还提供了用于下载和转换这些模型的TensorFlow Hub模型的脚本。

此实现是从TensorFlow版本的原始计算图中完成的,并且与TensorFlow版本的行为相似(输出差异的方差为1e5的量级)。

安装

此存储库已在Python 3.6和PyTorch 1.0.1上进行了测试

可以使用以下pip命令安装PyTorch预训练BigGAN

pip install pytorch-pretrained-biggan

如果您只想玩GAN,这应该足够了。

如果您想使用转换脚本和ImageNet实用工具,则需要额外的需求,特别是TensorFlow和NLTK。要安装所有需求,请使用full_requirements.txt文件

git clone https://github.com/huggingface/pytorch-pretrained-BigGAN.git
cd pytorch-pretrained-BigGAN
pip install -r full_requirements.txt

模型

本存储库提供直接且简单的访问到 BigGAN 的预训练“深层”版本,包括 128、256 和 512 像素分辨率的模型,如相关出版物所述。以下是关于这些模型的详细信息

  • BigGAN-deep-128:一个具有 50.4M 个参数的模型,生成 128x128 像素的图像,模型权重文件大小 201 MB,
  • BigGAN-deep-256:一个具有 55.9M 个参数的模型,生成 256x256 像素的图像,模型权重文件大小 224 MB,
  • BigGAN-deep-512:一个具有 56.2M 个参数的模型,生成 512x512 像素的图像,模型权重文件大小 225 MB。

有关架构的详细信息,请参阅论文附录 B。

所有模型都包含用于 0 和 1 之间 51 个截断值的前计算批量归一化统计数据(有关详细信息,请参阅论文附录 C.1)。

使用方法

以下是一个使用预训练模型快速入门 BigGAN 的示例。

有关这些类和方法的详细信息,请参阅下面的文档部分

import torch
from pytorch_pretrained_biggan import (BigGAN, one_hot_from_names, truncated_noise_sample,
                                       save_as_images, display_in_terminal)

# OPTIONAL: if you want to have more information on what's happening, activate the logger as follows
import logging
logging.basicConfig(level=logging.INFO)

# Load pre-trained model tokenizer (vocabulary)
model = BigGAN.from_pretrained('biggan-deep-256')

# Prepare a input
truncation = 0.4
class_vector = one_hot_from_names(['soap bubble', 'coffee', 'mushroom'], batch_size=3)
noise_vector = truncated_noise_sample(truncation=truncation, batch_size=3)

# All in tensors
noise_vector = torch.from_numpy(noise_vector)
class_vector = torch.from_numpy(class_vector)

# Generate an image
output = model(noise_vector, class_vector, truncation)

# If you have a sixtel compatible terminal you can display the images in the terminal
# (see https://github.com/saitoha/libsixel for details)
display_in_terminal(output)

# Save results as png images
save_as_images(output)

output_0 output_1 output_2

文档

加载 DeepMind 的预训练权重

要加载 DeepMind 的预训练模型之一,使用 from_pretrained() 实例化一个 BigGAN 模型

model = BigGAN.from_pretrained(PRE_TRAINED_MODEL_NAME_OR_PATH, cache_dir=None)

其中

  • PRE_TRAINED_MODEL_NAME_OR_PATH 是以下之一

    • Google AI 或 OpenAI 列表中选择的一个预训练模型的快捷名称

      • biggan-deep-128:12 层,768 隐藏层,12 个头部,110M 参数
      • biggan-deep-256:24 层,1024 隐藏层,16 个头部,340M 参数
      • biggan-deep-512:12 层,768 隐藏层,12 个头部,110M 参数
    • 预训练模型存档的路径或 URL,其中包含

      • config.json:模型的配置文件,以及
      • pytorch_model.bin:保存为 BigGAN 的预训练实例的 PyTorch 权重存档(使用常用的 torch.save() 保存)。

    如果 PRE_TRAINED_MODEL_NAME_OR_PATH 是快捷名称,则预训练权重将从 AWS S3 下载(请参阅此处的链接),并存储在缓存文件夹中以避免未来的下载(缓存文件夹位于 ~/.pytorch_pretrained_biggan/)。

  • cache_dir 可以是预训练模型权重下载和缓存的特定目录的路径。

配置

BigGANConfig 是一个用于存储和加载 BigGAN 配置的类。它在 config.py 中定义。

以下是关于属性的一些详细信息

  • output_dim:预训练模型的 GAN 输出分辨率(128、256 或 512),
  • z_dim:噪声向量的大小(预训练模型为 128)。
  • class_embed_dim:类别嵌入向量的大小(预训练模型为 128)。
  • channel_width:每个通道的大小(预训练模型为 128)。
  • num_classes:训练数据集中的类别数量,例如 ImageNet(预训练模型为 1000)。
  • layers:一个层定义列表。每个层的定义是一个包含 [层中上采样是否进行?(布尔值),输入通道数(整型),输出通道数(整型)] 的三元组。
  • attention_layer_position:自注意力层在层层次结构中的位置(预训练模型为 8)。
  • eps:用于频谱和批量归一化层的 epsilon 值(预训练模型为 1e-4)。
  • n_stats:与 0 和 1 之间各种截断值相关联的批量归一化层预计算统计数据数量(预训练模型为 51)。

模型

BigGAN 是一个定义在 model.py 中的 BigGAN PyTorch 模型(torch.nn.Module)。该模型包含类别嵌入(一个线性层)以及一系列卷积和条件批量归一化。目前尚未实现判别器,因为尚未发布其预训练权重。

输入和输出与 TensorFlow 模型的输入和输出相同。

我们在此详细说明。

BigGAN 将以下作为 输入

  • z:一个形状为 [batch_size, config.z_dim] 的 torch.FloatTensor,其中噪声是从截断正态分布中采样的,并且
  • class_label:一个可选的 torch.LongTensor,形状为 [batch_size, sequence_length],包含在 [0, 1] 中选择的标记类型索引。类型 0 对应于 句子 A,类型 1 对应于 句子 B 标记(更多细节请参阅 BERT 论文)。
  • truncation:一个介于 0(不包括)和 1 之间的浮点数。这是用于创建噪声向量的截断正态分布的截断值。此截断值用于在预计算的统计信息(均值和方差)集合之间进行选择,这些统计信息用于批量归一化层。

BigGAN 输出 一个形状为 [batch_size, 3, resolution, resolution] 的数组,其中 resolution 为 128、256 或 512,具体取决于模型

实用工具:图像、噪声、ImageNet 类别

我们提供了一些实用方法来使用该模型。它们在 utils.py 中定义。

以下是这些方法的详细信息

  • truncated_noise_sample(batch_size=1, dim_z=128, truncation=1., seed=None):

    创建一个截断噪声向量。

    • 参数
      • batch_size: 批量大小。
      • dim_z: z 的维度
      • truncation: 要使用的截断值
      • seed: 随机生成器的种子
    • 输出:形状为 (batch_size, dim_z) 的数组
  • convert_to_images(obj):

    将 BigGAN 的输出张量转换为图像列表。

    • 参数
      • obj: 形状为 (batch_size, channels, height, width) 的张量或 numpy 数组
    • 输出
      • 大小为 (height, width) 的 Pillow 图像列表
  • save_as_images(obj, file_name='output'):

    将 BigGAN 的输出张量转换为并保存为图像列表。

    • 参数
      • obj: 形状为 (batch_size, channels, height, width) 的张量或 numpy 数组
      • file_name: 要保存的路径和文件名开头。图像将保存为 file_name_{image_number}.png
  • display_in_terminal(obj):

    将 BigGAN 的输出张量转换为并在终端中显示。此函数使用 libsixel,并且仅在兼容 libsixel 的终端中工作。有关更多详细信息,请参阅 https://github.com/saitoha/libsixel

    • 参数
      • obj: 形状为 (batch_size, channels, height, width) 的张量或 numpy 数组
      • file_name: 要保存的路径和文件名开头。图像将保存为 file_name_{image_number}.png
  • one_hot_from_int(int_or_list, batch_size=1):

    从一个类别索引或类别索引列表创建一个 one-hot 向量。

    • 参数
      • int_or_list: int 或 int 列表,表示 ImageNet 类别(介于 0 和 999 之间)
      • batch_size: 批量大小。
        • 如果 int_or_list 是 int,则创建一个相同类别的批次。
        • 如果 int_or_list 是列表,则应有 len(int_or_list) == batch_size
    • 输出
      • 形状为 (batch_size, 1000) 的数组
  • one_hot_from_names(class_name, batch_size=1):

    从 ImageNet 类别的名称(如 'tennis ball', 'daisy' 等)创建一个 one-hot 向量。我们使用 NLTK 的 wordnet 搜索尝试找到相关 synset 并取第一个。如果我们不能直接找到,我们将查看类名的下位词和上位词。

    • 参数
      • class_name: 包含 ImageNet 对象名称的字符串。
    • 输出
      • 形状为 (batch_size, 1000) 的数组

下载和转换脚本

./scripts 中提供了用于从 TensorFlow Hub 下载和转换 TensorFlow 模型的脚本。

脚本可以直接使用

./scripts/download_tf_hub_models.sh
./scripts/convert_tf_hub_models.sh

项目详情


下载文件

下载适用于您平台的文件。如果您不确定要选择哪个,请了解有关 安装软件包 的更多信息。

源分布

pytorch_pretrained_biggan-0.1.1.tar.gz (28.9 kB 查看哈希值)

上传时间

构建分布

pytorch_pretrained_biggan-0.1.1-py3-none-any.whl (27.2 kB 查看哈希值)

上传于 Python 3

由以下提供支持