如果值符合尾数,可以安全地假设 32 位浮点数可以直接相互比较吗?

问题描述 投票:0回答:1

在关于查找整数是否是完全平方和的 leetcode 问题中,使用浮点数而不是整数会带来更多的加速(“完全平方”问题)。是否可以安全地假设,如果整数值保证小于 10000 且大于或等于 0,我们可以使用浮点数来代替?

样品对比:

if(n == i*i + j*j * 2)
    result3++;

if(n == i*i + k*k)
    result2++;

int
float
都通过了所有测试(n,i,j,k,所有float或所有int),但我仍然不确定CPU之间是否有任何区别(不确定leetcode是否始终使用完全相同的)或编译器或其他东西(比如时间?)。

问题和代码链接:https://leetcode.com/problems/perfect-squares/description/

代码:

#include<iostream>
#include<math.h>
class Solution {
public:

static constexpr int simd =8;
using FAST_TYPE = short;
using MASK_TYPE = short;
    const int numSquares(const int n) const noexcept {
        if(n==2 || n==8)
            return 2;
        if(n==3 || n==6 || n==11)
            return 3;
        
        if((int)std::sqrt(n)*(int)std::sqrt(n) == n)
            return 1;

        FAST_TYPE found2 = 0;
        FAST_TYPE found3 = 0;
        FAST_TYPE found32 = 0;
        FAST_TYPE found33 = 0;
        FAST_TYPE found34 = 0;


        alignas(64)
        FAST_TYPE zeroSimd[simd];
        alignas(64)
        FAST_TYPE oneSimd[simd];
        alignas(64)
        FAST_TYPE found3Simd[simd];
        alignas(64)
        FAST_TYPE found3Simd2[simd];
        alignas(64)
        FAST_TYPE found3Simd3[simd];
        alignas(64)
        FAST_TYPE found3Simd4[simd];                
        alignas(64)
        FAST_TYPE mSimd[simd];
        alignas(64)
        FAST_TYPE kSimd[simd];
        alignas(64)
        FAST_TYPE k0Simd[simd];
        alignas(64)
        FAST_TYPE nSimd[simd];
        alignas(64)
        FAST_TYPE twoSimd[simd];
        alignas(64)
        FAST_TYPE threeSimd[simd];   
        alignas(64)
        FAST_TYPE iSimd[simd];     
        alignas(64)
        FAST_TYPE jSimd[simd];
        alignas(64)
        FAST_TYPE ijSimd[simd];           
        alignas(64)
        FAST_TYPE j2Simd[simd];        
        alignas(64)
        FAST_TYPE i2Simd[simd];      
        alignas(64)
        MASK_TYPE mask1Simd[simd];                    
        alignas(64)
        MASK_TYPE mask2Simd[simd];                    
        alignas(64)
        MASK_TYPE mask3Simd[simd];                    
        alignas(64)
        MASK_TYPE mask4Simd[simd];         
        alignas(64)
        FAST_TYPE sum1Simd[simd];                                              
        alignas(64)
        FAST_TYPE sum2Simd[simd];                                                      
        alignas(64)
        FAST_TYPE sum3Simd[simd];       
        alignas(64)
        FAST_TYPE mulSimd[simd];                                                         
        for(int i=0;i<simd;i++)
        {
            zeroSimd[i]=0;
            oneSimd[i]=1;
            found3Simd[i]=0;
            found3Simd2[i]=0;
            found3Simd3[i]=0;
            found3Simd4[i]=0;
            mSimd[i]=i;
            nSimd[i]=n;
            twoSimd[i]=2;
            threeSimd[i]=2;
            
        }
        for(int i=1+std::sqrt(n);i>=1;i--)
        {
            const FAST_TYPE i2 = i*i;
            const FAST_TYPE i22 = 2*i*i;            
            const FAST_TYPE i23 = 3*i*i;            
            #pragma GCC ivdep
            for(int m=0;m<simd;m++)
                iSimd[m]=i2;
            #pragma GCC ivdep
            for(int m=0;m<simd;m++)
                i2Simd[m]=i22;                
            found2 += (i22 == n);            
            found3+=(i23 == n);   
            for(int j=i-1;j>=1;j--)
            {
                const FAST_TYPE j2 = j*j;
                const FAST_TYPE j22 = 2*j*j;
                const FAST_TYPE j23 = 3*j*j;
                #pragma GCC ivdep
                for(int m=0;m<simd;m++)
                    jSimd[m]=j2;
                #pragma GCC ivdep
                for(int m=0;m<simd;m++)
                    j2Simd[m]=j22;       
                #pragma GCC ivdep
                for(int m=0;m<simd;m++)
                    ijSimd[m]=i2+j2;                                      
                found2+=(i2 + j2 == n);
                found3+=(i2 + j22 == n)+(i22 + j2 == n)+(j23 == n);        
                const int k32 = j-1 - ((j-1)%simd);  
                #pragma GCC unroll 2
                for(int k0=1;k0<=k32;k0+=simd) 
                {
                    #pragma GCC ivdep
                    for(int m=0;m<simd;m++)
                        k0Simd[m]=k0;
                    #pragma GCC ivdep
                    for(int m=0;m<simd;m++)                 
                        kSimd[m] = k0Simd[m]+mSimd[m];
                    #pragma GCC ivdep
                    for(int m=0;m<simd;m++)                 
                        kSimd[m] = kSimd[m]*kSimd[m];


                    #pragma GCC ivdep
                    for(int m=0;m<simd;m++) 
                        sum1Simd[m]=ijSimd[m] + kSimd[m];
                    #pragma GCC ivdep
                    for(int m=0;m<simd;m++)                             
                        mask1Simd[m]=sum1Simd[m] == nSimd[m];
                    #pragma GCC ivdep
                    for(int m=0;m<simd;m++)                  
                        found3Simd[m]=mask1Simd[m]?oneSimd[m]:found3Simd[m];


                    #pragma GCC ivdep
                    for(int m=0;m<simd;m++) 
                        sum2Simd[m]=i2Simd[m] + kSimd[m];
                    #pragma GCC ivdep
                    for(int m=0;m<simd;m++)                             
                        mask2Simd[m]=(sum2Simd[m]==nSimd[m]);
                    #pragma GCC ivdep
                    for(int m=0;m<simd;m++)                  
                        found3Simd2[m]=mask2Simd[m]?oneSimd[m]:found3Simd2[m];

                    #pragma GCC ivdep
                    for(int m=0;m<simd;m++) 
                        sum3Simd[m]=j2Simd[m] + kSimd[m];
                    #pragma GCC ivdep
                    for(int m=0;m<simd;m++)                             
                        mask3Simd[m]=(sum3Simd[m]==nSimd[m]);                        
                     #pragma GCC ivdep
                    for(int m=0;m<simd;m++)                      
                        found3Simd3[m]=mask3Simd[m]?oneSimd[m]:found3Simd3[m];     


                    #pragma GCC ivdep
                    for(int m=0;m<simd;m++) 
                        mulSimd[m]=threeSimd[m]*kSimd[m];
                    #pragma GCC ivdep
                    for(int m=0;m<simd;m++)       
                        mask4Simd[m]=(mulSimd[m]==nSimd[m]);                    
                    #pragma GCC ivdep
                    for(int m=0;m<simd;m++)                                                                                                                                                                                                    
                        found3Simd4[m]=mask4Simd[m]?oneSimd[m]:found3Simd4[m];
                    
                }

                
                for(int k=k32;k<=j-1;k++)   
                {                   
                    const FAST_TYPE k2 = k*k;
                    found3+=(i2 + j2 + k2 ==n);
                    found32+=(i22 + k2 ==n);
                    found33+=(j22 + k2 ==n);
                    found34+=(3*k2 ==n);
                }
            }      
        }

        for(int i=0;i<simd;i++)
        {
            found3+=found3Simd[i];
            found32+=found3Simd2[i];
            found33+=found3Simd3[i];
            found34+=found3Simd4[i];
        }
        found3 += found32 + found33 + found34;
        if(found2)
            return 2;

        if(found3)
            return 3;

        return 4;
    }

};

int main()
{
    Solution s;
    for(int i=10;i<20;i++)
    {
        std::cout<<i<<" is equal to sum of "<<s.numSquares(i)<< " perfect squares"<<std::endl; 
    }
}

输出:

10 is equal to sum of 2 perfect squares
11 is equal to sum of 3 perfect squares
12 is equal to sum of 3 perfect squares
13 is equal to sum of 2 perfect squares
14 is equal to sum of 3 perfect squares
15 is equal to sum of 4 perfect squares
16 is equal to sum of 1 perfect squares
17 is equal to sum of 2 perfect squares
18 is equal to sum of 2 perfect squares
19 is equal to sum of 3 perfect squares
c++ floating-point integer comparison
1个回答
0
投票

如果值适合尾数,可以安全地假设 32 位浮点数可以直接相互比较吗?

确定是否安全的一种方法是将结果与直接的、非性能的参考代码进行比较并寻找差异。

如果它们都进行比较,那么至少在一台机器和编译器上,比较是安全的。

int
为 16 位时,OP 的代码可能会出现问题。

下面是一些测试代码来比较。

/*
 * Return true if `n` is the sum of 2 perfect squares
 */
bool IsSumOf2Squares(int n, int *a_ptr, int *b_ptr) {
  if (n < 0) {
    return false;
  }
  int b_target;
  for (int a = 0; (b_target = n - a * a) >= 0; a++) {
    int diff;
    for (int b = a; (diff = b_target - b * b) >= 0; b++) {
      if (diff == 0) {
        *a_ptr = a;
        *b_ptr = b;
        return true;
      }
    }
  }
  return false;
}

int main(void) {
  clock_t c0, c1;
  c0 = clock();
  int count = 0;
  int n = 10000;
  for (int i = -42; i <= n; i++) {
    int a, b;
    if (IsSumOf2Squares(i, &a, &b)) {
      if (count % 300 == 0) {
        printf("%10d: %5d %5d\n", i, a, b);
        fflush(stdout);
      }
      count++;
    }
  }
  c1 = clock();
  printf("Count: %d, Time:%gs\n", count, (double) (c1 - c0) / CLOCKS_PER_SEC);
  return 0;
}

输出

         0:     0     0
       901:     1    30
      1933:    13    42
      3001:    20    51
      4105:     3    64
      5213:    37    62
      6354:    27    75
      7489:    33    80
      8656:    40    84
      9808:    68    72
Count: 2750, Time:0.015s
© www.soinside.com 2019 - 2024. All rights reserved.