如何使用平衡采样器进行火炬数据集/数据加载器

问题描述 投票:0回答:1

我的简化数据集如下所示:

class MyDataset(Dataset):
    def __init__(self) -> None:
        super().__init__()
        self.images: torch.Tensor[n, w, h, c]   # n images in memmory - specific use case
        self.labels: torch.Tensor[n, w, h, c]   # n images in memmory - specific use case
        self.positive_idx: List                 # positive 1 out of 10000 negative
        self.negative_idx: List
        
    def __len__(self):
        return 10000 # fixed value for training
        
    def __getitem__(self, idx):
        return self.images[idx], self.labels[idx]
    

ds = MyDataset()
dl = DataLoader(ds, batch_size=100, shuffle=False, sampler=...)   
# Weighted Sampler? Shuffle False because I guess the sampler should process shuffling.

平衡 Dataloader 采样的最“火炬”方式是什么,以便在每个时期将批次构建为 10 个正值 + 90 个随机负值,并且在没有足够的正值复制可能的情况下?

出于本练习的目的,我不会实施增强来增加正样本的样本量。

python pytorch dataset sampling dataloader
1个回答
0
投票

我认为你可以实现一个批量采样器来选择哪个数据点将通过

__getitem__

为你的数据集产生
class NegativeSampler:

  def __init__(self, positive_idx, negative_idx):
     
    self.positive_idx = positive_idx
    self.negative_idx = negative_idx 

  def __iter__(self): # this function will return index for your custom dataset ```__getitem__(self, idx)```
    positive_idx_batch = random.sample(self.positive_idx, batch_size)
    negative_idx_batch = []

    for pos_idx in positive_idx_batch:
       negative_idx_batch.append()
    
    
    yield positive_idx_batch + negative_idx_batch  

© www.soinside.com 2019 - 2024. All rights reserved.