跳转到主要内容

Python函数重载

项目描述

Ovld

Python中的快速多播,具有许多额外功能。

📋 文档

使用ovld,您可以使用注解而不是编写笨拙的isinstance语句序列来为每个类型签名编写相同函数的版本。与Python的singledispatch不同,它可以用于多个参数。

示例

这是一个递归添加列表、元组和字典的函数

from ovld import ovld, recurse

@ovld
def add(x: list, y: list):
    return [recurse(a, b) for a, b in zip(x, y)]

@ovld
def add(x: tuple, y: tuple):
    return tuple(recurse(a, b) for a, b in zip(x, y))

@ovld
def add(x: dict, y: dict):
    return {k: recurse(v, y[k]) for k, v in x.items()}

@ovld
def add(x: object, y: object):
    return x + y

assert add([1, 2], [3, 4]) == [4, 6]

recurse 函数是特殊的:它将递归调用当前的ovld对象。您可能会问:它与简单地调用add有何不同?区别在于,如果您创建一个add的变体,recurse将自动调用该变体。

例如

变体

一个ovld的变体是它的一个副本,其中添加或更改了一些方法。例如,让我们以上面add的定义为例,并创建一个将数字相乘的变体

@add.variant
def mul(self, x: object, y: object):
    return x * y

assert mul([1, 2], [3, 4]) == [3, 8]

很简单!这意味着您可以定义一个递归遍历通用数据结构的ovld,然后以各种方式对其进行专门化。

优先级和call_next

可以为每个方法定义一个数值优先级(默认优先级为0)

from ovld import call_next

@ovld(priority=1000)
def f(x: int):
    return call_next(x + 1)

@ovld
def f(x: int):
    return x * x

assert f(10) == 121

上面两种定义具有相同的类型签名,但由于第一个具有更高的优先级,因此将被调用。

但这并不意味着无法调用第二个。实际上,当第一个函数调用特殊函数call_next(x + 1)时,它将调用列表中自己的下一个函数。

您在上面的模式中看到的模式是您如何将一些通用行为包装在每个调用中。例如,如果您这样做

@f.variant(priority=1000)
def f2(x: object)
    print(f"f({x!r})")
    return call_next(x)

您将有效地创建一个f的克隆,它会跟踪每个调用。

相关类型

相关类型是依赖于值的类型。ovld支持此功能,无论是通过Literal[value]还是通过Dependent[bound, check]。例如,下面是阶乘的定义

from typing import Literal
from ovld import ovld, recurse, Dependent

@ovld
def fact(n: Literal[0]):
    return 1

@ovld
def fact(n: Dependent[int, lambda n: n > 0]):
    return n * recurse(n - 1)

assert fact(5) == 120
fact(-1)   # Error!

Dependent的第一个参数必须是一个类型界限。在调用逻辑之前,界限必须匹配,这也确保了我们不会因为不相关的类型而受到性能影响。对于类型检查的目的,Dependent[T, A]Annotated[T, A]等价。

dependent_check

使用@dependent_check装饰器定义自己的类型

import torch
from ovld import ovld, dependent_check

@dependent_check
def Shape(tensor: torch.Tensor, *shape):
    return (
        len(tensor.shape) == len(shape)
        and all(s2 is Any or s1 == s2 for s1, s2 in zip(tensor.shape, shape))
    )

@dependent_check
def Dtype(tensor: torch.Tensor, dtype):
    return tensor.dtype == dtype

@ovld
def f(tensor: Shape[3, Any]):
    # Matches 3xN tensors
    ...

@ovld
def f(tensor: Shape[2, 2] & Dtype[torch.float32]):
    # Only matches 2x2 tensors that also have the float32 dtype
    ...

第一个参数是要检查的值。类型注解(例如上面的value: torch.Tensor)由ovld解释为该类型的界限,因此只有类型为torch.Tensor的参数才会调用Shape

方法

可以继承自OvldBase或使用OvldMC元类来使用方法的多重分发。

from ovld import OvldBase, OvldMC

# class Cat(OvldBase):  <= Also an option
class Cat(metaclass=OvldMC):
    def interact(self, x: Mouse):
        return "catch"

    def interact(self, x: Food):
        return "devour"

    def interact(self, x: PricelessVase):
        return "destroy"

子类

子类继承重载方法。它们可以为此方法定义额外的重载,这些重载只对子类有效,但它们需要使用@extend_super装饰器(这是出于清晰度的考虑)

from ovld import OvldMC, extend_super

class One(metaclass=OvldMC):
    def f(self, x: int):
        return "an integer"

class Two(One):
    @extend_super
    def f(self, x: str):
        return "a string"

assert Two().f(1) == "an integer"
assert Two().f("s") == "a string"

基准测试

ovld非常快:开销与isinstancematch相当,在调用Literal类型时只有2-3倍慢。与其他多重分发库相比,它的开销减少了1.5倍到100倍。

相对于最快实现的相对时间(1.00)(越低越好)。

基准测试 自定义 ovld plum multim multid runtype fastcore sd
trivial 1.45 1.00 3.32 4.63 2.04 2.41 51.93 1.91
multer 1.13 1.00 11.05 4.53 8.31 2.19 46.74 7.32
add 1.08 1.00 3.73 5.21 2.37 2.79 59.31 x
ast 1.00 1.08 23.14 3.09 1.68 1.91 28.39 1.66
calc 1.00 1.23 54.61 29.32 x x x x
regexp 1.00 1.87 19.18 x x x x x
fib 1.00 3.30 444.31 125.77 x x x x
tweaknum 1.00 2.09 x x x x x x

项目详情


下载文件

下载适用于您的平台的文件。如果您不确定选择哪个,请了解更多关于安装包的信息。

源代码分发

ovld-0.4.5.tar.gz (1.2 MB 查看哈希值

上传时间 源代码

构建分发

ovld-0.4.5-py3-none-any.whl (30.6 kB 查看哈希值)

上传时间: Python 3

由以下支持