跳转到主要内容

重载NumPy函数

项目描述

overload_numpy's PyPI Status overload_numpy's Coverage Status overload_numpy's Documentation Status Codestyle Black pre-commit

overload_numpy 提供了用于处理NumPy的 __array_(u)func(tion)__ 的易于使用的工具。该库是完全类型化的,并且使用mypyc编译了wheel。

实现重载

首先,一些导入

>>> from dataclasses import dataclass, fields
>>> from typing import ClassVar
>>> import numpy as np
>>> from overload_numpy import NumPyOverloader, NPArrayOverloadMixin

现在我们可以定义一个 NumPyOverloader 实例

>>> W_FUNCS = NumPyOverloader()

重载适用于数组包装类。让我们定义一个

>>> @dataclass
... class Wrap1D(NPArrayOverloadMixin):
...     '''A simple array wrapper.'''
...     x: np.ndarray
...     NP_OVERLOADS: ClassVar[NumPyOverloader] = W_FUNCS
>>> w1d = Wrap1D(np.arange(3))

现在可以重载和注册给 Wrap1Dnumpy.ufunc(例如 numpy.add)和 numpy 函数(例如 numpy.concatenate)。

>>> @W_FUNCS.implements(np.add, Wrap1D)
... def add(w1, w2):
...     return Wrap1D(np.add(w1.x, w2.x))
>>> @W_FUNCS.implements(np.concatenate, Wrap1D)
... def concatenate(w1ds):
...     return Wrap1D(np.concatenate(tuple(w.x for w in w1ds)))

现在检查这些是否工作

>>> np.add(w1d, w1d)
Wrap1D(x=array([0, 2, 4]))
>>> np.concatenate((w1d, w1d))
Wrap1D(x=array([0, 1, 2, 0, 1, 2]))

ufunc 也有一些方法:“at”,“accumulate”等。NEP13中的函数分发机制指出,“如果输入或输出参数中的任何一个实现了 __array_ufunc__,则执行它而不是ufunc。”目前,重载的 numpy.add 对于任何 ufunc 方法都不起作用。

>>> try: np.add.accumulate(w1d)
... except Exception: print("failed")
failed

可以在包装的 add 实现上注册 ufunc 方法重载

>>> @add.register('accumulate')
... def add_accumulate(w1):
...     return Wrap1D(np.add.accumulate(w1.x))
>>> np.add.accumulate(w1d)
Wrap1D(x=array([0, 1, 3]))

为子类分发重载

如果我们定义一个Wrap1D的子类会怎么样呢?

>>> @dataclass
... class Wrap2D(Wrap1D):
...     '''A simple 2-array wrapper.'''
...     y: np.ndarray

Wrap1D上注册的numpy.concatenate的重载对于Wrap2D将不会正常工作。然而,NumPyOverloader支持对调用类型的单分派,因此可以针对子类自定义重载。

>>> @W_FUNCS.implements(np.add, Wrap2D)
... def add(w1, w2):
...     print("using Wrap2D implementation...")
...     return Wrap2D(np.add(w1.x, w2.x),
...                   np.add(w1.y, w2.y))
>>> @W_FUNCS.implements(np.concatenate, Wrap2D)
... def concatenate2(w2ds):
...     print("using Wrap2D implementation...")
...     return Wrap2D(np.concatenate(tuple(w.x for w in w2ds)),
...                   np.concatenate(tuple(w.y for w in w2ds)))

检查这些是否工作

>>> w2d = Wrap2D(np.arange(3), np.arange(3, 6))
>>> np.add(w2d, w2d)
using Wrap2D implementation...
Wrap2D(x=array([0, 2, 4]), y=array([ 6, 8, 10]))
>>> np.concatenate((w2d, w2d))
using Wrap2D implementation...
Wrap2D(x=array([0, 1, 2, 0, 1, 2]), y=array([3, 4, 5, 3, 4, 5]))

太好了!但与其为每个子类定义新的实现,让我们看看如何编写一个更通用的重载。

>>> @W_FUNCS.implements(np.add, Wrap1D)  # overriding both
... @W_FUNCS.implements(np.add, Wrap2D)  # overriding both
... def add_general(w1, w2):
...     WT = type(w1)
...     return WT(*(np.add(getattr(w1, f.name), getattr(w2, f.name))
...                 for f in fields(WT)))
>>> @W_FUNCS.implements(np.concatenate, Wrap1D)  # overriding both
... @W_FUNCS.implements(np.concatenate, Wrap2D)  # overriding both
... def concatenate_general(ws):
...     WT = type(ws[0])
...     return WT(*(np.concatenate(tuple(getattr(w, f.name) for w in ws))
...                 for f in fields(WT)))

检查这些是否工作

>>> np.add(w2d, w2d)
Wrap2D(x=array([0, 2, 4]), y=array([ 6, 8, 10]))
>>> np.concatenate((w2d, w2d))
Wrap2D(x=array([0, 1, 2, 0, 1, 2]), y=array([3, 4, 5, 3, 4, 5]))
>>> @dataclass
... class Wrap3D(Wrap2D):
...     '''A simple 3-array wrapper.'''
...     z: np.ndarray
>>> w3d = Wrap3D(np.arange(2), np.arange(3, 5), np.arange(6, 8))
>>> np.add(w3d, w3d)
Wrap3D(x=array([0, 2]), y=array([6, 8]), z=array([12, 14]))
>>> np.concatenate((w3d, w3d))
Wrap3D(x=array([0, 1, 0, 1]), y=array([3, 4, 3, 4]), z=array([6, 7, 6, 7]))

协助重载组

在前面的例子中,我们为单个NumPy函数编写了实现。以这种方式重载全部NumPy函数将花费很长时间。

如果我们能够根据NumPy函数组编写更少的代码,不是更好吗?

>>> add_funcs = {np.add, np.subtract}
>>> @W_FUNCS.assists(add_funcs, types=Wrap1D, dispatch_on=Wrap1D)
... def add_assists(cls, func, w1, w2, *args, **kwargs):
...     return cls(*(func(getattr(w1, f.name), getattr(w2, f.name), *args, **kwargs)
...                     for f in fields(cls)))
>>> stack_funcs = {np.vstack, np.hstack, np.dstack, np.column_stack, np.row_stack}
>>> @W_FUNCS.assists(stack_funcs, types=Wrap1D, dispatch_on=Wrap1D)
... def stack_assists(cls, func, ws, *args, **kwargs):
...     return cls(*(func(tuple(getattr(v, f.name) for v in ws), *args, **kwargs)
...                     for f in fields(cls)))

检查这些是否工作

>>> np.subtract(w2d, w2d)
Wrap2D(x=array([0, 0, 0]), y=array([0, 0, 0]))
>>> np.vstack((w1d, w1d))
Wrap1D(x=array([[0, 1, 2],
                    [0, 1, 2]]))
>>> np.hstack((w1d, w1d))
Wrap1D(x=array([0, 1, 2, 0, 1, 2]))

我们还希望为所有的add_funcs重载实现accumulate方法。

>>> @add_assists.register("accumulate")
... def add_accumulate_assists(cls, func, w1, *args, **kwargs):
...     return cls(*(func(getattr(w1, f.name), *args, **kwargs)
...                  for f in fields(cls)))
>>> np.subtract.accumulate(w2d)
Wrap2D(x=array([ 0, -1, -3]), y=array([ 3, -1, -6]))

细节

想了解类型约束和API?请查看文档!

项目详情


下载文件

下载适用于您平台的文件。如果您不确定选择哪个,请了解有关安装包的更多信息。

源分布

overload_numpy-0.1.0.tar.gz (52.3 kB 查看哈希值)

上传时间:

构建分布

overload_numpy-0.1.0-py3-none-any.whl (28.6 kB 查看哈希值)

上传时间: Python 3

overload_numpy-0.1.0-cp310-cp310-win_amd64.whl (126.4 kB 查看哈希值)

上传时间: CPython 3.10 Windows x86-64

overload_numpy-0.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (237.0 kB 查看哈希值)

上传时间: CPython 3.10 manylinux: glibc 2.17+ x86-64

overload_numpy-0.1.0-cp310-cp310-macosx_11_0_arm64.whl (136.2 kB 查看哈希值)

上传时间: CPython 3.10 macOS 11.0+ ARM64

overload_numpy-0.1.0-cp310-cp310-macosx_10_9_x86_64.whl (140.0 kB 查看哈希值)

上传时间: CPython 3.10 macOS 10.9+ x86-64

overload_numpy-0.1.0-cp310-cp310-macosx_10_9_universal2.whl (246.6 kB 查看哈希值)

上传于 CPython 3.10 macOS 10.9+ universal2 (ARM64, x86-64)

overload_numpy-0.1.0-cp39-cp39-win_amd64.whl (126.3 kB 查看哈希值)

上传于 CPython 3.9 Windows x86-64

overload_numpy-0.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (236.5 kB 查看哈希值)

上传于 CPython 3.9 manylinux: glibc 2.17+ x86-64

overload_numpy-0.1.0-cp39-cp39-macosx_11_0_arm64.whl (136.3 kB 查看哈希值)

上传于 CPython 3.9 macOS 11.0+ ARM64

overload_numpy-0.1.0-cp39-cp39-macosx_10_9_x86_64.whl (140.0 kB 查看哈希值)

上传于 CPython 3.9 macOS 10.9+ x86-64

overload_numpy-0.1.0-cp39-cp39-macosx_10_9_universal2.whl (246.6 kB 查看哈希值)

上传于 CPython 3.9 macOS 10.9+ universal2 (ARM64, x86-64)

overload_numpy-0.1.0-cp38-cp38-win_amd64.whl (126.1 kB 查看哈希值)

上传于 CPython 3.8 Windows x86-64

overload_numpy-0.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (234.4 kB 查看哈希值)

上传于 CPython 3.8 manylinux: glibc 2.17+ x86-64

overload_numpy-0.1.0-cp38-cp38-macosx_11_0_arm64.whl (134.8 kB 查看哈希值)

上传于 CPython 3.8 macOS 11.0+ ARM64

overload_numpy-0.1.0-cp38-cp38-macosx_10_9_x86_64.whl (138.2 kB 查看哈希值)

上传于 CPython 3.8 macOS 10.9+ x86-64

overload_numpy-0.1.0-cp38-cp38-macosx_10_9_universal2.whl (243.4 kB 查看哈希值)

上传于 CPython 3.8 macOS 10.9+ universal2 (ARM64, x86-64)

支持

AWSAWS云计算和安全赞助商DatadogDatadog监控FastlyFastlyCDNGoogleGoogle下载分析MicrosoftMicrosoftPSF赞助商PingdomPingdom监控SentrySentry错误日志StatusPageStatusPage状态页面