跳转到主要内容

PyTorch的可微分量化框架。

项目描述

通过伪量化噪声进行可微分模型压缩

linter badge tests badge cov badge

DiffQ使用伪量化噪声执行可微分量化。它可以自动调整每个权重或权重组使用的位数,以达到在模型大小和准确性之间的特定权衡。

阅读我们的论文以获取更多详细信息。

发生了什么?

查看变更日志以获取关于发布的详细信息。

  • 2022-08-24: v0.2.3: 修复了加载旧量化状态时的错误。
  • 2021-11-25: 版本 0.2.2: 增加了torchscript的支持。

需求

DiffQ需要Python 3.7,以及一个合理近期的PyTorch版本(理想情况下为1.7.1)。要安装DiffQ,您可以从存储库的根目录运行

pip install .

您也可以通过pip install diffq直接从PyPI安装。

用法

import torch
from torch.nn import functional as F
import diffq
from diffq import DiffQuantizer

model = MyModel()
optim = ...  # The optimizer must be created before the quantizer
quantizer = DiffQuantizer(model)
quantizer.setup_optimizer(optim)

# Distributed data parallel must be created after DiffQuantizer!
dmodel = torch.distributed.DistributedDataParallel(...)

penalty = 1e-3
model.train()  # call model.eval() on eval to automatically use true quantized weights.
for batch in loader:
    ...
    optim.zero_grad()

    # The `penalty` parameter here will control the tradeoff between model size and model accuracy.
    loss = F.mse_loss(x, y) + penalty * quantizer.model_size()
    optim.step()

# To get the true model size with when doing proper bit packing.
print(f"Model is {quantizer.true_model_size():.1f} MB")

# When you want to dump your final model:
torch.save(quantizer.get_quantized_state(), "some_file.th")

# You can later load back the model with
model = MyModel()
diffq.restore_quantized_state(model, torch.load("some_file.th"))

# For DiffQ models, we support exporting the model to Torscript with optimal storage.
# Once loaded, the model will be stored in fp32 in memory (int8 support coming up).
from diffq.ts_export import export
export(quantizer, 'quantized.ts')

文档

请参阅API文档以获取详细文档。以下将介绍一些方面。

量化器对象

量化器在其创建时附加到模型上。所有量化器对象都提供相同的基本功能

  • 如果模型处于评估模式,则自动切换到量化权重。
  • 在训练正向过程中提供特定的量化器代码(例如,对于使用QAT的UniformQuantizer的STE,DiffQ的噪声注入)。
  • 提供对量化模型大小和状态访问。

量化大小和状态

方法quantizer.model_size()提供可微模型大小(对于DiffQ),而quantizer.true_model_size()提供真实、最优的位打包模型大小(不可微)。使用quantizer.compressed_model_size(),您可以使用gzip获取模型大小。这实际上可能大于真实模型大小,并揭示了关于特定量化方法熵使用的有趣信息。

通过 quantizer.get_quantized_state() 获取位打包的量化状态,并通过 quantizer.restore_quantized_state() 恢复。位打包针对速度进行了优化,可能会产生一些开销(实际中均匀和LSQ不会超过120B,DiffQ不会超过1KB)。

如果您无法访问原始量化器,例如在推理时,可以使用 diffq.restore_quantized_state(model, quantized_state) 加载状态。

量化器和优化

某些量化器会添加额外的可优化参数(DiffQuantizer和LSQ)。这些参数可能需要与主模型权重不同的优化器或超参数。通常,DiffQ位参数始终使用Adam进行优化。因此,您应该始终在量化器之前创建主优化器。然后,您可以使用此优化器或其他优化器设置量化器。

model = MyModel(...)
opt = torch.optim.Adam(model.parameters())
quantizer = diffq.DiffQuantizer(model)
quantizer.setup_optimizer(opt, **optim_overrides)

这提供了使用单独超参数的自由。例如,DiffQuantizer 总是禁用位参数的 weight_decay。

如果主优化器是SGD,建议为量化器使用第二个Adam优化器。

注意:您必须在创建量化器后始终用 DistributedDataParallel 包装您的模型,否则量化器参数不会被优化!

TorchScript 支持

目前,TorchScript 支持是实验性的。我们支持将模型以最佳存储方式保存到磁盘上的TorchScript。一旦加载,模型将存储在内存中的FP32。我们正在努力添加对内存中int8的支持。请参阅API中的 diffq.ts_export.export 函数。

示例

我们在 examples/ 文件夹中提供了三个示例。一个用于CIFAR-10/100,使用标准架构,如Wide-ResNet、ResNet或MobileNet。第二个基于 DeiT 视觉变换器。第三个是在Wikitext-103上的语言建模任务,使用 Fairseq

DeiT和Fairseq示例作为原始代码库中特定提交的补丁提供。您可以初始化git子模块并运行以下命令应用补丁:

make examples

有关每个示例的更多详细信息,请查看它们的特定README。

开发安装

这将安装依赖项和开发者模式下的 diffq(文件更改将直接反映),以及运行单元测试的依赖项。

pip install -e '.[dev]'

更新基于补丁的示例

要更新补丁,首先运行 make examples 正确初始化子仓库。然后执行所有您想要的更改,提交它们并运行 make patches。这将更新每个仓库的补丁。一旦完成,并确认您所做的所有更改都已正确包含在新的补丁文件中,您可以在运行 git add -u .; git commit -m "my changes" 和推送之前运行 make reset(这将从子模块中删除您所做的所有更改,因此在调用此命令之前请检查补丁文件)。

测试

您可以使用以下命令运行单元测试:

make tests

引用

如果您在论文中使用此代码或结果,请引用我们的工作如下:

@article{defossez2021differentiable,
  title={Differentiable Model Compression via Pseudo Quantization Noise},
  author={D{\'e}fossez, Alexandre and Adi, Yossi and Synnaeve, Gabriel},
  journal={TMLR},
  year={2022}
}

许可

此存储库根据CC-BY-NC 4.0许可证发布,如LICENSE文件中所示,以下部分除外,该部分根据MIT许可证发布。文件 examples/cifar/src/mobilenet.pyexamples/cifar/src/resnet.py 来自 kuangliu/pytorch-cifar,以MIT许可证发布。文件 examples/cifar/src/wide_resnet.py 来自 meliketoy/wide-resnet,以MIT许可证发布。请参阅每个文件的头部以获取详细许可证信息。

项目详情


下载文件

下载适用于您平台的文件。如果您不确定选择哪个,请了解有关安装包的更多信息。

源代码分发

diffq-0.2.4.tar.gz (157.1 kB 查看哈希值)

上传时间 源代码

构建分发

diffq-0.2.4-cp310-cp310-win_amd64.whl (91.8 kB 查看哈希值)

上传时间 CPython 3.10 Windows x86-64

diffq-0.2.4-cp310-cp310-win32.whl (82.2 kB 查看哈希值)

上传时间 CPython 3.10 Windows x86

diffq-0.2.4-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (418.8 kB 查看哈希值)

上传时间 CPython 3.10 manylinux: glibc 2.12+ x86-64 manylinux: glibc 2.5+ x86-64

diffq-0.2.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_12_i686.manylinux2010_i686.whl (401.1 kB 查看哈希值)

上传时间 CPython 3.10 manylinux: glibc 2.12+ i686 manylinux: glibc 2.5+ i686

diffq-0.2.4-cp310-cp310-macosx_11_0_arm64.whl (97.4 kB 查看哈希值)

上传时间 CPython 3.10 macOS 11.0+ ARM64

diffq-0.2.4-cp310-cp310-macosx_10_9_x86_64.whl (106.2 kB 查看哈希值)

上传时间 CPython 3.10 macOS 10.9+ x86-64

diffq-0.2.4-cp310-cp310-macosx_10_9_universal2.whl (175.8 kB 查看哈希值)

上传时间 CPython 3.10 macOS 10.9+ universal2 (ARM64, x86-64)

diffq-0.2.4-cp39-cp39-win_amd64.whl (93.2 kB 查看哈希值)

上传于 CPython 3.9 Windows x86-64

diffq-0.2.4-cp39-cp39-win32.whl (83.5 kB 查看哈希值)

上传于 CPython 3.9 Windows x86

diffq-0.2.4-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (425.5 kB 查看哈希值)

上传于 CPython 3.9 manylinux: glibc 2.12+ x86-64 manylinux: glibc 2.5+ x86-64

diffq-0.2.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_12_i686.manylinux2010_i686.whl (409.1 kB 查看哈希值)

上传于 CPython 3.9 manylinux: glibc 2.12+ i686 manylinux: glibc 2.5+ i686

diffq-0.2.4-cp39-cp39-macosx_11_0_arm64.whl (96.7 kB 查看哈希值)

上传于 CPython 3.9 macOS 11.0+ ARM64

diffq-0.2.4-cp39-cp39-macosx_10_9_x86_64.whl (106.0 kB 查看哈希值)

上传于 CPython 3.9 macOS 10.9+ x86-64

diffq-0.2.4-cp39-cp39-macosx_10_9_universal2.whl (174.9 kB 查看哈希值)

上传于 CPython 3.9 macOS 10.9+ universal2 (ARM64, x86-64)

diffq-0.2.4-cp38-cp38-win_amd64.whl (93.1 kB 查看哈希值)

上传于 CPython 3.8 Windows x86-64

diffq-0.2.4-cp38-cp38-win32.whl (83.4 kB 查看哈希值)

上传于 CPython 3.8 Windows x86

diffq-0.2.4-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (446.6 kB 查看哈希值)

上传于 CPython 3.8 manylinux: glibc 2.12+ x86-64 manylinux: glibc 2.5+ x86-64

diffq-0.2.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_12_i686.manylinux2010_i686.whl (428.1 kB 查看哈希值)

上传于 CPython 3.8 manylinux: glibc 2.12+ i686 manylinux: glibc 2.5+ i686

diffq-0.2.4-cp38-cp38-macosx_11_0_arm64.whl (95.1 kB 查看哈希值)

上传于 CPython 3.8 macOS 11.0+ ARM64

diffq-0.2.4-cp38-cp38-macosx_10_9_x86_64.whl (103.8 kB 查看哈希值)

上传于 CPython 3.8 macOS 10.9+ x86-64

diffq-0.2.4-cp38-cp38-macosx_10_9_universal2.whl (171.1 kB 查看哈希值)

上传于 CPython 3.8 macOS 10.9+ universal2 (ARM64, x86-64)

diffq-0.2.4-cp37-cp37m-win_amd64.whl (91.9 kB 查看哈希值)

上传于 CPython 3.7m Windows x86-64

diffq-0.2.4-cp37-cp37m-win32.whl (81.9 kB 查看哈希值)

上传于 CPython 3.7m Windows x86

diffq-0.2.4-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (393.8 kB 查看哈希值)

上传于 CPython 3.7m manylinux: glibc 2.12+ x86-64 manylinux: glibc 2.5+ x86-64

diffq-0.2.4-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_12_i686.manylinux2010_i686.whl (376.0 kB 查看哈希值)

上传于 CPython 3.7m manylinux: glibc 2.12+ i686 manylinux: glibc 2.5+ i686

diffq-0.2.4-cp37-cp37m-macosx_10_9_x86_64.whl (104.5 kB 查看哈希值)

上传于 CPython 3.7m macOS 10.9+ x86-64

支持者

AWS AWS 云计算和安全赞助商 Datadog Datadog 监控 Fastly Fastly CDN Google Google 下载分析 Microsoft Microsoft PSF 赞助商 Pingdom Pingdom 监控 Sentry Sentry 错误记录 StatusPage StatusPage 状态页面