我正在开展一个项目,需要在 C 语言的微控制器上实现神经网络,执行时间至关重要。我正在尝试尝试加快代码运行速度的技术,我发现有效的一件事是进行整数数学而不是浮点数学。这在速度方面确实非常有效,但是,当我去检查时,实际的数学是错误的,这是一个问题。
我有 MATLAB 代码,可以使用浮点计算 NN,然后也使用整数数学进行计算。我通过将输入、权重和偏差缩放 2^15 倍来实现此目的,并遵循以下过程:
activation((scaled_input(1) * scaled_weight(1) / scale_factor) + scaled_bias(1));
在比较浮点与整数数学时,这会产生相当好的精度。我还在 C 中实现了这个函数(对于浮点数),如下所示(权重和偏差已预先缩放):
void NN_relu(int r1, int c1, float m1[][c1], int r2, int c2, const float m2[][c2], const float bias[][c2], float result[][c2]) {
for (int i = 0; i < r1; i++) {
for (int j = 0; j < c2; j++) {
result[i][j] = bias[i][j];
for (int k = 0; k < c1; k++) {
result[i][j] += m1[i][k] * m2[k][j];
}
result[i][j] = (result[i][j] > 0) ? result[i][j] : 0;
}
}
}
这也很有效,获得与 MATLAB 完全相同的值。
但是,我的缩放实现不起作用。第一层的输出如下:
69259 0 0 24448 34106 64463 71738 0 55807 0
与 MATLAB 的正确缩放输出相反:
69259 0 0 24448 0 0 596026 0 0 0
这是有问题的缩放 C 函数
void NN_relu(int r1, int c1, int m1[][c1], int r2, int c2, const int m2[][c2], const int bias[][c2], int result[][c2]) {
for (int i = 0; i < r1; i++) {
for (int j = 0; j < c2; j++) {
result[i][j] = bias[i][j];
for (int k = 0; k < c1; k++) {
result[i][j] += (m1[i][k] * m2[k][j]) / S; // S = 2^15
}
result[i][j] = (result[i][j] > 0) ? result[i][j] : 0;
}
}
}
我已经推断出结果中某些但并非所有元素的错误原因是由于溢出,我已经通过使用 print 语句进行调试来确认这一点,即我得到的正确值:
m1[0][0]: 32768, m2[0][0]: 4752, product: 155713536, scaled_product: 4752
对于不正确的,我得到:
m1[0][0]: 32768, m2[0][4]: 101526, product: -968163328, scaled_product: -29546
如您所见,缩放后的乘积应该再次为 101526,但这并不是因为整数溢出。
我尝试过使用 2^31 或其他缩放因子,但我的 MATLAB 代码给出的精度非常糟糕,比如相差几个数量级。
有没有其他方法可以处理这个问题?或者我应该尝试另一种方法来加快代码执行速度?
您描述的使用整数执行非整数算术的特定方法称为定点算术。
在考虑或使用定点运算时,首要的基本问题是您需要支持什么范围的值(最大值和最小值)以及您需要什么精度(小数位数)。 这些首先是您的数据和预期计算结果(包括中间结果)的问题,而不是算法细节的问题。 您的比例因子 215 对应于分数的 15 位二进制数字,或大约 4-5 位十进制数字。 如果您需要比这更精确,那么您需要更大的比例,但这会减少您可以用相同的总位数表示的值的数字范围。 如果您可以承受不太精确的情况,那么较小的比例因子将为您提供更大的范围。
假设 32 位
int
(看起来是你的),指定小数部分为 15 位,与你的数据和所有计算结果很好匹配,那么你的状态就相当不错了。 乘积 mn 中的最大有效位数是 m 和 n 中有效位数的总和,因此为了避免定点乘法中的整数溢出,只需将操作数强制为64 位,执行 64 位乘法,然后按比例缩小结果。 例如:
result[i][j] += ((long long int) m1[i][k] * m2[k][j]) / S;
最终转换回类型
int
永远不会溢出,这是所选定点表示适合您的数据的方面之一。 也就是说,您需要在选择底层整数类型和小数位数时考虑到这一点。
补充说明:
将基础整数类型转换为更宽的整数类型,而不是仅执行与更宽类型的乘法,只会把问题踢下去。 在您的特定计算中,它可能对您有用,但仅限于您的特定数字不会触发溢出。 它不会消除溢出的风险。
那可能没问题。 但是,您可能会发现较窄的类型速度更快,因为它使用一半的内存,并且将数据从内存移动到 CPU 并返回的成本很高,值得关注。 如果您的矩阵很大,并且如果
int
足够了,但对于您询问的一个定点乘法,那么请考虑坚持使用 int
作为基础类型,并仅扩大整数乘法的瞬态中间结果。
我使用
long long int
代替long int
,因为前者保证至少有64位,但后者不是。 在这种情况下,如果大小很重要,您应该考虑使用 stdint.h
中的显式宽度整数类型。 特别是,int32_t
和/或int64_t
。