使用nDCG作为火炬损失函数

问题描述 投票:0回答:1

我实现了一个基于nDCG的loss,如下代码片段所示:

import torch

class NDCGLoss(torch.nn.Module):

    def __init__(self, relevance_map):
        super(NDCGLoss, self).__init__()
        self.relevance_map = relevance_map

    def get_relevance(self, queries_ids, docs_ids):
        relevance = torch.empty((queries_ids.shape[0], docs_ids.shape[0]), device=queries_ids.device,
                                requires_grad=True)
        for i, query_idx in enumerate(queries_ids.tolist()):
            for j, doc_idx in enumerate(docs_ids.tolist()):
                if doc_idx in self.relevance_map[query_idx]:
                    relevance[i, j] = 1.0
                else:
                    relevance[i, j] = 0.0
        return relevance

    def _get_dcg(self, scores, relevances):
        discount = 1.0 / (torch.log2(torch.arange(relevances.shape[-1], device=relevances.device) + 2.0))
        ranking = scores.argsort(descending=True)
        ranked = torch.gather(relevances, dim=-1, index=ranking)
        return torch.sum(discount * ranked, dim=-1)

    def forward(self, query_idx, query_rpr, doc_idx, doc_rpr):

        scores = torch.einsum("ab,cb->ac", query_rpr, doc_rpr)  # inner product
        relevance = self.get_relevance(query_idx, doc_idx)  # ground-truth
        dcg = self._get_dcg(scores, relevance)
        idcg = self._get_dcg(relevance, relevance)
        idcg = torch.where(idcg == 0, 1.0, idcg)
        return -torch.log(
            torch.mean(
                torch.div(dcg, idcg)
            ) + 1e-11
        )

相似度得分是使用 query_rpr 和 doc_rpr 嵌入作为内积来计算的。 query_idx 和 doc_idx 张量用于根据相关性映射中存储的内容来定义真实相关性,该映射指示哪些文档与每个查询相关。

但是,我收到以下错误。

RuntimeError: a view of a leaf Variable that requires grad is being used in an in-place operation.
pytorch torch loss-function
1个回答
0
投票

我认为这是来自

get_relevance
函数,您可以在其中进行就地分配(即
relevance[i, j] = 1.0
)。查看您的代码,您确定
relevance
张量需要
requires_grad=True
吗?如果我正确地阅读了代码,
relevance
是你的基本事实值 - 你不应该需要通过它进行反向传播。

© www.soinside.com 2019 - 2024. All rights reserved.