类型错误:不可散列类型:要设置的数组中的“numpy.ndarray”

问题描述 投票:0回答:1

在此定义中,我得到 TypeError: unhashable type: 'numpy.ndarray' in line returned_indices_set = set(retrieved_indices)

def evaluate_retrieval(query_idx, retrieved_indices, relevant_indices):
    # Convert each list element to a tuple
    # Flatten the two-layer list and convert elements to tuples
    arr = np.array(retrieved_indices) #retrieved_indices => 128 * 128 * 3

# Transpose the array and convert it to a list of tuples
    retrieved_indices = tuple(list(map(tuple, np.vstack(arr.T))))
    print(type(retrieved_indices))

# Create a set from the tuples
    retrieved_indices_set = set(retrieved_indices)
    relevant_retrieved = len(retrieved_indices_set.intersection(relevant_indices_set))
    precision = relevant_retrieved / len(retrieved_indices_set) if len(retrieved_indices_set) > 0 else 0
    return precision

I try this but didn't work
retrieved_indices_tuples = tuple(tuple(tuple(pixel) for pixel in row) for row in retrieved_indices)
python list hash set numpy-ndarray
1个回答
0
投票

无法将

numpy.ndarray
元素直接转换为集合。 您需要将
numpy.ndarray
转换为可哈希类型,例如元组。

类似这样的:

import numpy as np

def evaluate_retrieval(query_idx, retrieved_indices, relevant_indices):
    retrieved_indices_tuples = tuple(tuple(tuple(pixel) for pixel in row) for row in retrieved_indices)
    
    retrieved_indices_set = set(retrieved_indices_tuples)
    
    relevant_indices_set = set(relevant_indices)
    
    # Calculate the intersection of the sets
    relevant_retrieved = len(retrieved_indices_set.intersection(relevant_indices_set))
    
    precision = relevant_retrieved / len(retrieved_indices_set) if len(retrieved_indices_set) > 0 else 0
    
    return precision

retrieved_indices = np.random.randint(0, 256, (128, 128, 3))  # Sample data
relevant_indices = [((0, 0, 0), (1, 1, 1), (2, 2, 2))]  # Sample relevant indices

precision = evaluate_retrieval(0, retrieved_indices, relevant_indices)
print(f"Precision: {precision}")
© www.soinside.com 2019 - 2024. All rights reserved.