在关于查找整数是否是完全平方和的 leetcode 问题中,使用浮点数而不是整数会带来更多的加速(“完全平方”问题)。是否可以安全地假设,如果整数值保证小于 10000 且大于或等于 0,我们可以使用浮点数来代替?
样品对比:
if(n == i*i + j*j * 2)
result3++;
if(n == i*i + k*k)
result2++;
int
和float
都通过了所有测试(n,i,j,k,所有float或所有int),但我仍然不确定CPU之间是否有任何区别(不确定leetcode是否始终使用完全相同的)或编译器或其他东西(比如时间?)。
问题和代码链接:https://leetcode.com/problems/perfect-squares/description/
代码:
#include<iostream>
#include<math.h>
class Solution {
public:
static constexpr int simd =8;
using FAST_TYPE = short;
using MASK_TYPE = short;
const int numSquares(const int n) const noexcept {
if(n==2 || n==8)
return 2;
if(n==3 || n==6 || n==11)
return 3;
if((int)std::sqrt(n)*(int)std::sqrt(n) == n)
return 1;
FAST_TYPE found2 = 0;
FAST_TYPE found3 = 0;
FAST_TYPE found32 = 0;
FAST_TYPE found33 = 0;
FAST_TYPE found34 = 0;
alignas(64)
FAST_TYPE zeroSimd[simd];
alignas(64)
FAST_TYPE oneSimd[simd];
alignas(64)
FAST_TYPE found3Simd[simd];
alignas(64)
FAST_TYPE found3Simd2[simd];
alignas(64)
FAST_TYPE found3Simd3[simd];
alignas(64)
FAST_TYPE found3Simd4[simd];
alignas(64)
FAST_TYPE mSimd[simd];
alignas(64)
FAST_TYPE kSimd[simd];
alignas(64)
FAST_TYPE k0Simd[simd];
alignas(64)
FAST_TYPE nSimd[simd];
alignas(64)
FAST_TYPE twoSimd[simd];
alignas(64)
FAST_TYPE threeSimd[simd];
alignas(64)
FAST_TYPE iSimd[simd];
alignas(64)
FAST_TYPE jSimd[simd];
alignas(64)
FAST_TYPE ijSimd[simd];
alignas(64)
FAST_TYPE j2Simd[simd];
alignas(64)
FAST_TYPE i2Simd[simd];
alignas(64)
MASK_TYPE mask1Simd[simd];
alignas(64)
MASK_TYPE mask2Simd[simd];
alignas(64)
MASK_TYPE mask3Simd[simd];
alignas(64)
MASK_TYPE mask4Simd[simd];
alignas(64)
FAST_TYPE sum1Simd[simd];
alignas(64)
FAST_TYPE sum2Simd[simd];
alignas(64)
FAST_TYPE sum3Simd[simd];
alignas(64)
FAST_TYPE mulSimd[simd];
for(int i=0;i<simd;i++)
{
zeroSimd[i]=0;
oneSimd[i]=1;
found3Simd[i]=0;
found3Simd2[i]=0;
found3Simd3[i]=0;
found3Simd4[i]=0;
mSimd[i]=i;
nSimd[i]=n;
twoSimd[i]=2;
threeSimd[i]=2;
}
for(int i=1+std::sqrt(n);i>=1;i--)
{
const FAST_TYPE i2 = i*i;
const FAST_TYPE i22 = 2*i*i;
const FAST_TYPE i23 = 3*i*i;
#pragma GCC ivdep
for(int m=0;m<simd;m++)
iSimd[m]=i2;
#pragma GCC ivdep
for(int m=0;m<simd;m++)
i2Simd[m]=i22;
found2 += (i22 == n);
found3+=(i23 == n);
for(int j=i-1;j>=1;j--)
{
const FAST_TYPE j2 = j*j;
const FAST_TYPE j22 = 2*j*j;
const FAST_TYPE j23 = 3*j*j;
#pragma GCC ivdep
for(int m=0;m<simd;m++)
jSimd[m]=j2;
#pragma GCC ivdep
for(int m=0;m<simd;m++)
j2Simd[m]=j22;
#pragma GCC ivdep
for(int m=0;m<simd;m++)
ijSimd[m]=i2+j2;
found2+=(i2 + j2 == n);
found3+=(i2 + j22 == n)+(i22 + j2 == n)+(j23 == n);
const int k32 = j-1 - ((j-1)%simd);
#pragma GCC unroll 2
for(int k0=1;k0<=k32;k0+=simd)
{
#pragma GCC ivdep
for(int m=0;m<simd;m++)
k0Simd[m]=k0;
#pragma GCC ivdep
for(int m=0;m<simd;m++)
kSimd[m] = k0Simd[m]+mSimd[m];
#pragma GCC ivdep
for(int m=0;m<simd;m++)
kSimd[m] = kSimd[m]*kSimd[m];
#pragma GCC ivdep
for(int m=0;m<simd;m++)
sum1Simd[m]=ijSimd[m] + kSimd[m];
#pragma GCC ivdep
for(int m=0;m<simd;m++)
mask1Simd[m]=sum1Simd[m] == nSimd[m];
#pragma GCC ivdep
for(int m=0;m<simd;m++)
found3Simd[m]=mask1Simd[m]?oneSimd[m]:found3Simd[m];
#pragma GCC ivdep
for(int m=0;m<simd;m++)
sum2Simd[m]=i2Simd[m] + kSimd[m];
#pragma GCC ivdep
for(int m=0;m<simd;m++)
mask2Simd[m]=(sum2Simd[m]==nSimd[m]);
#pragma GCC ivdep
for(int m=0;m<simd;m++)
found3Simd2[m]=mask2Simd[m]?oneSimd[m]:found3Simd2[m];
#pragma GCC ivdep
for(int m=0;m<simd;m++)
sum3Simd[m]=j2Simd[m] + kSimd[m];
#pragma GCC ivdep
for(int m=0;m<simd;m++)
mask3Simd[m]=(sum3Simd[m]==nSimd[m]);
#pragma GCC ivdep
for(int m=0;m<simd;m++)
found3Simd3[m]=mask3Simd[m]?oneSimd[m]:found3Simd3[m];
#pragma GCC ivdep
for(int m=0;m<simd;m++)
mulSimd[m]=threeSimd[m]*kSimd[m];
#pragma GCC ivdep
for(int m=0;m<simd;m++)
mask4Simd[m]=(mulSimd[m]==nSimd[m]);
#pragma GCC ivdep
for(int m=0;m<simd;m++)
found3Simd4[m]=mask4Simd[m]?oneSimd[m]:found3Simd4[m];
}
for(int k=k32;k<=j-1;k++)
{
const FAST_TYPE k2 = k*k;
found3+=(i2 + j2 + k2 ==n);
found32+=(i22 + k2 ==n);
found33+=(j22 + k2 ==n);
found34+=(3*k2 ==n);
}
}
}
for(int i=0;i<simd;i++)
{
found3+=found3Simd[i];
found32+=found3Simd2[i];
found33+=found3Simd3[i];
found34+=found3Simd4[i];
}
found3 += found32 + found33 + found34;
if(found2)
return 2;
if(found3)
return 3;
return 4;
}
};
int main()
{
Solution s;
for(int i=10;i<20;i++)
{
std::cout<<i<<" is equal to sum of "<<s.numSquares(i)<< " perfect squares"<<std::endl;
}
}
输出:
10 is equal to sum of 2 perfect squares
11 is equal to sum of 3 perfect squares
12 is equal to sum of 3 perfect squares
13 is equal to sum of 2 perfect squares
14 is equal to sum of 3 perfect squares
15 is equal to sum of 4 perfect squares
16 is equal to sum of 1 perfect squares
17 is equal to sum of 2 perfect squares
18 is equal to sum of 2 perfect squares
19 is equal to sum of 3 perfect squares
如果值适合尾数,可以安全地假设 32 位浮点数可以直接相互比较吗?
确定是否不安全的一种方法是将结果与直接的、非性能的参考代码进行比较并寻找差异。
如果它们都进行比较,那么至少在一台机器和编译器上,比较是安全的。
当
int
为 16 位时,OP 的代码可能会出现问题。
下面是一些测试代码来比较。
/*
* Return true if `n` is the sum of 2 perfect squares
*/
bool IsSumOf2Squares(int n, int *a_ptr, int *b_ptr) {
if (n < 0) {
return false;
}
int b_target;
for (int a = 0; (b_target = n - a * a) >= 0; a++) {
int diff;
for (int b = a; (diff = b_target - b * b) >= 0; b++) {
if (diff == 0) {
*a_ptr = a;
*b_ptr = b;
return true;
}
}
}
return false;
}
int main(void) {
clock_t c0, c1;
c0 = clock();
int count = 0;
int n = 10000;
for (int i = -42; i <= n; i++) {
int a, b;
if (IsSumOf2Squares(i, &a, &b)) {
if (count % 300 == 0) {
printf("%10d: %5d %5d\n", i, a, b);
fflush(stdout);
}
count++;
}
}
c1 = clock();
printf("Count: %d, Time:%gs\n", count, (double) (c1 - c0) / CLOCKS_PER_SEC);
return 0;
}
输出
0: 0 0
901: 1 30
1933: 13 42
3001: 20 51
4105: 3 64
5213: 37 62
6354: 27 75
7489: 33 80
8656: 40 84
9808: 68 72
Count: 2750, Time:0.015s