Python函数重载
项目描述
Ovld
Python中的快速多播,具有许多额外功能。
使用ovld,您可以使用注解而不是编写笨拙的isinstance
语句序列来为每个类型签名编写相同函数的版本。与Python的singledispatch
不同,它可以用于多个参数。
- ⚡️ 快速:
ovld
是周围最快的多播库,有相当大的差距。 - 🚀 变体 和 混合 函数和方法。
- 🦄 依赖类型: 重载函数可以依赖于比参数类型更多的东西:它们可以依赖于实际值。
- 🔑 广泛: 在函数、方法、位置参数甚至在某些限制下关键字参数上分发。
示例
这是一个递归添加列表、元组和字典的函数
from ovld import ovld, recurse
@ovld
def add(x: list, y: list):
return [recurse(a, b) for a, b in zip(x, y)]
@ovld
def add(x: tuple, y: tuple):
return tuple(recurse(a, b) for a, b in zip(x, y))
@ovld
def add(x: dict, y: dict):
return {k: recurse(v, y[k]) for k, v in x.items()}
@ovld
def add(x: object, y: object):
return x + y
assert add([1, 2], [3, 4]) == [4, 6]
recurse
函数是特殊的:它将递归调用当前的ovld对象。您可能会问:它与简单地调用add
有何不同?区别在于,如果您创建一个add
的变体,recurse
将自动调用该变体。
例如
变体
一个ovld
的变体是它的一个副本,其中添加或更改了一些方法。例如,让我们以上面add
的定义为例,并创建一个将数字相乘的变体
@add.variant
def mul(self, x: object, y: object):
return x * y
assert mul([1, 2], [3, 4]) == [3, 8]
很简单!这意味着您可以定义一个递归遍历通用数据结构的ovld
,然后以各种方式对其进行专门化。
优先级和call_next
可以为每个方法定义一个数值优先级(默认优先级为0)
from ovld import call_next
@ovld(priority=1000)
def f(x: int):
return call_next(x + 1)
@ovld
def f(x: int):
return x * x
assert f(10) == 121
上面两种定义具有相同的类型签名,但由于第一个具有更高的优先级,因此将被调用。
但这并不意味着无法调用第二个。实际上,当第一个函数调用特殊函数call_next(x + 1)
时,它将调用列表中自己的下一个函数。
您在上面的模式中看到的模式是您如何将一些通用行为包装在每个调用中。例如,如果您这样做
@f.variant(priority=1000)
def f2(x: object)
print(f"f({x!r})")
return call_next(x)
您将有效地创建一个f
的克隆,它会跟踪每个调用。
相关类型
相关类型是依赖于值的类型。ovld
支持此功能,无论是通过Literal[value]
还是通过Dependent[bound, check]
。例如,下面是阶乘的定义
from typing import Literal
from ovld import ovld, recurse, Dependent
@ovld
def fact(n: Literal[0]):
return 1
@ovld
def fact(n: Dependent[int, lambda n: n > 0]):
return n * recurse(n - 1)
assert fact(5) == 120
fact(-1) # Error!
Dependent
的第一个参数必须是一个类型界限。在调用逻辑之前,界限必须匹配,这也确保了我们不会因为不相关的类型而受到性能影响。对于类型检查的目的,Dependent[T, A]
与Annotated[T, A]
等价。
dependent_check
使用@dependent_check
装饰器定义自己的类型
import torch
from ovld import ovld, dependent_check
@dependent_check
def Shape(tensor: torch.Tensor, *shape):
return (
len(tensor.shape) == len(shape)
and all(s2 is Any or s1 == s2 for s1, s2 in zip(tensor.shape, shape))
)
@dependent_check
def Dtype(tensor: torch.Tensor, dtype):
return tensor.dtype == dtype
@ovld
def f(tensor: Shape[3, Any]):
# Matches 3xN tensors
...
@ovld
def f(tensor: Shape[2, 2] & Dtype[torch.float32]):
# Only matches 2x2 tensors that also have the float32 dtype
...
第一个参数是要检查的值。类型注解(例如上面的value: torch.Tensor
)由ovld
解释为该类型的界限,因此只有类型为torch.Tensor
的参数才会调用Shape
。
方法
可以继承自OvldBase
或使用OvldMC
元类来使用方法的多重分发。
from ovld import OvldBase, OvldMC
# class Cat(OvldBase): <= Also an option
class Cat(metaclass=OvldMC):
def interact(self, x: Mouse):
return "catch"
def interact(self, x: Food):
return "devour"
def interact(self, x: PricelessVase):
return "destroy"
子类
子类继承重载方法。它们可以为此方法定义额外的重载,这些重载只对子类有效,但它们需要使用@extend_super
装饰器(这是出于清晰度的考虑)
from ovld import OvldMC, extend_super
class One(metaclass=OvldMC):
def f(self, x: int):
return "an integer"
class Two(One):
@extend_super
def f(self, x: str):
return "a string"
assert Two().f(1) == "an integer"
assert Two().f("s") == "a string"
基准测试
ovld
非常快:开销与isinstance
或match
相当,在调用Literal
类型时只有2-3倍慢。与其他多重分发库相比,它的开销减少了1.5倍到100倍。
相对于最快实现的相对时间(1.00)(越低越好)。
基准测试 | 自定义 | ovld | plum | multim | multid | runtype | fastcore | sd |
---|---|---|---|---|---|---|---|---|
trivial | 1.45 | 1.00 | 3.32 | 4.63 | 2.04 | 2.41 | 51.93 | 1.91 |
multer | 1.13 | 1.00 | 11.05 | 4.53 | 8.31 | 2.19 | 46.74 | 7.32 |
add | 1.08 | 1.00 | 3.73 | 5.21 | 2.37 | 2.79 | 59.31 | x |
ast | 1.00 | 1.08 | 23.14 | 3.09 | 1.68 | 1.91 | 28.39 | 1.66 |
calc | 1.00 | 1.23 | 54.61 | 29.32 | x | x | x | x |
regexp | 1.00 | 1.87 | 19.18 | x | x | x | x | x |
fib | 1.00 | 3.30 | 444.31 | 125.77 | x | x | x | x |
tweaknum | 1.00 | 2.09 | x | x | x | x | x | x |
项目详情
下载文件
下载适用于您的平台的文件。如果您不确定选择哪个,请了解更多关于安装包的信息。
源代码分发
构建分发
ovld-0.4.5.tar.gz 的哈希值
算法 | 哈希摘要 | |
---|---|---|
SHA256 | c0f0393fc1bf5bbf55530a6bedbdae2a22d854590b55954a5a9b8e6292442b0d |
|
MD5 | 76d617c1f2d02a65ccb35f3c34f8f531 |
|
BLAKE2b-256 | 396be9c11f83471750e711130ff55f532e3b99894a85d4092b452184010ffb34 |
ovld-0.4.5-py3-none-any.whl 的哈希值
算法 | 哈希摘要 | |
---|---|---|
SHA256 | fe603d9f1f5eca110aa6a821ef2d12528be380a15969fc666f8cacabc473da61 |
|
MD5 | 36ea95ffaa2799ff32a648d3ef833a5e |
|
BLAKE2b-256 | 797273116206447b302447e1adf3741ab4acb2c30eddca550cca1b0a1d3362fe |