我有实现向前和向后传递的
Softmax
函数的练习,但是当我运行它时,它有这样的错误:
terminate called after throwing an instance of 'std::invalid_argument'
what(): A row can only be accessed on an expression with exact two dimensions
Aborted (core dumped)
这是我的代码,我被要求使用
xtensor
库来实现。我尝试修复但没有成功:
class Softmax {
public:
Softmax(int axis) : m_nAxis(axis) {}
xt::xarray<double> forward(xt::xarray<double> X) {
cached_Y = softmax(X, m_nAxis);
return cached_Y;
}
xt::xarray<double> backward(const xt::xarray<double>& DY) {
xt::xarray<double> DX = xt::zeros<double>(DY.shape());
int nclasses = DY.shape()[0];
for (int i = 0; i < nclasses; i++) {
xt::xarray<double> y = xt::row(cached_Y, i);
xt::xarray<double> dy_i = xt::row(DY, i);
xt::xarray<double> J = xt::diag(y) - xt::linalg::outer(y, y);
xt::view(DX, i) = xt::linalg::dot(J, dy_i);
}
return DX;
}
private:
int m_nAxis;
xt::xarray<double> cached_Y;
int positive_index(int idx, int size) {
if (idx < 0) return idx + size;
return idx;
}
xt::xarray<double> softmax(xt::xarray<double> X, int axis) {
xt::svector<unsigned long> shape = X.shape();
axis = positive_index(axis, shape.size());
shape[axis] = 1;
xt::xarray<double> Xmax = xt::amax(X, {axis});
X = xt::exp(X - Xmax.reshape(shape));
xt::xarray<double> SX = xt::sum(X, {axis});
SX = SX.reshape(shape);
X = X / SX;
return X;
}
};
int main(int argc, char* argv[]) {
xt::xarray<double> X = {1, 2, 3};
std::cout << "Input X: " << X << std::endl;
Softmax softmax_layer(0);
xt::xarray<double> Y = softmax_layer.forward(X);
std::cout << "Softmax Output Y: " << Y << std::endl;
xt::xarray<double> DY = {0.1, 0.2, 0.3};
std::cout << "Input DY: " << DY << std::endl;
xt::xarray<double> DX = softmax_layer.backward(DY);
std::cout << "Backward Output DX: " << DX << std::endl;
return 0;
}
不知道为什么调用backward函数时无法访问。
在你的前向传播中,你正确地选择了一个通用的softmax,它接收在构造函数中计算softmax的轴。在向后传递中,您忘记了这一点并要求 softmax 始终沿着轴 1,并且不仅输入数组,而且forward() 的一般输出都将是二维的。这些都是不必要的假设,因为您的样本数组是一般 softmax 的有效输入。
nclasses
是一个用词不当,这并没有帮助,softmax 不是在 softmax 轴(类的数量)上独立迭代,而是在其他一切上迭代。
您应该重用变量
m_nAxis
并沿其计算向后(),即迭代大小为 shape()[m_nAxis]
的所有视图/切片而不是所有行。从文档来看,似乎 xaxis_slice_iterator
就是您所需要的。
此外,如果您做出任何具体假设(一般情况下不应该在这里),您需要确保它们在您的代码中得到正确记录或验证,最好两者兼而有之。如果您在backward()开始时打印一个错误,指出某些数组不是二维的,但需要是二维的,这将有助于调试。如果这是一个你应该上交的练习,它也会看起来更令人印象深刻。