python – Numba中的生成器参数

这个问题的后续跟进:
function types in numba.

我正在编写一个需要将生成器作为其参数之一的函数.粘贴在这里太复杂了,所以请考虑这个玩具示例:

def take_and_sum(gen):
    @numba.jit(nopython=False)
    def inner(n):
        s = 0
        for _ in range(n):
            s += next(gen)
        return s
    return inner

它返回生成器的前n个元素的总和.用法示例:

@numba.njit()
def odd_numbers():
    n = 1
    while True:
        yield n
        n += 2

take_and_sum(odd_numbers())(3) # prints 9

这是因为我想用nopython = True编译然后我不能将gen(一个pyobject)作为参数传递.不幸的是,使用nopython = True我收到一个错误:

TypingError: Failed at nopython (nopython frontend)
Untyped global name 'gen'

即使我nopython编译我的发电机.

令人困惑的是输入的硬编码:

def take_and_sum():
    @numba.njit()
    def inner(n):
        gen = odd_numbers()
        s = 0.0
        for _ in range(n):
            s += next(gen)
        return s
    return inner

take_and_sum()(3)

我也尝试将我的发电机变成一个类:

@numba.jitclass({'n': numba.uint})
class Odd:
    def __init__(self):
        self.n = 1
    def next(self):
        n = self.n
        self.n += 2
        return n

同样,这在对象模式下工作,但在nopython模式下,我得到了不可搜索的:

LoweringError: Failed at nopython (nopython mode backend)
Internal error:
NotImplementedError: instance.jitclass.Odd#4aa9758<n:uint64> as constant unsupported

最佳答案 我实际上无法解决你的问题,因为据我所知,这是不可能的.我只是突出了一些方面(对于numba 0.30有效):

你不能创建一个numba-jitclass生成器:

import numba

@numba.jitclass({'n': numba.uint})
class Odd:
    def __init__(self):
        self.n = 1

    def __iter__(self):
        return self

    def __next__(self):
        n = self.n
        self.n += 2
        return n

试一试:

>>> next(Odd())
TypeError: 'Odd' object is not an iterator

删除numba.jitclass时,它的工作原理如下:

>>> next(Odd())
1

使用硬编码生成器的示例不相同.您的原始尝试创建一个生成器对象将其传递给numba函数并修改生成器.您可能希望它更新生成器的状态.

>>> t = odd_numbers()
>>> take_and_sum(t)(3)
9
>>> next(t)   # State has been updated, unfortunatly that requires nopython=False!
7

但对于numba来说,这是不可能的.

第二个示例是不同的,因为每次调用函数时都会创建生成器,因此函数外部没有需要更新的状态:

>>> take_and_sum()(3) # using your hardcoded version
9.0
>>> take_and_sum()(3) # no updated state so this returns the same:
9.0

它绝对有可能改变它,但没有选择使用仲裁功能:

@numba.jitclass({'n': numba.uint})
class Odd:
    def __init__(self):
        self.n = 1

    def calculate(self, n):
        s = 0.0
        for _ in range(n):
            s += self.n
            self.n += 2
        return s

>>> x = Odd()
>>> x.calculate(3)
9.0
>>> x.calculate(3)
27.0

我知道这不是你想要的,但至少它在某种程度上有效:-)

点赞