为什么这个函数缓存会给我多维 np.array 的错误消息?

问题描述 投票:0回答:1

这是代码:

import numpy as np
from functools import cache, wraps


def np_cache(function):
    @cache
    def cached_wrapper(*args, **kwargs):

        args = [np.array(a) if isinstance(a, tuple) else a for a in args]
        kwargs = {
            k: np.array(v) if isinstance(v, tuple) else v for k, v in kwargs.items()
        }

        return function(*args, **kwargs)

    @wraps(function)
    def wrapper(*args, **kwargs):
        args = [tuple(a) if isinstance(a, np.ndarray) else a for a in args]
        kwargs = {
            k: tuple(v) if isinstance(v, np.ndarray) else v for k, v in kwargs.items()
        }
        return cached_wrapper(*args, **kwargs)

    wrapper.cache_info = cached_wrapper.cache_info
    wrapper.cache_clear = cached_wrapper.cache_clear

    return wrapper

x = np.array([[1,2],[3,4]])
y = np.array([2, 4])

@np_cache
def test2(x,shown = False):
  if shown:
    print(x)
  return x

#test2(y,True)
test2(x)

我收到此错误:

(array([[1, 2],
       [3, 4]]), True)
[(array([1, 2]), array([3, 4])), True]
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-79-73b6dd94e3a9> in <cell line: 42>()
     40 
     41 #test2(y,True)
---> 42 test2(x, True)
     43 
     44 

<ipython-input-79-73b6dd94e3a9> in wrapper(*args, **kwargs)
     23             k: tuple(v) if isinstance(v, np.ndarray) else v for k, v in kwargs.items()
     24         }
---> 25         return cached_wrapper(*args, **kwargs)
     26 
     27     wrapper.cache_info = cached_wrapper.cache_info

TypeError: unhashable type: 'numpy.ndarray'

如您所见,数组已正确转换为数组元组,所以我不知道为什么它不起作用。

我认为关键是在包装器中进行一些修改,但我不确定如何修改。

提前致谢

python numpy caching
1个回答
0
投票
正如您所知,

Numpy

ndarray
对象不可散列。但是如果你天真地在二维或更高维的数组上使用
tuple
,你会得到一个 ndarray 对象的元组,它仍然不可散列。

因此,如果数组要坚持基本数据类型,您可以使用快速而肮脏的函数来转换为嵌套元组:

def _lol_to_tot_helper(lol, ndim):
    if ndim == 1:
        return tuple(lol)
    elif ndim == 2:
        return tuple(map(tuple, lol))
    else:
        return tuple(lol_to_tot(x, ndim-1) for x in lol)

def array_to_nested_tuple(arr):
    lol = arr.tolist()
    return _lol_to_tot_helper(lol, arr.ndim)

虽然是递归的,但它不应该非常低效,因为基本情况之一是二维情况,因此您不会递归到单个元素。小心不要传递像

np.array(1)
这样的东西,即 0 维 ndarray。

© www.soinside.com 2019 - 2024. All rights reserved.