我的工作是使用pointnet++算法检测靠近中心线点的点。我已经获得了obj中的50个网格数据集及其dat格式的中心线点。我申请了 pointnet++ 实现,下面是我的错误代码
您提供了一种网格分类方法,可以根据点与中心线的接近程度对点进行分类。
/home/aniruddha/PycharmProjects/Thesis_centerline/venv/bin/python /home/aniruddha/PycharmProjects/Thesis_centerline/PointNET++.py
data shape: torch.Size([32, 3, 1000])
group_idx shape before assignment: torch.Size([32, 512, 3])
group_first shape: torch.Size([32, 512, 3])
group_idx shape after assignment: torch.Size([32, 512, 3])
Traceback (most recent call last):
File "/home/aniruddha/PycharmProjects/Thesis_centerline/PointNET++.py", line 365, in <module>
outputs = model(data)
File "/home/aniruddha/PycharmProjects/Thesis_centerline/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/aniruddha/PycharmProjects/Thesis_centerline/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/home/aniruddha/PycharmProjects/Thesis_centerline/PointNET++.py", line 334, in forward
l1_xyz, l1_points = self.sa1(xyz, None)
File "/home/aniruddha/PycharmProjects/Thesis_centerline/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/aniruddha/PycharmProjects/Thesis_centerline/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/home/aniruddha/PycharmProjects/Thesis_centerline/PointNET++.py", line 234, in forward
grouped_points = F.relu(bn(conv(grouped_points)))
File "/home/aniruddha/PycharmProjects/Thesis_centerline/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/aniruddha/PycharmProjects/Thesis_centerline/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/home/aniruddha/PycharmProjects/Thesis_centerline/venv/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 460, in forward
return self._conv_forward(input, self.weight, self.bias)
File "/home/aniruddha/PycharmProjects/Thesis_centerline/venv/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 456, in _conv_forward
return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Given groups=1, weight of size [32, 3, 1, 1], expected input[32, 1000, 1536, 1] to have 3 channels, but got 1000 channels instead
Process finished with exit code 1
这是我到目前为止所尝试的
import os
import warnings
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import trimesh
from scipy.spatial import cKDTree
from sklearn.decomposition import PCA
from torch.utils.data import DataLoader, Dataset
warnings.filterwarnings("ignore")
# Rename the files to ensure they match the OBJ and DAT files
directory_path = '2023_RCSE_Centerline'
current_file_names = ['Kiel_BB_Pat_1_centerline.dat',
'Kiel_BB_Pat_2_centerline.dat',
'Kiel_BB_Pat_3_centerline.dat',
'Kiel_BB_Pat_4_centerline.dat',
'Kiel_BB_Pat_5_centerline.dat',
'Kiel_BB_Pat_6_centerline.dat',
'Kiel_BB_Pat_7_centerline.dat',
'Kiel_BB_Pat_8_centerline.dat',
'Kiel_BB_Pat_9_centerline.dat',
'HW_20170703_centerline_rescaled.dat',
'Kiel_BB_Patient_10_centerline.dat']
new_file_names = ['Kiel_BB_Pat1_centerline.dat',
'Kiel_BB_Pat2_centerline.dat',
'Kiel_BB_Pat3_centerline.dat',
'Kiel_BB_Pat4_centerline.dat',
'Kiel_BB_Pat5_centerline.dat',
'Kiel_BB_Pat6_centerline.dat',
'Kiel_BB_Pat7_centerline.dat',
'Kiel_BB_Pat8_centerline.dat',
'Kiel_BB_Pat9_centerline.dat',
'HW_20170703_rescaled_centerline.dat',
'Kiel_BB_Pat10_centerline.dat']
for k in range(len(current_file_names)):
current_file_name = os.path.join(directory_path, current_file_names[k])
new_file_name = os.path.join(directory_path, new_file_names[k])
if os.path.exists(current_file_name):
os.rename(current_file_name, new_file_name)
def load_centerline(file_path):
return np.loadtxt(file_path, skiprows=1, usecols=(0, 1, 2))
def load_obj(file_path):
return trimesh.load(file_path)
def remove_duplicate_points(centerline, threshold=0.01):
tree = cKDTree(centerline)
to_remove = set()
for i, point in enumerate(centerline):
if i in to_remove:
continue
indices = tree.query_ball_point(point, threshold)
indices.remove(i)
to_remove.update(indices)
cleaned_centerline = np.array([point for i, point in enumerate(centerline) if i not in to_remove])
return cleaned_centerline
def center_data(points):
centroid = points.mean(axis=0)
return points - centroid
def align_data_with_pca(points):
pca = PCA(n_components=3)
pca.fit(points)
aligned_points = pca.transform(points)
return aligned_points, pca.components_
def get_point_cloud(file_path):
mesh = trimesh.load_mesh(file_path)
point_cloud = mesh.sample(500)
new_points = np.random.uniform(mesh.bounds[0], mesh.bounds[1], (500, 3))
final_point_cloud = np.concatenate([point_cloud, new_points])
return final_point_cloud.astype(np.float32)
def classify_point(point_cloud, centerline_point):
if point_cloud.shape[0] > 0:
tree = cKDTree(point_cloud)
classifications = np.zeros(len(point_cloud))
for k in range(len(centerline_point)):
_, indices = tree.query(centerline_point[k, :])
classifications[indices] = 1
else:
classifications = np.array([])
return classifications
class PointCloudDataset(Dataset):
def __init__(self, point_clouds, labels):
self.point_clouds = point_clouds
self.labels = labels
def __len__(self):
return len(self.point_clouds)
def __getitem__(self, index):
point_cloud = self.point_clouds[index]
label = self.labels[index]
return point_cloud, label
def prepare_dataset(obj_files, directory_path,target_size=100):
point_clouds = []
classifications = []
for obj_file in obj_files:
file_path = os.path.join(directory_path, obj_file)
point_cloud = get_point_cloud(file_path)
centerline_file = file_path.replace('.obj', '_centerline.dat')
centerline_points = np.loadtxt(centerline_file, skiprows=1, usecols=(0, 1, 2))
augmented_clouds = np.expand_dims(point_cloud, axis=0)
for aug_cloud in augmented_clouds:
classifications.append(classify_point(aug_cloud, centerline_points))
point_clouds.append(aug_cloud.T)
point_clouds = np.array(point_clouds, dtype=np.float32)
classifications = np.concatenate(classifications, axis=0)
return PointCloudDataset(point_clouds, classifications)
def random_rotation(pc, max_angle=0.2):
angle = np.random.uniform(-max_angle, max_angle)
rotation_matrix = np.array([
[np.cos(angle), -np.sin(angle), 0],
[np.sin(angle), np.cos(angle), 0],
[0, 0, 1]
])
return np.dot(pc, rotation_matrix)
def random_crop(pc, crop_ratio=0.8):
mins = np.min(pc, axis=0)
maxs = np.max(pc, axis=0)
diff = maxs - mins
min_crop = mins + diff * (1 - crop_ratio)
max_crop = maxs - diff * (1 - crop_ratio)
crop_min = mins + np.random.rand(3) * (max_crop - min_crop)
crop_max = crop_min + diff * crop_ratio
indices = np.all((pc >= crop_min) & (pc <= crop_max), axis=1)
cropped_pc = pc[indices]
if len(cropped_pc) == 0:
return pc # Return the original pc if crop results in empty array
return cropped_pc
def random_scale(pc, scale_range=(0.8, 1.2)):
scale = np.random.uniform(*scale_range)
return pc * scale
def random_noise(pc, noise_level=0.01):
noise = np.random.normal(0, noise_level, pc.shape)
return pc + noise
# Load OBJ and corresponding centerline files
obj_files = [f for f in os.listdir(directory_path) if f.endswith('.obj')]
# Prepare the dataset
dataset = prepare_dataset(obj_files, directory_path, target_size=100)
train_size = int(0.7 * len(dataset))
val_size = int(0.15 * len(dataset))
test_size = len(dataset) - train_size - val_size
# Split dataset
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])
# Adjusting DataLoader to ensure correct shapes
# DataLoader setup with collate function
def collate_fn(batch):
point_clouds, labels = zip(*batch)
point_clouds = torch.stack([torch.tensor(pc, dtype=torch.float32) for pc in point_clouds])
labels = torch.tensor(labels, dtype=torch.long)
return point_clouds, labels
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)
class PointNetSetAbstractionMsg(nn.Module):
def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list):
super(PointNetSetAbstractionMsg, self).__init__()
self.npoint = npoint
self.radius_list = radius_list
self.nsample_list = nsample_list
self.conv_blocks = nn.ModuleList()
self.bn_blocks = nn.ModuleList()
for i in range(len(mlp_list)):
convs = nn.ModuleList()
bns = nn.ModuleList()
last_channel = in_channel + 3
for out_channel in mlp_list[i]:
convs.append(nn.Conv2d(last_channel, out_channel, 1))
bns.append(nn.BatchNorm2d(out_channel))
last_channel = out_channel
self.conv_blocks.append(convs)
self.bn_blocks.append(bns)
def forward(self, xyz, points):
B, N, C = xyz.shape
new_xyz = index_points(xyz, farthest_point_sample(xyz, self.npoint))
new_points_list = []
for i, radius in enumerate(self.radius_list):
K = self.nsample_list[i]
group_idx = query_ball_point(radius, K, xyz, new_xyz)
grouped_xyz = index_points(xyz, group_idx)
grouped_xyz -= new_xyz[:, :, None, :]
if points is not None:
grouped_points = index_points(points, group_idx)
grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1)
else:
grouped_points = grouped_xyz
# Adjust the shape of grouped_points
grouped_points = grouped_points.permute(0, 3, 2, 1)
grouped_points = grouped_points.contiguous().view(B, C, -1, 1)
for j, (conv, bn) in enumerate(zip(self.conv_blocks[i], self.bn_blocks[i])):
grouped_points = F.relu(bn(conv(grouped_points)))
new_points = torch.max(grouped_points, 2)[0]
new_points_list.append(new_points)
new_points_concat = torch.cat(new_points_list, dim=1)
return new_xyz, new_points_concat
def index_points(points, idx):
device = points.device
B = points.shape[0]
view_shape = list(idx.shape)
view_shape[1:] = [1] * (len(view_shape) - 1)
repeat_shape = list(idx.shape)
repeat_shape[0] = 1
batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
new_points = points[batch_indices, idx, :]
return new_points
def farthest_point_sample(xyz, npoint):
B, N, C = xyz.shape
centroids = torch.zeros(B, npoint, dtype=torch.long).to(xyz.device)
distance = torch.ones(B, N).to(xyz.device) * 1e10
farthest = torch.randint(0, N, (B,), dtype=torch.long).to(xyz.device)
batch_indices = torch.arange(B, dtype=torch.long).to(xyz.device)
for i in range(npoint):
centroids[:, i] = farthest
centroid = xyz[batch_indices, farthest, :].view(B, 1, C)
dist = torch.sum((xyz - centroid) ** 2, -1)
mask = dist < distance
distance[mask] = dist[mask]
farthest = torch.max(distance, -1)[1]
return centroids
def query_ball_point(radius, nsample, xyz, new_xyz):
device = xyz.device
B, N, C = xyz.shape
_, S, _ = new_xyz.shape
# Compute squared distances
sqrdists = square_distance(new_xyz, xyz)
# Initialize group_idx and mask points outside the radius
group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
group_idx[sqrdists > radius ** 2] = N
# Sort and select the top-k nearest neighbors
group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
# Create group_first to handle insufficient neighbors
group_first = group_idx[:, :, 0].view(B, S, 1).repeat(1, 1, nsample)
group_first = group_first[:, :, :group_idx.shape[2]]
# Debug prints
print(f"group_idx shape before assignment: {group_idx.shape}")
print(f"group_first shape: {group_first.shape}")
# Create mask and ensure shapes match before assignment
mask = group_idx == N
if mask.shape != group_first.shape:
print(f"mask shape: {mask.shape}")
print(f"group_first shape (adjusted): {group_first.shape}")
# Apply mask to replace invalid indices with the first valid index
group_idx[mask] = group_first[mask]
# Debug prints
print(f"group_idx shape after assignment: {group_idx.shape}")
return group_idx
def square_distance(src, dst):
return torch.sum((src[:, :, None] - dst[:, None]) ** 2, dim=-1)
class PointNet2ClsMsg(nn.Module):
def __init__(self, num_classes, normal_channel=False):
super(PointNet2ClsMsg, self).__init__()
in_channel = 3 if normal_channel else 0
self.sa1 = PointNetSetAbstractionMsg(npoint=512, radius_list=[0.1, 0.2, 0.4], nsample_list=[16, 32, 128],
in_channel=in_channel,
mlp_list=[[32, 32, 64], [64, 64, 128], [64, 96, 128]])
self.sa2 = PointNetSetAbstractionMsg(npoint=128, radius_list=[0.2, 0.4, 0.8], nsample_list=[32, 64, 128],
in_channel=320, mlp_list=[[64, 64, 128], [128, 128, 256], [128, 128, 256]])
self.sa3 = PointNetSetAbstractionMsg(npoint=None, radius_list=[0.4, 0.8, 1.6], nsample_list=[64, 128, 256],
in_channel=640,
mlp_list=[[128, 256, 512], [256, 256, 512], [256, 384, 512]])
self.fc1 = nn.Linear(1024, 512)
self.bn1 = nn.BatchNorm1d(512)
self.drop1 = nn.Dropout(0.5)
self.fc2 = nn.Linear(512, 256)
self.bn2 = nn.BatchNorm1d(256)
self.drop2 = nn.Dropout(0.5)
self.fc3 = nn.Linear(256, num_classes)
def forward(self, xyz):
B, _, _ = xyz.shape
l1_xyz, l1_points = self.sa1(xyz, None)
l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)
l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)
x = l3_points.view(B, 1024)
x = F.relu(self.bn1(self.fc1(x)))
x = self.drop1(x)
x = F.relu(self.bn2(self.fc2(x)))
x = self.drop2(x)
x = self.fc3(x)
x = F.log_softmax(x, -1)
return x
# Model, loss, and optimizer
model = PointNet2ClsMsg(num_classes=2)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Training loop with validation
num_epochs = 2
best_val_loss = float('inf')
for epoch in range(num_epochs):
model.train()
total_loss = 0
correct = 0
total = 0
for data, labels in train_loader:
optimizer.zero_grad()
print(f"data shape: {data.shape}")
outputs = model(data)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
total_loss += loss.item() * data.size(0)
_, predicted = torch.max(outputs, 1)
correct += (predicted == labels).sum().item()
total += labels.size(0)
train_loss = total_loss / total
train_acc = correct / total * 100
# Validation
model.eval()
val_loss = 0
val_correct = 0
val_total = 0
with torch.no_grad():
for data, labels in val_loader:
outputs = model(data)
loss = criterion(outputs, labels)
val_loss += loss.item() * data.size(0)
_, predicted = torch.max(outputs, 1)
val_correct += (predicted == labels).sum().item()
val_total += labels.size(0)
val_loss = val_loss / val_total
val_acc = val_correct / val_total * 100
print(
f'Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
# Save the model with the best validation loss
if val_loss < best_val_loss:
best_val_loss = val_loss
torch.save(model.state_dict(), 'best_model.pth')
print("Training completed!")
# Load the best model for testing
model.load_state_dict(torch.load('best_model.pth'))
def evaluate_model(test_loader, model):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for data, labels in test_loader:
outputs = model(data)
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
return 100 * correct / total
# Evaluate on test data
test_accuracy = evaluate_model(test_loader, model)
print(f'Test Accuracy: {test_accuracy:.2f}%')
我发现您的 PointNet++ 模型中遇到了形状不匹配错误。错误消息显示:
RuntimeError: Given groups=1, weight of size [32, 3, 1, 1], expected input[32, 1000, 1536, 1] to have 3 channels, but got 1000 channels instead
这基本上意味着您的卷积层期望输入有 3 个通道,但实际上却得到了 1000 个通道。问题在于如何在 PointNetSetAbstractionMsg 类中重塑 grouped_points。
这是代码中有问题的部分:
grouped_points = grouped_points.permute(0, 3, 2, 1)
grouped_points = grouped_points.contiguous().view(B, C, -1, 1)
这里的排列弄乱了 grouped_points 的形状。在将其传递到卷积层之前,您需要正确调整尺寸。
尝试将此部分更改为:
grouped_points = grouped_points.permute(0, 3, 1, 2)
这应该可以解决问题。更正后的代码应如下所示:
def forward(self, xyz, points):
B, N, C = xyz.shape
new_xyz = index_points(xyz, farthest_point_sample(xyz, self.npoint))
new_points_list = []
for i, radius in enumerate(self.radius_list):
K = self.nsample_list[i]
group_idx = query_ball_point(radius, K, xyz, new_xyz)
grouped_xyz = index_points(xyz, group_idx)
grouped_xyz -= new_xyz[:, :, None, :]
if points is not None:
grouped_points = index_points(points, group_idx)
grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1)
else:
grouped_points = grouped_xyz
# Adjust the shape of grouped_points
grouped_points = grouped_points.permute(0, 3, 1, 2)
for j, (conv, bn) in enumerate(zip(self.conv_blocks[i], self.bn_blocks[i])):
grouped_points = F.relu(bn(conv(grouped_points)))
new_points = torch.max(grouped_points, 2)[0]
new_points_list.append(new_points)
new_points_concat = torch.cat(new_points_list, dim=1)
return new_xyz, new_points_concat
这样,通道将位于正确的位置,并且您的卷积层将以预期的形状接收输入。尝试一下,看看是否能解决问题!