如何计算连体网络中的相似度/距离,然后对它们进行分类?
这是我目前的尝试
class SiameseNetwork(nn.Module):
def __init__(self) -> None:
super().__init__()
self.resnet = torchvision.models.resnet18(num_classes=5)
def forward_once(self, item):
output = self.resnet(item)
return output
def forward(self, anchor, positive, negative):
output1 = self.forward_once(anchor)
output2 = self.forward_once(positive)
output3 = self.forward_once(negative)
return output1, output2, output3
我使用 resnet 和 TripletMarginLoss 作为损失函数,但我对如何计算相似度并对输出进行分类感到困惑
如果您发现代码有任何问题,请告诉我。
TripletMarginLoss
使用基础 p 范数距离。如果您使用 p=2
的默认参数,那么您应该使用欧氏距离来计算嵌入的相似度。
分类需要一个带标签的分类数据集和一个单独的分类损失,与任何标准分类问题相同。相似性损失不会给你分类。