如何在Pytorch中实现批量负采样?

问题描述 投票:0回答:1

我尝试使用隐式数据集训练推荐系统的两塔模型。 在训练之前,我想使用批量负采样来预处理数据集。 我认为代码很好,但性能真的很慢,所以我无法训练模型。 可以查一下训练方法吗?

def train_with_in_batch_negative_sampling(
    model,
    num_epoch,
    optimizer,
    data_loader,
    criterion,
    device,
    log_interval=10,
    num_neg_samples=5,
):
    model.train()
    total_loss = 0
    train_loss = 987654321
    tk0 = tqdm.tqdm(data_loader, smoothing=0, mininterval=1.0)

    for i, (fields, target) in enumerate(tk0):
        fields, target = fields.to(device), target.to(device)
        new_fields = []
        for idx, row in enumerate(fields):
            if idx==0:
                item_tensor = fields[idx+1:,1]
            else:
                item_tensor = torch.cat([fields[:idx,1], fields[idx+1:,1]],dim=0)
            new_fields.append(torch.Tensor([row[0], row[1], torch.Tensor(1)]).view(1, -1).to(device))
            new_fields.append(
                torch.cat(
                    [
                        row[0:1].repeat(511).view(-1, 1),
                        item_tensor.view(-1, 1),
                        torch.zeros(511, 1).to(device),
                    ],
                    dim=1,
                )
            )
        merged = torch.cat(new_fields,dim=0).to(torch.int)
        y = model(merged[:,:2])
        loss = criterion(y, merged[:,2])
        model.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        train_loss = min(train_loss, total_loss / log_interval)
        if (i + 1) % log_interval == 0:
            tk0.set_postfix(loss=total_loss / log_interval, epoch_num=num_epoch)
        total_loss = 0
    return train_loss

在创建Dataloader之前我尝试过负随机采样,但是训练出来的结果不太好。所以我想应用批量负采样。

pytorch sampling
1个回答
0
投票

您可以在tfrs查看方法 至少它提供了一些见解

© www.soinside.com 2019 - 2024. All rights reserved.