我理解检索的任务——我已经浏览过代码;还研究了替代方法,例如 SCNN,它是一种超快的最近邻。
但是我还是很难理解下面代码的机制
# Create a model that takes in raw query features, and
index = tfrs.layers.factorized_top_k.BruteForce(model.user_model)
# recommends movies out of the entire movies dataset.
index.index_from_dataset(
tf.data.Dataset.zip((movies.batch(100), movies.batch(100).map(model.movie_model)))
)
# Get recommendations.
_, titles = index(tf.constant(["42"]))
print(f"Recommendations for user 42: {titles[0, :3]}")
model.user_model
已训练完毕,现在应该返回 user_id 的嵌入。 BruteForce
层的输入是model.user_model
;然后应该对其进行索引?
我猜输出给出了
user_id
42,返回 3 个标题,其中 movies.batch(100)
。但我无法理解 BruteForce 和索引的功能!
BruteForce 层测试从模型最后一层提取的嵌入之间的所有组合。
根据该层的 tensorflow 文档,该层重新调整最接近每个索引的 top k 结果(默认为 10)索引。
您误解为“BruteForce层的输入是model.user_model”。 user_model 不是 BureteForce 层的输入。它是ButeForce类的参数。所以“输入”是 BruteForce 的实例。 user_model 是嵌入输入的 query_model 和两座塔之一。
index.index_from_dataset() 设置另一个候选塔的嵌入。
movies.batch(100) 不仅仅输出 100 部电影,而是输出 100 部电影的许多块。