如何将(B x C x H x W)张量瓦片无缝混合在一起以隐藏瓦片边界?

问题描述 投票:1回答:1

为了完整起见,这是我要做的事情的文本摘要:

  1. 将图像分割为图块。
  2. 通过模型的新副本运行每个图块以进行一定次数的迭代。
  3. 将羽毛块并排成行。
  4. 将行排齐,然后放回原始图像/张量中。
  5. 也许保存输出,然后将输出再次拆分为图块。
  6. 重复步骤2和3进行一定数量的迭代。

我仅需要有关步骤1、3和4的帮助。样式转移过程会导致在处理后的图块中形成一些细微差异,因此我需要将它们重新融合在一起。通过羽化,我基本上是指将图块淡入另一个图块以模糊边界(例如在ImageMagick,Photoshop等中)。我正在尝试通过使用Torch.linspace()创建遮罩来完成此混合,尽管我不确定是否有更好的方法。

尽管我正在与PyTorch合作,但我试图实现的目标基于:https://github.com/VaKonS/neural-style/blob/Multi-resolution/neural_style.lua。我要实现的平铺代码可以在这里找到:https://gist.github.com/ProGamerGov/e64fcb309274c2946f5a9a679ed45669,尽管您不需要查看它,因为您所需的所有内容都可以在下面找到。

本质上,这就是我想要做的(红色区域与另一个图块重叠):

Diagram of what I am trying accomplish. Red regions are where a tile overlaps with another tile.

这是到目前为止我所拥有的代码。尚未实现羽化和将行添加在一起,因为我还无法使单个图块羽化正常工作。]

import torch
from PIL import Image
import torchvision.transforms as transforms

def tile_calc(tile_size, v, d):
    max_val = max(min(tile_size*v+tile_size, d), 0)
    min_val = tile_size*v
    if abs(min_val - max_val) < tile_size:
        min_val = max_val-tile_size
    return min_val, max_val

def split_tensor(tensor, tile_size=256):
    tiles, tile_idx = [], []
    tile_size_y, tile_size_x = tile_size+8, tile_size +5 # Make H and W different for testing
    h, w = tensor.size(2), tensor.size(3)
    h_range, w_range = int(-(h // -tile_size_y)), int(-(w // -tile_size_x))

    for y in range(h_range):       
        for x in range(w_range):        
            ty, y_val = tile_calc(tile_size_y, y, h)
            tx, x_val = tile_calc(tile_size_x, x, w)

            tiles.append(tensor[:, :, ty:y_val, tx:x_val])
            tile_idx.append([ty, y_val, tx, x_val])

    w_overlap = tile_idx[0][3] - tile_idx[1][2]
    h_overlap = tile_idx[0][1] - tile_idx[w_range][0]

    if tensor.is_cuda:
        base_tensor = torch.zeros(tensor.squeeze(0).size(), device=tensor.get_device())
    else: 
        base_tensor = torch.zeros(tensor.squeeze(0).size())
    return tiles, base_tensor.unsqueeze(0), (h_range, w_range), (h_overlap, w_overlap) 

 # Feather vertically          
def feather_tiles(tensor_list, hxw, w_overlap):
    print(len(tensor_list))
    mask_list = []
    if w_overlap > 0:
        for i, tile in enumerate(tensor_list):
            if i % hxw[1] != 0:
                lin_mask = torch.linspace(0,1,w_overlap).repeat(tile.size(2),1)
                mask_part = torch.ones(tile.size(2), tile.size(3)-w_overlap)
                mask = torch.cat([lin_mask, mask_part], 1)
                mask = mask.repeat(3,1,1).unsqueeze(0)
                mask_list.append(mask)
            else:
                mask = torch.ones(tile.squeeze().size()).unsqueeze(0)
                mask_list.append(mask)
    return mask_list


def build_row(tensor_tiles, tile_masks, hxw, w_overlap, bt, tile_size):
    print(len(tensor_tiles), len(tile_masks))
    if bt.is_cuda:
        row_base = torch.ones(bt.size(1),tensor_tiles[0].size(2),bt.size(3), device=bt.get_device()).unsqueeze(0)
    else: 
        row_base = torch.ones(bt.size(1),tensor_tiles[0].size(2),bt.size(3)).unsqueeze(0)
    row_list = []
    for v in range(hxw[1]):
      row_list.append(row_base.clone())  

    num_tiles = 0
    row_val = 0
    tile_size_y, tile_size_x = tile_size+8, tile_size +5
    h, w = bt.size(2), bt.size(3)
    h_range, w_range = hxw[0], hxw[1]
    for y in range(h_range):       
        for x in range(w_range):        
            ty, y_val = tile_calc(tile_size_y, y, h)
            tx, x_val = tile_calc(tile_size_x, x, w)

            if num_tiles % hxw[1] != 0: 
                new_mean = (row_list[row_val][:, :, :, tx:x_val].mean() + tensor_tiles[num_tiles])/2
                row_list[row_val][:, :, :, tx:x_val] = row_list[row_val][:, :, :, tx:x_val] - row_list[row_val][:, :, :, tx:x_val].mean()
                tensor_tiles[num_tiles] = tensor_tiles[num_tiles] - tensor_tiles[num_tiles].mean()  

                row_list[row_val][:, :, :, tx:x_val] = (row_list[row_val][:, :, :, tx:x_val] + ( tensor_tiles[num_tiles] * tile_masks[num_tiles])) + new_mean

            else:
                row_list[row_val][:, :, :, tx:x_val] = tensor_tiles[num_tiles]          
            num_tiles+=1 
        row_val+=1          
    return row_list


def preprocess(image_name, image_size):
    image = Image.open(image_name).convert('RGB')
    if type(image_size) is not tuple:
        image_size = tuple([int((float(image_size) / max(image.size))*x) for x in (image.height, image.width)])
    Loader = transforms.Compose([transforms.Resize(image_size), transforms.ToTensor()])
    tensor = (Loader(image) * 256).unsqueeze(0)
    return tensor

def deprocess(output_tensor):
    output_tensor = output_tensor.squeeze(0).cpu() / 256
    output_tensor.clamp_(0, 1)
    Image2PIL = transforms.ToPILImage()
    image = Image2PIL(output_tensor.cpu())
    return image


input_tensor = preprocess('test.jpg', 256)

tile_tensors, base_t, hxw, ovlp = split_tensor(input_tensor, 128)
tile_masks = feather_tiles(tile_tensors, hxw, ovlp[1])
row_tensors = build_row(tile_tensors, tile_masks, hxw, ovlp[1], base_t, 128)

ft = deprocess(row_tensors[0]) # save tensor to view it 
ft.save('ft_row_0.png')

为了完整起见,这是我要执行的操作的文本摘要:将图像拆分为图块。在模型的新副本中运行每个图块,以进行一定数量的迭代。铺羽毛并将其放置...

python image-processing pytorch tensor tile
1个回答
0
投票

您正在寻找torch.nn.functional.unfoldtorch.nn.functional.unfold。这些功能使您可以将“滑动窗口”操作应用于具有任意窗口大小和跨度的图像。torch.nn.functional.fold提供了有关这些功能的更多信息,torch.nn.functional.fold提供了如何使用This answer“融合”重叠窗口的示例。这些参考应该为您提供实现混合方案所需的信息。

© www.soinside.com 2019 - 2024. All rights reserved.