(a * b) / c MulDiv 并处理中间乘法的溢出

问题描述 投票:0回答:6

我需要做以下算术:

long a,b,c;
long result = a*b/c;

虽然结果保证适合

long
,但乘法却不然,因此它可能会溢出。

我试图一步一步地做到这一点(先乘法然后除法),同时通过将

a*b
的中间结果拆分为最大大小为 4 的 int 数组来处理溢出(就像 BigInteger 使用其
int[] mag
一样)变量)。

在这里我陷入了分裂。我无法理解进行精确除法所需的按位移位。我只需要商(不需要余数)。

假设的方法是:

public static long divide(int[] dividend, long divisor)

另外,我不考虑使用

BigInteger
,因为这部分代码需要快速(我想坚持使用基元和基元数组)。

任何帮助将不胜感激!

编辑: 我并不是想自己实现整个

BigInteger
。我想做的是比使用通用
a*b/c
更快地解决特定问题(
a*b
,其中
BigInteger
可能会溢出)。

Edit2:如果能以一种聪明的方式完成,完全不溢出,那就太理想了,评论中出现了一些提示,但我仍在寻找正确的提示。

更新: 我尝试将 BigInteger 代码移植到我的特定需求,而不创建对象,并且在第一次迭代中,与使用 BigInteger(在我的开发电脑上)相比,我的速度提高了约 46%。

然后我尝试了一些修改过的@David Eisenstat解决方案,与BigInteger(即〜18%)相比,这给了我〜56%(我从

Long.MIN_VALUE
Long.MAX_VALUE
运行了100_000_000_000个随机输入)减少了运行时间(超过2倍)与我改编的 BigInteger 算法相比)。

还会有更多的优化和测试迭代,但在这一点上,我想我必须接受这个答案是最好的。

java algorithm long-integer division
6个回答
3
投票

我一直在修改一种方法,即 (1) 将

a
b
与学校算法在 21 位肢体上相乘 (2) 继续除以
c
,并以不寻常的残差表示
a*b - c*q
,使用
double
存储高位,使用
long
存储低位。我不知道它是否可以与标准长除法竞争,但是为了您的享受,

public class MulDiv {
  public static void main(String[] args) {
    java.util.Random r = new java.util.Random();
    for (long i = 0; true; i++) {
      if (i % 1000000 == 0) {
        System.err.println(i);
      }
      long a = r.nextLong() >> (r.nextInt(8) * 8);
      long b = r.nextLong() >> (r.nextInt(8) * 8);
      long c = r.nextLong() >> (r.nextInt(8) * 8);
      if (c == 0) {
        continue;
      }
      long x = mulDiv(a, b, c);
      java.math.BigInteger aa = java.math.BigInteger.valueOf(a);
      java.math.BigInteger bb = java.math.BigInteger.valueOf(b);
      java.math.BigInteger cc = java.math.BigInteger.valueOf(c);
      java.math.BigInteger xx = aa.multiply(bb).divide(cc);
      if (java.math.BigInteger.valueOf(xx.longValue()).equals(xx) && x != xx.longValue()) {
        System.out.printf("a=%d b=%d c=%d: %d != %s\n", a, b, c, x, xx);
      }
    }
  }

  // Returns truncate(a b/c), subject to the precondition that the result is
  // defined and can be represented as a long.
  private static long mulDiv(long a, long b, long c) {
    // Decompose a.
    long a2 = a >> 42;
    long a10 = a - (a2 << 42);
    long a1 = a10 >> 21;
    long a0 = a10 - (a1 << 21);
    assert a == (((a2 << 21) + a1) << 21) + a0;
    // Decompose b.
    long b2 = b >> 42;
    long b10 = b - (b2 << 42);
    long b1 = b10 >> 21;
    long b0 = b10 - (b1 << 21);
    assert b == (((b2 << 21) + b1) << 21) + b0;
    // Compute a b.
    long ab4 = a2 * b2;
    long ab3 = a2 * b1 + a1 * b2;
    long ab2 = a2 * b0 + a1 * b1 + a0 * b2;
    long ab1 = a1 * b0 + a0 * b1;
    long ab0 = a0 * b0;
    // Compute a b/c.
    DivBy d = new DivBy(c);
    d.shift21Add(ab4);
    d.shift21Add(ab3);
    d.shift21Add(ab2);
    d.shift21Add(ab1);
    d.shift21Add(ab0);
    return d.getQuotient();
  }
}

public strictfp class DivBy {
  // Initializes n <- 0.
  public DivBy(long d) {
    di = d;
    df = (double) d;
    oneOverD = 1.0 / df;
  }

  // Updates n <- 2^21 n + i. Assumes |i| <= 3 (2^42).
  public void shift21Add(long i) {
    // Update the quotient and remainder.
    q <<= 21;
    ri = (ri << 21) + i;
    rf = rf * (double) (1 << 21) + (double) i;
    reduce();
  }

  // Returns truncate(n/d).
  public long getQuotient() {
    while (rf != (double) ri) {
      reduce();
    }
    // Round toward zero.
    if (q > 0) {
      if ((di > 0 && ri < 0) || (di < 0 && ri > 0)) {
        return q - 1;
      }
    } else if (q < 0) {
      if ((di > 0 && ri > 0) || (di < 0 && ri < 0)) {
        return q + 1;
      }
    }
    return q;
  }

  private void reduce() {
    // x is approximately r/d.
    long x = Math.round(rf * oneOverD);
    q += x;
    ri -= di * x;
    rf = repairLowOrderBits(rf - df * (double) x, ri);
  }

  private static double repairLowOrderBits(double f, long i) {
    int e = Math.getExponent(f);
    if (e < 64) {
      return (double) i;
    }
    long rawBits = Double.doubleToRawLongBits(f);
    long lowOrderBits = (rawBits >> 63) ^ (rawBits << (e - 52));
    return f + (double) (i - lowOrderBits);
  }

  private final long di;
  private final double df;
  private final double oneOverD;
  private long q = 0;
  private long ri = 0;
  private double rf = 0;
}

1
投票

您可以使用最大公约数(gcd)来帮助。

a * b / c = (a / gcd(a,c)) * (b / (c / gcd(a,c)))

编辑:OP要求我解释上面的方程。基本上,我们有:

a = (a / gcd(a,c)) * gcd(a,c)
c = (c / gcd(a,c)) * gcd(a,c)

Let's say x=gcd(a,c) for brevity, and rewrite this.

a*b/c = (a/x) * x * b 
        --------------
        (c/x) * x

Next, we cancel

a*b/c = (a/x) * b 
        ----------
        (c/x) 

您可以更进一步。设 y = gcd(b, c/x)

a*b/c = (a/x) * (b/y) * y 
        ------------------
        ((c/x)/y) * y 

a*b/c = (a/x) * (b/y) 
        ------------
           (c/(xy))

这是获取 gcd 的代码。

static long gcd(long a, long b) 
{ 
  if (b == 0) 
    return a; 
  return gcd(b, a % b);  
} 

1
投票

David Eisenstat 让我思考更多。
我希望简单的情况能够快速完成:让

double
来处理。 Newton-Raphson 对于其他人来说可能是更好的选择。

 /** Multiplies both <code>factor</code>s
  *  and divides by <code>divisor</code>.
  * @return <code>Long.MIN_VALUE</code> if result out of range,<br/>
  *     else <code>factorA * factor1 / divisor</code> */
    public static long
    mulDiv(long factorA, long factor1, long divisor) {
        final double dd = divisor,
            product = (double)factorA * factor1,
            a1_d = product / dd;
        if (a1_d < -TOO_LARGE || TOO_LARGE < a1_d)
            return tooLarge();
        if (-ONE_ < a1_d && a1_d < ONE_)
            return 0;
        if (-EXACT < product && product < EXACT)
            return (long) a1_d;
        long pLo = factorA * factor1, //diff,
            pHi = high64(factorA, factor1);
        if (a1_d < -LONG_MAX_ || LONG_MAX_ < a1_d) {
            long maxdHi = divisor >> 1;
            if (maxdHi < pHi
                || maxdHi == pHi
                   && Long.compareUnsigned((divisor << Long.SIZE-1),
                                           pLo) <= 0)
                return tooLarge();
        }
        final double high_dd = TWO_POWER64/dd;
        long quotient = (long) a1_d,
            loPP = quotient * divisor,
            hiPP = high64(quotient, divisor);
        long remHi = pHi - hiPP, // xxx overflow/carry
            remLo = pLo - loPP;
        if (Long.compareUnsigned(pLo, remLo) < 0)
            remHi -= 1;
        double fudge = remHi * high_dd;
        if (remLo < 0)
            fudge += high_dd;
        fudge += remLo/dd;
        long //fHi = (long)fudge/TWO_POWER64,
            fLo = (long) Math.floor(fudge); //*round
        quotient += fLo;
        loPP = quotient * divisor;
        hiPP = high64(quotient, divisor);
        remHi = pHi - hiPP; // should be 0?!
        remLo = pLo - loPP;
        if (Long.compareUnsigned(pLo, remLo) < 0)
            remHi -= 1;
        if (0 == remHi && 0 <= remLo && remLo < divisor)
            return quotient;

        fudge = remHi * high_dd;
        if (remLo < 0)
            fudge += high_dd;
        fudge += remLo/dd;
        fLo = (long) Math.floor(fudge);
        return quotient + fLo;
    }

 /** max <code>double</code> trusted to represent
  *  a value in the range of <code>long</code> */
    static final double
        LONG_MAX_ = Double.valueOf(Long.MAX_VALUE - 0xFFF);
 /** max <code>double</code> trusted to represent a value below 1 */
    static final double
        ONE_ = Double.longBitsToDouble(
                    Double.doubleToRawLongBits(1) - 4);
 /** max <code>double</code> trusted to represent a value exactly */
    static final double
        EXACT = Long.MAX_VALUE >> 12;
    static final double
        TWO_POWER64 = Double.valueOf(1L<<32)*Double.valueOf(1L<<32);

    static long tooLarge() {
//      throw new RuntimeException("result too large for long");
        return Long.MIN_VALUE;
    }
    static final long   ONES_32 = ~(~0L << 32);

    static long high64(long factorA, long factor1) {
        long loA = factorA & ONES_32,
            hiA = factorA >>> 32,
            lo1 = factor1 & ONES_32,
            hi1 = factor1 >>> 32;
        return ((loA * lo1 >>> 32)
                +loA * hi1 + hiA * lo1 >>> 32)
               + hiA * hi1;
    }

(我在 IDE 之外重新排列了这段代码,使

mulDiv()
位于顶部。 由于懒惰,我有一个用于标志处理的包装器 - 可能会在地狱结冰之前尝试正确地完成它。
对于计时,迫切需要输入模型:
让每个可能的结果都有相同的可能性怎么样?)


0
投票

也许不聪明,但结果时间是线性的

#define MUL_DIV_TYPE    unsigned int
#define BITS_PER_TYPE   (sizeof(MUL_DIV_TYPE)*8)
#define TOP_BIT_TYPE    (1<<(BITS_PER_TYPE-1))

//
//    result = ( a * b ) / c, without intermediate overflow.
//
MUL_DIV_TYPE mul_div( MUL_DIV_TYPE a, MUL_DIV_TYPE b, MUL_DIV_TYPE c ) {
    MUL_DIV_TYPE    st, sb;     // product sum top and bottom

    MUL_DIV_TYPE    d, e;       // division result

    MUL_DIV_TYPE    i,      // bit counter
            j;      // overflow check

    st = 0;
    sb = 0;

    d = 0;
    e = 0;

    for( i = 0; i < BITS_PER_TYPE; i++ ) {
        //
        //  Shift sum left to make space
        //  for next partial sum
        //
        st <<= 1;
        if( sb & TOP_BIT_TYPE ) st |= 1;
        sb <<= 1;
        //
        //  Add a to s if top bit on b
        //  is set.
        //
        if( b & TOP_BIT_TYPE ) {
            j = sb;
            sb += a;
            if( sb < j ) st++;
        }
        //
        //  Division.
        //
        d <<= 1;
        if( st >= c ) {
            d |= 1;
            st -= c;
            e++;
        }
        else {
            if( e ) e++;
        }
        //
        //  Shift b up by one bit.
        //
        b <<= 1;
    }
    //
    //  Roll in missing bits.
    //
    for( i = e; i < BITS_PER_TYPE; i++ ) {
        //
        //  Shift across product sum
        //
        st <<= 1;
        if( sb & TOP_BIT_TYPE ) st |= 1;
        sb <<= 1;
        //
        //  Division, continued.
        //
        d <<= 1;
        if( st >= c ) {
            d |= 1;
            st -= c;
        }
    }
    return( d );  // remainder should be in st
}

0
投票

我创建了以下看起来相当不错的算法:

  • 避免使用缓慢的
    BigInteger
    数学
  • 仅对
    long
    值使用快速运算
    /**
     * Long computation of (a*b)/c. Arguments and result must be in range 0..Long.MAX_VALUE
     * 
     * @param a First multiplicand
     * @param b Second multiplicand
     * @param c Divisor
     * @return Result of (a*b)/c, or -1 if result overflows long
     */
    static long fastMulDiv(long a, long b, long c) {
        // 128 bit multiply
        long m0=a*b;
        long m1=Math.multiplyHigh(a, b);
        long acc=0;
        
        // we are going to do base 2^63 long division :-)
        // a * b = ab1 * 2^63 + ab0 
        long ab1 = (m1<<1)|(m0>>>63);
        if (c<=ab1) return -1;   // we know this will overflow
        if (ab1==0) return m0/c; // fast path
        long ab0 = (m0&~0x8000000000000000L);

        // d = 2^63 / c
        long dq=-(0x8000000000000000L/c); // note we need to reverse sign
        long dr=Long.remainderUnsigned(0x8000000000000000L, c);  // dr < c

        while (ab1>0) {         
            // a * b = c*(ab1*dq) + (ab1*dr + ab0)
            acc+=ab1*dq;
            
            // so we just need to divide (ab1*dr + ab0) by c to get the rest
            m0=ab1*dr;
            m1=Math.multiplyHigh(ab1, dr);
            ab1 = (m1<<1)|(m0>>>63);
            ab0 = (m0&~0x8000000000000000L)+ab0;  // note we add in the previous ab0
            if (ab0<0) {
                // overflow carry
                ab0=(ab0&~0x8000000000000000L);
                ab1++;
            }
        }
        long result = acc + ab0/c;
        return result;
    }

-1
投票

将 a/c 和 b/c 分成整数部分和分数(余数)部分,然后你有:

a*b/c 
= c * a/c * b/c 
= c * (x/c + y/c) * (z/c + w/c)
= xz/c + xw/c + yz/c + yw/c where x and z are multiples of c

因此,您可以轻松计算前三个因子而不会溢出。根据我的经验,这通常足以涵盖典型的溢出情况。然而,如果你的 divisor 太大,导致

(a % c) * (b % c)
溢出,这个方法仍然会失败。如果这对您来说是一个典型问题,您可能需要考虑其他方法(例如,将 a 和 b 以及 c 中的最大值除以 2,直到不再发生溢出,但是如何做到这一点而不引入额外的错误,因为过程中的偏差并非微不足道——您可能需要在单独的变量中保留错误的运行分数,可能)

无论如何,上面的代码:

long a,b,c;
long bMod = (b % c)
long result = a * (b / c) + (a / c) * bMod + ((a % c) * bMod) / c;

如果速度是一个大问题(我假设它至少在某种程度上是一个大问题,因为你问这个),你可能需要考虑将

a/c
b/c
存储在变量中并通过乘法计算 mod,例如将
(a % c)
替换为
(a - aDiv * c)
——这样您可以将每次调用的 4 个分区变为 2 个分区。

© www.soinside.com 2019 - 2024. All rights reserved.