如何计算暹罗网络中的相似度/距离(pytorch)

问题描述 投票:0回答:1

如何计算连体网络中的相似度/距离,然后对它们进行分类?

这是我目前的尝试

class SiameseNetwork(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.resnet = torchvision.models.resnet18(num_classes=5)

    def forward_once(self, item):
        output = self.resnet(item)
        return output

    def forward(self, anchor, positive, negative):
        output1 = self.forward_once(anchor)
        output2 = self.forward_once(positive)
        output3 = self.forward_once(negative)

        return output1, output2, output3

我使用 resnet 和 TripletMarginLoss 作为损失函数,但我对如何计算相似度并对输出进行分类感到困惑

如果您发现代码有任何问题,请告诉我。

python deep-learning pytorch
1个回答
0
投票

TripletMarginLoss
使用基础 p 范数距离。如果您使用
p=2
的默认参数,那么您应该使用欧氏距离来计算嵌入的相似度。

分类需要一个带标签的分类数据集和一个单独的分类损失,与任何标准分类问题相同。相似性损失不会给你分类。

最新问题
© www.soinside.com 2019 - 2025. All rights reserved.