以样式查找和实例化类。
项目描述
Class Resolver
以样式查找和实例化类。
💪 入门指南
from class_resolver import ClassResolver
from dataclasses import dataclass
class Base: pass
@dataclass
class A(Base):
name: str
@dataclass
class B(Base):
name: str
# Index
resolver = ClassResolver([A, B], base=Base)
# Lookup
assert A == resolver.lookup('A')
# Instantiate with a dictionary
assert A(name='hi') == resolver.make('A', {'name': 'hi'})
# Instantiate with kwargs
assert A(name='hi') == resolver.make('A', name='hi')
# A pre-instantiated class will simply be passed through
assert A(name='hi') == resolver.make(A(name='hi'))
🤖 使用 class-resolver
编写可扩展的机器学习模型
假设您已经在PyTorch中实现了一个简单的多层感知器
from itertools import chain
from more_itertools import pairwise
from torch import nn
class MLP(nn.Sequential):
def __init__(self, dims: list[int]):
super().__init__(chain.from_iterable(
(
nn.Linear(in_features, out_features),
nn.ReLU(),
)
for in_features, out_features in pairwise(dims)
))
这个MLP使用硬编码的ReLU作为层之间的非线性激活函数。我们可以通过向其 __init__()
函数添加一个参数来将这个MLP泛化,使用多种非线性激活函数,例如
from itertools import chain
from more_itertools import pairwise
from torch import nn
class MLP(nn.Sequential):
def __init__(self, dims: list[int], activation: str = "relu"):
if activation == "relu":
activation = nn.ReLU()
elif activation == "tanh":
activation = nn.Tanh()
elif activation == "hardtanh":
activation = nn.Hardtanh()
else:
raise KeyError(f"Unsupported activation: {activation}")
super().__init__(chain.from_iterable(
(
nn.Linear(in_features, out_features),
activation,
)
for in_features, out_features in pairwise(dims)
))
这个实现的第一问题是它依赖于硬编码的条件语句集,因此难以扩展。可以通过使用字典查找来改进
from itertools import chain
from more_itertools import pairwise
from torch import nn
activation_lookup: dict[str, nn.Module] = {
"relu": nn.ReLU(),
"tanh": nn.Tanh(),
"hardtanh": nn.Hardtanh(),
}
class MLP(nn.Sequential):
def __init__(self, dims: list[int], activation: str = "relu"):
activation = activation_lookup[activation]
super().__init__(chain.from_iterable(
(
nn.Linear(in_features, out_features),
activation,
)
for in_features, out_features in pairwise(dims)
))
这种方法是刚性的,因为它需要预先实例化激活。如果我们需要改变nn.HardTanh
类的参数,之前的方法就不再适用。我们可以将实现方式修改为在实例化之前查询类,然后可选地传递一些参数。
from itertools import chain
from more_itertools import pairwise
from torch import nn
activation_lookup: dict[str, type[nn.Module]] = {
"relu": nn.ReLU,
"tanh": nn.Tanh,
"hardtanh": nn.Hardtanh,
}
class MLP(nn.Sequential):
def __init__(
self,
dims: list[int],
activation: str = "relu",
activation_kwargs: None | dict[str, any] = None,
):
activation_cls = activation_lookup[activation]
activation = activation_cls(**(activation_kwargs or {}))
super().__init__(chain.from_iterable(
(
nn.Linear(in_features, out_features),
activation,
)
for in_features, out_features in pairwise(dims)
))
这很好,但仍然存在一些问题
- 您必须手动维护
activation_lookup
字典 - 您不能通过
activation
关键字传递实例或类 - 您必须正确地处理大小写
- 默认值是硬编码为字符串,这意味着在创建MLP的任何地方都需要复制(易出错)
- 您必须为所有类重写此逻辑
现在介绍class_resolver
包,它使用以下方法处理所有这些问题
from itertools import chain
from class_resolver import ClassResolver, Hint
from more_itertools import pairwise
from torch import nn
activation_resolver = ClassResolver(
[nn.ReLU, nn.Tanh, nn.Hardtanh],
base=nn.Module,
default=nn.ReLU,
)
class MLP(nn.Sequential):
def __init__(
self,
dims: list[int],
activation: Hint[nn.Module] = None, # Hint = Union[None, str, nn.Module, type[nn.Module]]
activation_kwargs: None | dict[str, any] = None,
):
super().__init__(chain.from_iterable(
(
nn.Linear(in_features, out_features),
activation_resolver.make(activation, activation_kwargs),
)
for in_features, out_features in pairwise(dims)
))
由于这是一个非常常见的模式,我们已经在class_resolver.contrib.torch
中的贡献模块中提供了它
from itertools import chain
from class_resolver import Hint
from class_resolver.contrib.torch import activation_resolver
from more_itertools import pairwise
from torch import nn
class MLP(nn.Sequential):
def __init__(
self,
dims: list[int],
activation: Hint[nn.Module] = None,
activation_kwargs: None | dict[str, any] = None,
):
super().__init__(chain.from_iterable(
(
nn.Linear(in_features, out_features),
activation_resolver.make(activation, activation_kwargs),
)
for in_features, out_features in pairwise(dims)
))
现在,您可以使用以下方式实例化MLP
MLP(dims=[10, 200, 40]) # uses default, which is ReLU
MLP(dims=[10, 200, 40], activation="relu") # uses lowercase
MLP(dims=[10, 200, 40], activation="ReLU") # uses stylized
MLP(dims=[10, 200, 40], activation=nn.ReLU) # uses class
MLP(dims=[10, 200, 40], activation=nn.ReLU()) # uses instance
MLP(dims=[10, 200, 40], activation="hardtanh", activation_kwargs={"min_val": 0.0, "max_value": 6.0}) # uses kwargs
MLP(dims=[10, 200, 40], activation=nn.HardTanh, activation_kwargs={"min_val": 0.0, "max_value": 6.0}) # uses kwargs
MLP(dims=[10, 200, 40], activation=nn.HardTanh(0.0, 6.0)) # uses instance
在实践中,最好使用字符串与超参数优化库(如Optuna)结合使用。
⬇️ 安装
可以从PyPI安装最新版本
$ pip install class_resolver
最新代码和数据可以直接从GitHub安装
$ pip install git+https://github.com/cthoyt/class-resolver.git
要在开发模式下安装,请使用以下命令
$ git clone git+https://github.com/cthoyt/class-resolver.git
$ cd class-resolver
$ pip install -e .
🙏 贡献
欢迎提交问题、创建pull request或fork。有关参与的更多信息,请参阅CONTRIBUTING.rst。
👋 授权
⚖️ 许可证
此包中的代码根据MIT许可证授权。
🍪 Cookiecutter
此包是用@audreyfeldroy的cookiecutter包和@cthoyt的cookiecutter-snekpack模板创建的。
🛠️ 对于开发者
请参阅开发者说明
README的最后一部分是如果您想通过代码贡献来参与其中。
❓ 测试
在克隆存储库并使用pip install tox
安装tox
后,可以使用以下命令可重复运行tests/
文件夹中的单元测试
$ tox
此外,这些测试在每个提交时都会自动重新运行,在GitHub Action中。
📦 发布版本
在开发模式下安装包并使用pip install tox
安装tox
后,在tox.ini
的finish
环境中包含制作新版本的命令。在shell中运行以下命令
$ tox -e finish
此脚本执行以下操作
- 使用BumpVersion切换
setup.cfg
和src/{{cookiecutter.package_name}}/version.py
中的版本号,去掉-dev
后缀 - 将代码打包成tar存档和wheel
- 使用
twine
上传到PyPI。确保配置了.pypirc
文件,以避免在此步骤需要手动输入 - 推送到GitHub。您需要创建一个带有版本提升提交的发布版本
- 将版本提升到下一个补丁。如果您进行了重大更改并希望通过次要版本提升版本,则可以在之后使用
tox -e bumpversion minor
项目详情
下载文件
下载适用于您的平台的文件。如果您不确定选择哪一个,请了解更多关于安装包的信息。