我指的是以下链接中的笔记本来进行图像相似性搜索:https://github.com/towhee-io/examples/blob/main/image/reverse_image_search/1_build_image_search_engine.ipynb
下面是我正在使用的代码:
import csv
from glob import glob
from pathlib import Path
from statistics import mean
from towhee import pipe, ops, DataCollection
from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility
# Towhee parameters
MODEL = 'resnet50'
# Milvus parameters
HOST = [MY_HOST]
PORT = [MY_PORT]
TOPK = 5
DIM = 2048
COLLECTION_NAME = 'images'
INDEX_TYPE = 'IVF_FLAT'
METRIC_TYPE = 'L2'
index_params = {
'metric_type': METRIC_TYPE,
'index_type': INDEX_TYPE,
'params': {"nlist": 2048}
}
collection.create_index(field_name='image', index_params=index_params, index_name = "image_index")
我只对图像进行了矢量化,然后我继续用元数据保存它
# Load image path
def load_image(x):
if x.endswith('csv'):
with open(x) as f:
reader = csv.reader(f)
next(reader)
for item in reader:
yield item[1]
else:
for item in glob(x):
yield item
# Embedding pipeline
p_embed = (
pipe.input('src')
.flat_map('src', 'img_path', load_image)
.map('img_path', 'img', ops.image_decode())
.map('img', 'vec', ops.image_embedding.timm(model_name=MODEL))
)
image_save_dir = [MY_IMAGE_PATH]
p_display = p_embed.output('img_path', 'img', 'vec')
result = DataCollection(p_display(image_save_dir))
# check result
result.show()
print(result[0]['img_path'])
print(result[0]['vec'])
connections.connect(alias='default', host=HOST, port=PORT)
collection_name = "clothes"
collection = Collection(name = collection_name)
for i, r in enumerate(result):
vector = r['vec']
collection.insert([
{
"clothes_id" : i,
"category" : "top",
"color" : "black",
"image" : vector,
"gender" : ["F"],
"style" : ["casual"],
"thickness" : [],
"season" : ["spring"]
}
])
然后我尝试通过图像相似度搜索来获取ID值
p_search_pre = (
p_embed.map('vec', ('search_res'), ops.ann_search.milvus_client(
host=HOST, port=PORT, limit=5,
collection_name="clothes"))
.map('search_res', 'pred', lambda x: [y[0] for y in x]) # get id
)
p_search = p_search_pre.output('img_path', 'pred')
# Search for example query image(s)
collection.load()
dc = p_search('[MY_IMAGE_PATH]/test37.png')
# Display search results with image paths
DataCollection(dc).show()
但是,我显示的时候没有任何结果。为了确保搜索正确,我下载了 Milvus 中存储的图像,并将它们保存在路径中,并计算了同一张图像的余弦相似度。
from numpy import dot
from numpy.linalg import norm
import numpy as np
def cos_similarity(A, B):
return dot(A, B) / (norm(A) * norm(B))
test_image = '[MY_IMAGE_PATH]/test38.png'
p_display = p_embed.output('img_path', 'img', 'vec')
result = DataCollection(p_display(image_save_dir))
collection_name = 'clothes'
collection = Collection(name=collection_name)
# get image vector where clothes_id=38
save_results = collection.query(
expr="clothes_id == 38",
output_fields=["clothes_id", "category", "color", "gender", "style", "thickness", "season", "image"]
)
if save_results:
saved_image_vector = save_results[0]["image"]
result_vector = np.array(result[0]['vec'])
saved_image_vector = np.array(saved_image_vector)
# get similarity
similarity = cos_similarity(result_vector, saved_image_vector)
print(f"cosine similarity : {similarity}")
else:
print("there is no images")
我得到的结果约为 1。我很困惑问题可能是什么,并且不确定为什么我无法执行正确的图像相似性搜索。
我遵循您的代码结构,并尝试在您执行相似性搜索以获取 ID 值的部分使用下面的代码,它成功地返回了结果。
p_search_pre = (
p_embed.map('vec', ('search_res'), ops.ann_search.milvus_client(
host=HOST, port=PORT, limit=5, collection_name="clothes"))
.map('search_res', 'pred', lambda x: [(y[0]) for y in x])
)