这是代码:
import numpy as np
from functools import cache, wraps
def np_cache(function):
@cache
def cached_wrapper(*args, **kwargs):
args = [np.array(a) if isinstance(a, tuple) else a for a in args]
kwargs = {
k: np.array(v) if isinstance(v, tuple) else v for k, v in kwargs.items()
}
return function(*args, **kwargs)
@wraps(function)
def wrapper(*args, **kwargs):
args = [tuple(a) if isinstance(a, np.ndarray) else a for a in args]
kwargs = {
k: tuple(v) if isinstance(v, np.ndarray) else v for k, v in kwargs.items()
}
return cached_wrapper(*args, **kwargs)
wrapper.cache_info = cached_wrapper.cache_info
wrapper.cache_clear = cached_wrapper.cache_clear
return wrapper
x = np.array([[1,2],[3,4]])
y = np.array([2, 4])
@np_cache
def test2(x,shown = False):
if shown:
print(x)
return x
#test2(y,True)
test2(x)
我收到此错误:
(array([[1, 2],
[3, 4]]), True)
[(array([1, 2]), array([3, 4])), True]
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-79-73b6dd94e3a9> in <cell line: 42>()
40
41 #test2(y,True)
---> 42 test2(x, True)
43
44
<ipython-input-79-73b6dd94e3a9> in wrapper(*args, **kwargs)
23 k: tuple(v) if isinstance(v, np.ndarray) else v for k, v in kwargs.items()
24 }
---> 25 return cached_wrapper(*args, **kwargs)
26
27 wrapper.cache_info = cached_wrapper.cache_info
TypeError: unhashable type: 'numpy.ndarray'
如您所见,数组已正确转换为数组元组,所以我不知道为什么它不起作用。
我认为关键是在包装器中进行一些修改,但我不确定如何修改。
提前致谢
Numpy
ndarray
对象不可散列。但是如果你天真地在二维或更高维的数组上使用tuple
,你会得到一个 ndarray 对象的元组,它仍然不可散列。
因此,如果数组要坚持基本数据类型,您可以使用快速而肮脏的函数来转换为嵌套元组:
def _lol_to_tot_helper(lol, ndim):
if ndim == 1:
return tuple(lol)
elif ndim == 2:
return tuple(map(tuple, lol))
else:
return tuple(lol_to_tot(x, ndim-1) for x in lol)
def array_to_nested_tuple(arr):
lol = arr.tolist()
return _lol_to_tot_helper(lol, arr.ndim)
虽然是递归的,但它不应该非常低效,因为基本情况之一是二维情况,因此您不会递归到单个元素。小心不要传递像
np.array(1)
这样的东西,即 0 维 ndarray。