POintNet++ 实现

问题描述 投票:0回答:1

我的工作是使用pointnet++算法检测靠近中心线点的点。我已经获得了obj中的50个网格数据集及其dat格式的中心线点。我申请了 pointnet++ 实现,下面是我的错误代码

您提供了一种网格分类方法,可以根据点与中心线的接近程度对点进行分类。

/home/aniruddha/PycharmProjects/Thesis_centerline/venv/bin/python /home/aniruddha/PycharmProjects/Thesis_centerline/PointNET++.py 
data shape: torch.Size([32, 3, 1000])
group_idx shape before assignment: torch.Size([32, 512, 3])
group_first shape: torch.Size([32, 512, 3])
group_idx shape after assignment: torch.Size([32, 512, 3])
Traceback (most recent call last):
  File "/home/aniruddha/PycharmProjects/Thesis_centerline/PointNET++.py", line 365, in <module>
    outputs = model(data)
  File "/home/aniruddha/PycharmProjects/Thesis_centerline/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/aniruddha/PycharmProjects/Thesis_centerline/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/aniruddha/PycharmProjects/Thesis_centerline/PointNET++.py", line 334, in forward
    l1_xyz, l1_points = self.sa1(xyz, None)
  File "/home/aniruddha/PycharmProjects/Thesis_centerline/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/aniruddha/PycharmProjects/Thesis_centerline/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/aniruddha/PycharmProjects/Thesis_centerline/PointNET++.py", line 234, in forward
    grouped_points = F.relu(bn(conv(grouped_points)))
  File "/home/aniruddha/PycharmProjects/Thesis_centerline/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/aniruddha/PycharmProjects/Thesis_centerline/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/aniruddha/PycharmProjects/Thesis_centerline/venv/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 460, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/home/aniruddha/PycharmProjects/Thesis_centerline/venv/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 456, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Given groups=1, weight of size [32, 3, 1, 1], expected input[32, 1000, 1536, 1] to have 3 channels, but got 1000 channels instead

Process finished with exit code 1

这是我到目前为止所尝试的

import os
import warnings
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import trimesh
from scipy.spatial import cKDTree
from sklearn.decomposition import PCA
from torch.utils.data import DataLoader, Dataset

warnings.filterwarnings("ignore")

# Rename the files to ensure they match the OBJ and DAT files
directory_path = '2023_RCSE_Centerline'
current_file_names = ['Kiel_BB_Pat_1_centerline.dat',
                      'Kiel_BB_Pat_2_centerline.dat',
                      'Kiel_BB_Pat_3_centerline.dat',
                      'Kiel_BB_Pat_4_centerline.dat',
                      'Kiel_BB_Pat_5_centerline.dat',
                      'Kiel_BB_Pat_6_centerline.dat',
                      'Kiel_BB_Pat_7_centerline.dat',
                      'Kiel_BB_Pat_8_centerline.dat',
                      'Kiel_BB_Pat_9_centerline.dat',
                      'HW_20170703_centerline_rescaled.dat',
                      'Kiel_BB_Patient_10_centerline.dat']
new_file_names = ['Kiel_BB_Pat1_centerline.dat',
                  'Kiel_BB_Pat2_centerline.dat',
                  'Kiel_BB_Pat3_centerline.dat',
                  'Kiel_BB_Pat4_centerline.dat',
                  'Kiel_BB_Pat5_centerline.dat',
                  'Kiel_BB_Pat6_centerline.dat',
                  'Kiel_BB_Pat7_centerline.dat',
                  'Kiel_BB_Pat8_centerline.dat',
                  'Kiel_BB_Pat9_centerline.dat',
                  'HW_20170703_rescaled_centerline.dat',
                  'Kiel_BB_Pat10_centerline.dat']

for k in range(len(current_file_names)):
    current_file_name = os.path.join(directory_path, current_file_names[k])
    new_file_name = os.path.join(directory_path, new_file_names[k])

    if os.path.exists(current_file_name):
        os.rename(current_file_name, new_file_name)


def load_centerline(file_path):
    return np.loadtxt(file_path, skiprows=1, usecols=(0, 1, 2))


def load_obj(file_path):
    return trimesh.load(file_path)


def remove_duplicate_points(centerline, threshold=0.01):
    tree = cKDTree(centerline)
    to_remove = set()
    for i, point in enumerate(centerline):
        if i in to_remove:
            continue
        indices = tree.query_ball_point(point, threshold)
        indices.remove(i)
        to_remove.update(indices)

    cleaned_centerline = np.array([point for i, point in enumerate(centerline) if i not in to_remove])
    return cleaned_centerline


def center_data(points):
    centroid = points.mean(axis=0)
    return points - centroid


def align_data_with_pca(points):
    pca = PCA(n_components=3)
    pca.fit(points)
    aligned_points = pca.transform(points)
    return aligned_points, pca.components_


def get_point_cloud(file_path):
    mesh = trimesh.load_mesh(file_path)
    point_cloud = mesh.sample(500)
    new_points = np.random.uniform(mesh.bounds[0], mesh.bounds[1], (500, 3))
    final_point_cloud = np.concatenate([point_cloud, new_points])
    return final_point_cloud.astype(np.float32)


def classify_point(point_cloud, centerline_point):
    if point_cloud.shape[0] > 0:
        tree = cKDTree(point_cloud)
        classifications = np.zeros(len(point_cloud))
        for k in range(len(centerline_point)):
            _, indices = tree.query(centerline_point[k, :])
            classifications[indices] = 1
    else:
        classifications = np.array([])
    return classifications


class PointCloudDataset(Dataset):
    def __init__(self, point_clouds, labels):
        self.point_clouds = point_clouds
        self.labels = labels

    def __len__(self):
        return len(self.point_clouds)

    def __getitem__(self, index):
        point_cloud = self.point_clouds[index]
        label = self.labels[index]
        return point_cloud, label


def prepare_dataset(obj_files, directory_path,target_size=100):
    point_clouds = []
    classifications = []
    for obj_file in obj_files:
        file_path = os.path.join(directory_path, obj_file)
        point_cloud = get_point_cloud(file_path)
        centerline_file = file_path.replace('.obj', '_centerline.dat')
        centerline_points = np.loadtxt(centerline_file, skiprows=1, usecols=(0, 1, 2))

        augmented_clouds = np.expand_dims(point_cloud, axis=0)
        for aug_cloud in augmented_clouds:
            classifications.append(classify_point(aug_cloud, centerline_points))
            point_clouds.append(aug_cloud.T)

    point_clouds = np.array(point_clouds, dtype=np.float32)
    classifications = np.concatenate(classifications, axis=0)
    return PointCloudDataset(point_clouds, classifications)


def random_rotation(pc, max_angle=0.2):
    angle = np.random.uniform(-max_angle, max_angle)
    rotation_matrix = np.array([
        [np.cos(angle), -np.sin(angle), 0],
        [np.sin(angle), np.cos(angle), 0],
        [0, 0, 1]
    ])
    return np.dot(pc, rotation_matrix)


def random_crop(pc, crop_ratio=0.8):
    mins = np.min(pc, axis=0)
    maxs = np.max(pc, axis=0)
    diff = maxs - mins
    min_crop = mins + diff * (1 - crop_ratio)
    max_crop = maxs - diff * (1 - crop_ratio)
    crop_min = mins + np.random.rand(3) * (max_crop - min_crop)
    crop_max = crop_min + diff * crop_ratio
    indices = np.all((pc >= crop_min) & (pc <= crop_max), axis=1)
    cropped_pc = pc[indices]
    if len(cropped_pc) == 0:
        return pc  # Return the original pc if crop results in empty array
    return cropped_pc


def random_scale(pc, scale_range=(0.8, 1.2)):
    scale = np.random.uniform(*scale_range)
    return pc * scale


def random_noise(pc, noise_level=0.01):
    noise = np.random.normal(0, noise_level, pc.shape)
    return pc + noise


# Load OBJ and corresponding centerline files
obj_files = [f for f in os.listdir(directory_path) if f.endswith('.obj')]

# Prepare the dataset
dataset = prepare_dataset(obj_files, directory_path, target_size=100)
train_size = int(0.7 * len(dataset))
val_size = int(0.15 * len(dataset))
test_size = len(dataset) - train_size - val_size

# Split dataset
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])


# Adjusting DataLoader to ensure correct shapes
# DataLoader setup with collate function
def collate_fn(batch):
    point_clouds, labels = zip(*batch)
    point_clouds = torch.stack([torch.tensor(pc, dtype=torch.float32) for pc in point_clouds])
    labels = torch.tensor(labels, dtype=torch.long)
    return point_clouds, labels


train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)


class PointNetSetAbstractionMsg(nn.Module):
    def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list):
        super(PointNetSetAbstractionMsg, self).__init__()
        self.npoint = npoint
        self.radius_list = radius_list
        self.nsample_list = nsample_list
        self.conv_blocks = nn.ModuleList()
        self.bn_blocks = nn.ModuleList()
        for i in range(len(mlp_list)):
            convs = nn.ModuleList()
            bns = nn.ModuleList()
            last_channel = in_channel + 3
            for out_channel in mlp_list[i]:
                convs.append(nn.Conv2d(last_channel, out_channel, 1))
                bns.append(nn.BatchNorm2d(out_channel))
                last_channel = out_channel
            self.conv_blocks.append(convs)
            self.bn_blocks.append(bns)

    def forward(self, xyz, points):
        B, N, C = xyz.shape
        new_xyz = index_points(xyz, farthest_point_sample(xyz, self.npoint))
        new_points_list = []
        for i, radius in enumerate(self.radius_list):
            K = self.nsample_list[i]
            group_idx = query_ball_point(radius, K, xyz, new_xyz)
            grouped_xyz = index_points(xyz, group_idx)
            grouped_xyz -= new_xyz[:, :, None, :]
            if points is not None:
                grouped_points = index_points(points, group_idx)
                grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1)
            else:
                grouped_points = grouped_xyz
            # Adjust the shape of grouped_points
            grouped_points = grouped_points.permute(0, 3, 2, 1)
            grouped_points = grouped_points.contiguous().view(B, C, -1, 1)
            for j, (conv, bn) in enumerate(zip(self.conv_blocks[i], self.bn_blocks[i])):
                grouped_points = F.relu(bn(conv(grouped_points)))
            new_points = torch.max(grouped_points, 2)[0]
            new_points_list.append(new_points)
        new_points_concat = torch.cat(new_points_list, dim=1)
        return new_xyz, new_points_concat


def index_points(points, idx):
    device = points.device
    B = points.shape[0]
    view_shape = list(idx.shape)
    view_shape[1:] = [1] * (len(view_shape) - 1)
    repeat_shape = list(idx.shape)
    repeat_shape[0] = 1
    batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
    new_points = points[batch_indices, idx, :]
    return new_points


def farthest_point_sample(xyz, npoint):
    B, N, C = xyz.shape
    centroids = torch.zeros(B, npoint, dtype=torch.long).to(xyz.device)
    distance = torch.ones(B, N).to(xyz.device) * 1e10
    farthest = torch.randint(0, N, (B,), dtype=torch.long).to(xyz.device)
    batch_indices = torch.arange(B, dtype=torch.long).to(xyz.device)
    for i in range(npoint):
        centroids[:, i] = farthest
        centroid = xyz[batch_indices, farthest, :].view(B, 1, C)
        dist = torch.sum((xyz - centroid) ** 2, -1)
        mask = dist < distance
        distance[mask] = dist[mask]
        farthest = torch.max(distance, -1)[1]
    return centroids


def query_ball_point(radius, nsample, xyz, new_xyz):
    device = xyz.device
    B, N, C = xyz.shape
    _, S, _ = new_xyz.shape

    # Compute squared distances
    sqrdists = square_distance(new_xyz, xyz)

    # Initialize group_idx and mask points outside the radius
    group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
    group_idx[sqrdists > radius ** 2] = N

    # Sort and select the top-k nearest neighbors
    group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]

    # Create group_first to handle insufficient neighbors
    group_first = group_idx[:, :, 0].view(B, S, 1).repeat(1, 1, nsample)
    group_first = group_first[:, :, :group_idx.shape[2]]

    # Debug prints
    print(f"group_idx shape before assignment: {group_idx.shape}")
    print(f"group_first shape: {group_first.shape}")

    # Create mask and ensure shapes match before assignment
    mask = group_idx == N
    if mask.shape != group_first.shape:
        print(f"mask shape: {mask.shape}")
        print(f"group_first shape (adjusted): {group_first.shape}")

    # Apply mask to replace invalid indices with the first valid index
    group_idx[mask] = group_first[mask]

    # Debug prints
    print(f"group_idx shape after assignment: {group_idx.shape}")

    return group_idx


def square_distance(src, dst):
    return torch.sum((src[:, :, None] - dst[:, None]) ** 2, dim=-1)


class PointNet2ClsMsg(nn.Module):
    def __init__(self, num_classes, normal_channel=False):
        super(PointNet2ClsMsg, self).__init__()
        in_channel = 3 if normal_channel else 0
        self.sa1 = PointNetSetAbstractionMsg(npoint=512, radius_list=[0.1, 0.2, 0.4], nsample_list=[16, 32, 128],
                                             in_channel=in_channel,
                                             mlp_list=[[32, 32, 64], [64, 64, 128], [64, 96, 128]])
        self.sa2 = PointNetSetAbstractionMsg(npoint=128, radius_list=[0.2, 0.4, 0.8], nsample_list=[32, 64, 128],
                                             in_channel=320, mlp_list=[[64, 64, 128], [128, 128, 256], [128, 128, 256]])
        self.sa3 = PointNetSetAbstractionMsg(npoint=None, radius_list=[0.4, 0.8, 1.6], nsample_list=[64, 128, 256],
                                             in_channel=640,
                                             mlp_list=[[128, 256, 512], [256, 256, 512], [256, 384, 512]])

        self.fc1 = nn.Linear(1024, 512)
        self.bn1 = nn.BatchNorm1d(512)
        self.drop1 = nn.Dropout(0.5)
        self.fc2 = nn.Linear(512, 256)
        self.bn2 = nn.BatchNorm1d(256)
        self.drop2 = nn.Dropout(0.5)
        self.fc3 = nn.Linear(256, num_classes)

    def forward(self, xyz):
        B, _, _ = xyz.shape
        l1_xyz, l1_points = self.sa1(xyz, None)
        l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)
        l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)
        x = l3_points.view(B, 1024)
        x = F.relu(self.bn1(self.fc1(x)))
        x = self.drop1(x)
        x = F.relu(self.bn2(self.fc2(x)))
        x = self.drop2(x)
        x = self.fc3(x)
        x = F.log_softmax(x, -1)
        return x


# Model, loss, and optimizer
model = PointNet2ClsMsg(num_classes=2)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop with validation
num_epochs = 2
best_val_loss = float('inf')

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for data, labels in train_loader:
        optimizer.zero_grad()
        print(f"data shape: {data.shape}")
        outputs = model(data)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * data.size(0)
        _, predicted = torch.max(outputs, 1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)

    train_loss = total_loss / total
    train_acc = correct / total * 100

    # Validation
    model.eval()
    val_loss = 0
    val_correct = 0
    val_total = 0
    with torch.no_grad():
        for data, labels in val_loader:
            outputs = model(data)
            loss = criterion(outputs, labels)
            val_loss += loss.item() * data.size(0)
            _, predicted = torch.max(outputs, 1)
            val_correct += (predicted == labels).sum().item()
            val_total += labels.size(0)

    val_loss = val_loss / val_total
    val_acc = val_correct / val_total * 100

    print(
        f'Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')

    # Save the model with the best validation loss
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), 'best_model.pth')

print("Training completed!")

# Load the best model for testing
model.load_state_dict(torch.load('best_model.pth'))


def evaluate_model(test_loader, model):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, labels in test_loader:
            outputs = model(data)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return 100 * correct / total


# Evaluate on test data
test_accuracy = evaluate_model(test_loader, model)
print(f'Test Accuracy: {test_accuracy:.2f}%')
python machine-learning deep-learning artificial-intelligence medical
1个回答
0
投票

我发现您的 PointNet++ 模型中遇到了形状不匹配错误。错误消息显示:

RuntimeError: Given groups=1, weight of size [32, 3, 1, 1], expected input[32, 1000, 1536, 1] to have 3 channels, but got 1000 channels instead

这基本上意味着您的卷积层期望输入有 3 个通道,但实际上却得到了 1000 个通道。问题在于如何在 PointNetSetAbstractionMsg 类中重塑 grouped_points。

这是代码中有问题的部分:

grouped_points = grouped_points.permute(0, 3, 2, 1)
grouped_points = grouped_points.contiguous().view(B, C, -1, 1)

这里的排列弄乱了 grouped_points 的形状。在将其传递到卷积层之前,您需要正确调整尺寸。

尝试将此部分更改为:

grouped_points = grouped_points.permute(0, 3, 1, 2)

这应该可以解决问题。更正后的代码应如下所示:

def forward(self, xyz, points):
    B, N, C = xyz.shape
    new_xyz = index_points(xyz, farthest_point_sample(xyz, self.npoint))
    new_points_list = []
    for i, radius in enumerate(self.radius_list):
        K = self.nsample_list[i]
        group_idx = query_ball_point(radius, K, xyz, new_xyz)
        grouped_xyz = index_points(xyz, group_idx)
        grouped_xyz -= new_xyz[:, :, None, :]
        if points is not None:
            grouped_points = index_points(points, group_idx)
            grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1)
        else:
            grouped_points = grouped_xyz
        # Adjust the shape of grouped_points
        grouped_points = grouped_points.permute(0, 3, 1, 2)
        for j, (conv, bn) in enumerate(zip(self.conv_blocks[i], self.bn_blocks[i])):
            grouped_points = F.relu(bn(conv(grouped_points)))
        new_points = torch.max(grouped_points, 2)[0]
        new_points_list.append(new_points)
    new_points_concat = torch.cat(new_points_list, dim=1)
    return new_xyz, new_points_concat

这样,通道将位于正确的位置,并且您的卷积层将以预期的形状接收输入。尝试一下,看看是否能解决问题!

© www.soinside.com 2019 - 2024. All rights reserved.