在此定义中,我得到 TypeError: unhashable type: 'numpy.ndarray' in line returned_indices_set = set(retrieved_indices)
def evaluate_retrieval(query_idx, retrieved_indices, relevant_indices):
# Convert each list element to a tuple
# Flatten the two-layer list and convert elements to tuples
arr = np.array(retrieved_indices) #retrieved_indices => 128 * 128 * 3
# Transpose the array and convert it to a list of tuples
retrieved_indices = tuple(list(map(tuple, np.vstack(arr.T))))
print(type(retrieved_indices))
# Create a set from the tuples
retrieved_indices_set = set(retrieved_indices)
relevant_retrieved = len(retrieved_indices_set.intersection(relevant_indices_set))
precision = relevant_retrieved / len(retrieved_indices_set) if len(retrieved_indices_set) > 0 else 0
return precision
I try this but didn't work
retrieved_indices_tuples = tuple(tuple(tuple(pixel) for pixel in row) for row in retrieved_indices)
无法将
numpy.ndarray
元素直接转换为集合。
您需要将 numpy.ndarray
转换为可哈希类型,例如元组。
类似这样的:
import numpy as np
def evaluate_retrieval(query_idx, retrieved_indices, relevant_indices):
retrieved_indices_tuples = tuple(tuple(tuple(pixel) for pixel in row) for row in retrieved_indices)
retrieved_indices_set = set(retrieved_indices_tuples)
relevant_indices_set = set(relevant_indices)
# Calculate the intersection of the sets
relevant_retrieved = len(retrieved_indices_set.intersection(relevant_indices_set))
precision = relevant_retrieved / len(retrieved_indices_set) if len(retrieved_indices_set) > 0 else 0
return precision
retrieved_indices = np.random.randint(0, 256, (128, 128, 3)) # Sample data
relevant_indices = [((0, 0, 0), (1, 1, 1), (2, 2, 2))] # Sample relevant indices
precision = evaluate_retrieval(0, retrieved_indices, relevant_indices)
print(f"Precision: {precision}")