我是 PyTorch 的新手,正在创建一个多输出线性回归模型,根据字母为单词着色。 (这将帮助有字素颜色联觉的人更轻松地阅读。)它接收单词并输出 RGB 值。每个单词都表示为 45 个浮点数 [0,1] 的向量,其中 (0, 1] 代表字母,0 代表该位置不存在字母。每个样本的输出应该是一个向量 [r-value, g -值,b-值]。
我得到了
运行时错误:mat1 和 mat2 形状无法相乘(90x1 和 45x3)
当我尝试在训练循环中运行我的模型时。
查看现有的 Stack Overflow 帖子,我认为这意味着我需要重塑我的数据,但我不知道如何/在哪里以解决此问题的方式进行此操作。特别是考虑到我不知道那个 90x1 矩阵来自哪里。
我的模特
我从简单开始;在我可以让单层发挥作用之后,可以出现多层。
class ColorPredictor(torch.nn.Module):
#Constructor
def __init__(self):
super(ColorPredictor, self).__init__()
self.linear = torch.nn.Linear(45, 3, device= device) #length of encoded word vectors & size of r,g,b vectors
# Prediction
def forward(self, x: torch.Tensor) -> torch.Tensor:
y_pred = self.linear(x)
return y_pred
我如何加载数据
# Dataset Class
class Data(Dataset):
# Constructor
def __init__(self, inputs, outputs):
self.x = inputs # a list of encoded word vectors
self.y = outputs # a Pandas dataframe of r,g,b values converted to a torch tensor
self.len = len(inputs)
# Getter
def __getitem__(self, index):
return self.x[index], self.y[index]
# Get number of samples
def __len__(self):
return self.len
# create train/test split
train_size = int(0.8 * len(data))
train_data = Data(inputs[:train_size], outputs[:train_size])
test_data = Data(inputs[train_size:], outputs[train_size:])
# create DataLoaders for training and testing sets
train_loader = DataLoader(dataset = train_data, batch_size=2)
test_loader = DataLoader(dataset = test_data, batch_size=2)
发生错误的测试循环
for epoch in range(epochs):
# Train
model.train() #training mode
for x,y in train_loader:
y_pred = model(x) #ERROR HERE
loss = criterion(y_pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
将 45x1 输入张量更改为 2x45 输入张量,第二列全为零。这适用于第一次运行 train_loader 循环,但在第二次运行 train_loader 循环期间,我得到另一个矩阵乘法错误,这次是大小为 90x2 和 45x3 的矩阵。
我将编码的词向量从 (45, 1) 重塑为 (1,45)
如果输入大小为(1,45)并且batch_size = 2:
size of weight matrix = output_features x input_features = 3x45
bias vector size = output_features = 3
input x weight transposed bias
y = [ [1,2,3,...,45], * [ [1, 2, 3], + [ [b1, b2, b3],
[3,2,1,...,45]] [2, 2, 3], [b1, b2, b3] ]
[3, 2, 3],
[. . .],
[45,45,45] ]
2x45 * 45x3
2x3 + 2x3