什么是 TensorFlow 中的 BruteForce 层以及为什么它被称为 BruteForce

问题描述 投票:0回答:2

我理解检索的任务——我已经浏览过代码;还研究了替代方法,例如 SCNN,它是一种超快的最近邻。

但是我还是很难理解下面代码的机制

# Create a model that takes in raw query features, and
index = tfrs.layers.factorized_top_k.BruteForce(model.user_model)
# recommends movies out of the entire movies dataset.
index.index_from_dataset(
  tf.data.Dataset.zip((movies.batch(100), movies.batch(100).map(model.movie_model)))
)

# Get recommendations.
_, titles = index(tf.constant(["42"]))
print(f"Recommendations for user 42: {titles[0, :3]}")

model.user_model
已训练完毕,现在应该返回 user_id 的嵌入。
BruteForce
层的输入是
model.user_model
;然后应该对其进行索引?

我猜输出给出了

user_id
42,返回 3 个标题,其中
movies.batch(100)
。但我无法理解 BruteForce 和索引的功能!

keras tensorflow2.0 recommendation-engine
2个回答
2
投票

BruteForce 层测试从模型最后一层提取的嵌入之间的所有组合。

根据该层的 tensorflow 文档,该层重新调整最接近每个索引的 top k 结果(默认为 10)索引。


0
投票

您误解为“BruteForce层的输入是model.user_model”。 user_model 不是 BureteForce 层的输入。它是ButeForce类的参数。所以“输入”是 BruteForce 的实例。 user_model 是嵌入输入的 query_model 和两座塔之一。

index.index_from_dataset() 设置另一个候选塔的嵌入。

movies.batch(100) 不仅仅输出 100 部电影,而是输出 100 部电影的许多块。

© www.soinside.com 2019 - 2024. All rights reserved.