跳转到主要内容

以样式查找和实例化类。

项目描述

Class Resolver

Tests Cookiecutter template from @cthoyt PyPI PyPI - Python Version PyPI - License Documentation Status Codecov status DOI Code style: black

以样式查找和实例化类。

💪 入门指南

from class_resolver import ClassResolver
from dataclasses import dataclass

class Base: pass

@dataclass
class A(Base):
   name: str

@dataclass
class B(Base):
   name: str

# Index
resolver = ClassResolver([A, B], base=Base)

# Lookup
assert A == resolver.lookup('A')

# Instantiate with a dictionary
assert A(name='hi') == resolver.make('A', {'name': 'hi'})

# Instantiate with kwargs
assert A(name='hi') == resolver.make('A', name='hi')

# A pre-instantiated class will simply be passed through
assert A(name='hi') == resolver.make(A(name='hi'))

🤖 使用 class-resolver 编写可扩展的机器学习模型

假设您已经在PyTorch中实现了一个简单的多层感知器

from itertools import chain

from more_itertools import pairwise
from torch import nn

class MLP(nn.Sequential):
    def __init__(self, dims: list[int]):
        super().__init__(chain.from_iterable(
            (
                nn.Linear(in_features, out_features),
                nn.ReLU(),
            )
            for in_features, out_features in pairwise(dims)
        ))

这个MLP使用硬编码的ReLU作为层之间的非线性激活函数。我们可以通过向其 __init__() 函数添加一个参数来将这个MLP泛化,使用多种非线性激活函数,例如

from itertools import chain

from more_itertools import pairwise
from torch import nn

class MLP(nn.Sequential):
    def __init__(self, dims: list[int], activation: str = "relu"):
        if activation == "relu":
            activation = nn.ReLU()
        elif activation == "tanh":
            activation = nn.Tanh()
        elif activation == "hardtanh":
            activation = nn.Hardtanh()
        else:
            raise KeyError(f"Unsupported activation: {activation}")
        super().__init__(chain.from_iterable(
            (
                nn.Linear(in_features, out_features),
                activation,
            )
            for in_features, out_features in pairwise(dims)
        ))

这个实现的第一问题是它依赖于硬编码的条件语句集,因此难以扩展。可以通过使用字典查找来改进

from itertools import chain

from more_itertools import pairwise
from torch import nn

activation_lookup: dict[str, nn.Module] = {
   "relu": nn.ReLU(),
   "tanh": nn.Tanh(),
   "hardtanh": nn.Hardtanh(),
}

class MLP(nn.Sequential):
    def __init__(self, dims: list[int], activation: str = "relu"):
        activation = activation_lookup[activation]
        super().__init__(chain.from_iterable(
            (
                nn.Linear(in_features, out_features),
                activation,
            )
            for in_features, out_features in pairwise(dims)
        ))

这种方法是刚性的,因为它需要预先实例化激活。如果我们需要改变nn.HardTanh类的参数,之前的方法就不再适用。我们可以将实现方式修改为在实例化之前查询类,然后可选地传递一些参数。

from itertools import chain

from more_itertools import pairwise
from torch import nn

activation_lookup: dict[str, type[nn.Module]] = {
   "relu": nn.ReLU,
   "tanh": nn.Tanh,
   "hardtanh": nn.Hardtanh,
}

class MLP(nn.Sequential):
    def __init__(
        self, 
        dims: list[int], 
        activation: str = "relu", 
        activation_kwargs: None | dict[str, any] = None,
    ):
        activation_cls = activation_lookup[activation]
        activation = activation_cls(**(activation_kwargs or {}))
        super().__init__(chain.from_iterable(
            (
                nn.Linear(in_features, out_features),
                activation,
            )
            for in_features, out_features in pairwise(dims)
        ))

这很好,但仍然存在一些问题

  1. 您必须手动维护activation_lookup字典
  2. 您不能通过activation关键字传递实例或类
  3. 您必须正确地处理大小写
  4. 默认值是硬编码为字符串,这意味着在创建MLP的任何地方都需要复制(易出错)
  5. 您必须为所有类重写此逻辑

现在介绍class_resolver包,它使用以下方法处理所有这些问题

from itertools import chain

from class_resolver import ClassResolver, Hint
from more_itertools import pairwise
from torch import nn

activation_resolver = ClassResolver(
    [nn.ReLU, nn.Tanh, nn.Hardtanh],
    base=nn.Module,
    default=nn.ReLU,
)

class MLP(nn.Sequential):
    def __init__(
        self, 
        dims: list[int], 
        activation: Hint[nn.Module] = None,  # Hint = Union[None, str, nn.Module, type[nn.Module]]
        activation_kwargs: None | dict[str, any] = None,
    ):
        super().__init__(chain.from_iterable(
            (
                nn.Linear(in_features, out_features),
                activation_resolver.make(activation, activation_kwargs),
            )
            for in_features, out_features in pairwise(dims)
        ))

由于这是一个非常常见的模式,我们已经在class_resolver.contrib.torch中的贡献模块中提供了它

from itertools import chain

from class_resolver import Hint
from class_resolver.contrib.torch import activation_resolver
from more_itertools import pairwise
from torch import nn

class MLP(nn.Sequential):
    def __init__(
        self, 
        dims: list[int], 
        activation: Hint[nn.Module] = None,
        activation_kwargs: None | dict[str, any] = None,
    ):
        super().__init__(chain.from_iterable(
            (
                nn.Linear(in_features, out_features),
                activation_resolver.make(activation, activation_kwargs),
            )
            for in_features, out_features in pairwise(dims)
        ))

现在,您可以使用以下方式实例化MLP

MLP(dims=[10, 200, 40])  # uses default, which is ReLU
MLP(dims=[10, 200, 40], activation="relu")  # uses lowercase
MLP(dims=[10, 200, 40], activation="ReLU")  # uses stylized
MLP(dims=[10, 200, 40], activation=nn.ReLU)  # uses class
MLP(dims=[10, 200, 40], activation=nn.ReLU())  # uses instance

MLP(dims=[10, 200, 40], activation="hardtanh", activation_kwargs={"min_val": 0.0, "max_value": 6.0})  # uses kwargs
MLP(dims=[10, 200, 40], activation=nn.HardTanh, activation_kwargs={"min_val": 0.0, "max_value": 6.0})  # uses kwargs
MLP(dims=[10, 200, 40], activation=nn.HardTanh(0.0, 6.0))  # uses instance

在实践中,最好使用字符串与超参数优化库(如Optuna)结合使用。

⬇️ 安装

可以从PyPI安装最新版本

$ pip install class_resolver

最新代码和数据可以直接从GitHub安装

$ pip install git+https://github.com/cthoyt/class-resolver.git

要在开发模式下安装,请使用以下命令

$ git clone git+https://github.com/cthoyt/class-resolver.git
$ cd class-resolver
$ pip install -e .

🙏 贡献

欢迎提交问题、创建pull request或fork。有关参与的更多信息,请参阅CONTRIBUTING.rst

👋 授权

⚖️ 许可证

此包中的代码根据MIT许可证授权。

🍪 Cookiecutter

此包是用@audreyfeldroycookiecutter包和@cthoytcookiecutter-snekpack模板创建的。

🛠️ 对于开发者

请参阅开发者说明

README的最后一部分是如果您想通过代码贡献来参与其中。

❓ 测试

在克隆存储库并使用pip install tox安装tox后,可以使用以下命令可重复运行tests/文件夹中的单元测试

$ tox

此外,这些测试在每个提交时都会自动重新运行,在GitHub Action中。

📦 发布版本

在开发模式下安装包并使用pip install tox安装tox后,在tox.inifinish环境中包含制作新版本的命令。在shell中运行以下命令

$ tox -e finish

此脚本执行以下操作

  1. 使用BumpVersion切换setup.cfgsrc/{{cookiecutter.package_name}}/version.py中的版本号,去掉-dev后缀
  2. 将代码打包成tar存档和wheel
  3. 使用twine上传到PyPI。确保配置了.pypirc文件,以避免在此步骤需要手动输入
  4. 推送到GitHub。您需要创建一个带有版本提升提交的发布版本
  5. 将版本提升到下一个补丁。如果您进行了重大更改并希望通过次要版本提升版本,则可以在之后使用tox -e bumpversion minor

项目详情


下载文件

下载适用于您的平台的文件。如果您不确定选择哪一个,请了解更多关于安装包的信息。

源分发

class_resolver-0.5.2.tar.gz (42.8 kB 查看散列)

上传时间

构建分发

class_resolver-0.5.2-py3-none-any.whl (28.8 kB 查看散列)

上传时间 Python 3

由以下提供支持