将配置变量传递给函数,以便它们表现为编译时常量

问题描述 投票:0回答:1

在 numba 中,我想将配置变量作为编译时常量传递给函数。具体来说我想做的是

    @njit
    def physics(config):
        flagA = config.flagA
        flagB = config.flagB
        aNumbaList = List()
        for i in range(100):
            if flagA:
                aNumbaList.append(i)
            else:
                aNumbaList.append(i/10)
        return aNumbaList

如果配置变量是编译时常量,这就会通过,但事实并非如此,所以它给我错误,说有两个候选者

There are 2 candidate implementations:
                 - Of which 2 did not match due to:
                 ...
                 ...

我查看了一份 numba 会议纪要,发现有一种方法可以做到这一点Numba 会议:2024-03-05 我尝试过,但仍然出现同样的错误。

这是带有错误消息的代码:

.. code:: ipython3

    from numba import jit, types, njit
    from numba.extending import overload
    from numba.typed import List
    import functools

.. code:: ipython3

    class Config():
        def __init__(self, flagA, flagB):
            self._flagA = flagA
            self._flagB = flagB
    
        @property
        def flagA(self):
            return self._flagA
    
        @property
        def flagB(self):
            return self._flagB

.. code:: ipython3

    @functools.cache
    def obj2strkeydict(obj, config_name):
    
        # unpack object to freevars and close over them
        tmp_a = obj.flagA
        tmp_b = obj.flagB
        assert isinstance(config_name, str)
        tmp_force_heterogeneous = config_name
    
        @njit
        def configurator():
            d = {'flagA': tmp_a,
                 'flagB': tmp_b,
                 'config_name': tmp_force_heterogeneous}
            return d
    
        # return a configuration function that returns a string-key-dict
        # representation of the configuration object.
        return configurator

.. code:: ipython3

    @njit
    def physics(cfig_func):
        config = cfig_func()
        flagA = config['flagA']
        flagB = config['flagB']
        aNumbaList = List()
        for i in range(100):
            if flagA:
                aNumbaList.append(i)
            else:
                aNumbaList.append(i/10)
        return aNumbaList

.. code:: ipython3

    def demo():
        configuration1 = Config(True, False)
        jit_config1 = obj2strkeydict(configuration1, 'config1')
        physics(jit_config1)

.. code:: ipython3

    demo()


::


    ---------------------------------------------------------------------------

    TypingError                               Traceback (most recent call last)

    Cell In[83], line 1
    ----> 1 demo()


    Cell In[82], line 4, in demo()
          2 configuration1 = Config(True, False)
          3 jit_config1 = obj2strkeydict(configuration1, 'config1')
    ----> 4 physics(jit_config1)


    File ~/anaconda3/envs/tardis/lib/python3.11/site-packages/numba/core/dispatcher.py:468, in _DispatcherBase._compile_for_args(self, *args, **kws)
        464         msg = (f"{str(e).rstrip()} \n\nThis error may have been caused "
        465                f"by the following argument(s):\n{args_str}\n")
        466         e.patch_message(msg)
    --> 468     error_rewrite(e, 'typing')
        469 except errors.UnsupportedError as e:
        470     # Something unsupported is present in the user code, add help info
        471     error_rewrite(e, 'unsupported_error')


    File ~/anaconda3/envs/tardis/lib/python3.11/site-packages/numba/core/dispatcher.py:409, in _DispatcherBase._compile_for_args.<locals>.error_rewrite(e, issue_type)
        407     raise e
        408 else:
    --> 409     raise e.with_traceback(None)


    TypingError: Failed in nopython mode pipeline (step: nopython frontend)
    - Resolution failure for literal arguments:
    No implementation of function Function(<function impl_append at 0x7fd87d253920>) found for signature:
    
     >>> impl_append(ListType[int64], float64)
    
    There are 2 candidate implementations:
          - Of which 2 did not match due to:
          Overload in function 'impl_append': File: numba/typed/listobject.py: Line 592.
            With argument(s): '(ListType[int64], float64)':
           Rejected as the implementation raised a specific error:
             TypingError: Failed in nopython mode pipeline (step: nopython frontend)
           No implementation of function Function(<intrinsic _cast>) found for signature:
    
            >>> _cast(float64, class(int64))
    
           There are 2 candidate implementations:
                 - Of which 2 did not match due to:
                 Intrinsic in function '_cast': File: numba/typed/typedobjectutils.py: Line 22.
                   With argument(s): '(float64, class(int64))':
                  Rejected as the implementation raised a specific error:
                    TypingError: cannot safely cast float64 to int64. Please cast explicitly.
             raised from /home/sam/anaconda3/envs/tardis/lib/python3.11/site-packages/numba/typed/typedobjectutils.py:75
           
           During: resolving callee type: Function(<intrinsic _cast>)
           During: typing of call at /home/sam/anaconda3/envs/tardis/lib/python3.11/site-packages/numba/typed/listobject.py (600)
           
           
           File "../anaconda3/envs/tardis/lib/python3.11/site-packages/numba/typed/listobject.py", line 600:
               def impl(l, item):
                   casteditem = _cast(item, itemty)
                   ^
    
      raised from /home/sam/anaconda3/envs/tardis/lib/python3.11/site-packages/numba/core/typeinfer.py:1086
    
    - Resolution failure for non-literal arguments:
    None
    
    During: resolving callee type: BoundFunction((<class 'numba.core.types.containers.ListType'>, 'append') for ListType[int64])
    During: typing of call at /tmp/ipykernel_9889/739598600.py (11)
    
    
    File "../../../tmp/ipykernel_9889/739598600.py", line 11:
    <source missing, REPL/exec in use?>

任何帮助或任何相关材料的参考都会对我很有帮助。 谢谢你。

python numba
1个回答
0
投票

在 Numba 中,全局变量是编译时常量,因此您可以使用它来执行您想要的操作。这是一个例子:

import numba as nb   # v0.58.1

flagA = True

@nb.njit
def physics(flagA):
    aNumbaList = nb.typed.List()
    for i in range(100):
        if flagA:
            aNumbaList.append(i)
        else:
            aNumbaList.append(i/10)
    return aNumbaList

在参数中传递

flagA
时效果很好,不会出现错误,因为
if
else
中的项目属于不同类型。

话虽这么说,全局变量在软件工程方面并不是很好,您可能希望在运行时为不同的配置编译函数(例如,基于初始化过程,同时避免写入全局变量)。

另一种解决方案是返回一个函数,该函数读取父函数中定义的变量,因此它被视为该函数的全局变量,因此被视为编译时常量。编译函数读取的变量可以作为参数传递给父函数。这是一个例子:

import numba as nb

def make_physics(flagA):
    @nb.njit
    def fun():
        aNumbaList = nb.typed.List()
        for i in range(100):
            if flagA:
                aNumbaList.append(i)
            else:
                aNumbaList.append(i/10)
        return aNumbaList

    return fun

physics = make_physics(True)  # Compile a specialized function every time it is called
physics()                     # Call the compiled function generated just before

这也不会导致任何错误,并且实际上按预期工作。以下是

physics
函数生成的汇编代码,显示主循环中没有对
flagA
进行运行时检查:

    [...]

    movq    %rax, %r12                 ; r12 = an allocated Python object (the list?)
    movq    24(%rax), %rax
    movq    %r14, (%rax)
    xorl    %ebx, %ebx                 ; i = 0
    movabsq $NRT_incref, %r13
    movabsq $numba_list_append, %rbp
    leaq    48(%rsp), %r15             ; (r15 is a pointer on i)

.LBB0_6:                               ; Main loop
    movq    %r12, %rcx                 
    callq   *%r13                      ; Call NRT_incref(r12)
    movq    %rbx, 48(%rsp)             
    movq    %r14, %rcx                 
    movq    %r15, %rdx                 
    callq   *%rbp                      ; Call numba_list_append(r14, pointer_of(i))
    testl   %eax, %eax                 
    jne .LBB0_7                        ; Stop the loop if numba_list_append returned a non-zero value
    incq    %rbx                       ; i += 1
    movq    %r12, %rcx                 
    movabsq $NRT_decref, %rax          
    callq   *%rax                      ; Call NRT_decref(r12)
    cmpq    $100, %rbx                 
    jne .LBB0_6                        ; Loop as long as i < 100

    [...]

关于实际用例,记忆和 Numba 函数缓存可以帮助避免针对相同配置多次编译目标函数。

© www.soinside.com 2019 - 2024. All rights reserved.