快速代码确定任意两个列子集是否具有相同的总和

问题描述 投票:0回答:2

对于给定的 n 和 m,我迭代所有 n × m 个部分循环矩阵,其条目为 0 或 1。我想查找是否存在一个矩阵,使得不存在给出相同的金额。在这里,当我们添加列时,我们只是按元素进行操作。我当前的代码通过 ortools 使用约束编程。然而它并没有我想要的那么快。对于 n = 7 和 m = 12,需要超过 3 分钟;对于 n = 10、m = 18,即使只有 2^18 = 262144 个不同的矩阵需要考虑,它也不会终止。这是我的代码。

from scipy.linalg import circulant
import numpy as np
import itertools
from ortools.constraint_solver import pywrapcp as cs

n = 7
m = 12

def isdetecting(matrix):
    X = np.array([solver.IntVar(values) for i in range(matrix.shape[1])])
    X1 = X.tolist()
    for row in matrix:
        x = X[row].tolist()
        solver.Add(solver.Sum(x) == 0)
    db = solver.Phase(X1, solver.INT_VAR_DEFAULT, solver.INT_VALUE_DEFAULT)
    solver.NewSearch(db)
    count = 0
    while (solver.NextSolution() and count < 2):
        solution = [x.Value() for x in X1]
        count += 1
    solver.EndSearch()
    if (count < 2):
        return True

values = [-1,0,1]
solver = cs.Solver("scip")

for row in itertools.product([0,1],repeat = m):
    M = np.array(circulant(row)[0:n], dtype=bool)
    if isdetecting(M):
        print M.astype(int)
        break

这个问题能否足够快地解决,以便能够解决 n = 10,m = 18?

algorithm performance math or-tools constraint-programming
2个回答
2
投票

一个问题是您在全局声明“solver”变量,这似乎会混淆或工具多次重用它。当将其移入“isDetecting”时,(7,12) 问题的解决速度要快得多,大约需要 7 秒(而原始模型需要 2 分 51 分钟)。不过,我还没有检查过更大的问题。

此外,测试不同的标签(而不是solver.INT_VAR_DEFAULT和solver.INT_VALUE_DEFAULT)可能是个好主意,尽管二进制值往往对不同的标签不太敏感。请参阅另一个标签的代码。

def isdetecting(matrix):
   solver = cs.Solver("scip") # <----
   X = np.array([solver.IntVar(values) for i in range(matrix.shape[1])])
   X1 = X.tolist()
   for row in matrix:
       x = X[row].tolist()
       solver.Add(solver.Sum(x) == 0)
   # db = solver.Phase(X1, solver.INT_VAR_DEFAULT, solver.INT_VALUE_DEFAULT)
   db = solver.Phase(X1, solver.CHOOSE_FIRST_UNBOUND, solver.ASSIGN_CENTER_VALUE)    
   solver.NewSearch(db)
   count = 0
   while (solver.NextSolution() and count < 2):
       solution = [x.Value() for x in X1]
       count += 1
   solver.EndSearch()
   if (count < 2):
       print "FOUND"
       return True

编辑:以下是删除评论中提到的全 0 解决方案的约束。据我所知,它需要一个单独的列表。现在需要更长的时间(10.4 秒 vs 7 秒)。

X1Abs = [solver.IntVar(values, 'X1Abs[%i]' % i) for i in range(X1_len)]
for i in range(X1_len):
    solver.Add(X1Abs[i] == abs(X1[i])) 
solver.Add(solver.Sum(X1Abs) > 0)       

1
投票

我就是这么想的。我估计命令行参数 10 18 在我的机器上的运行时间不到 8 小时。

public class Search {
    public static void main(String[] args) {
        int n = Integer.parseInt(args[0]);
        int m = Integer.parseInt(args[1]);
        int row = search(n, m);
        if (row >= 0) {
            printRow(m, row);
        }
    }

    private static int search(int n, int m) {
        if (n < 0 || m < n || m >= 31 || powOverflows(m + 1, n)) {
            throw new IllegalArgumentException();
        }
        long[] column = new long[m];
        long[] sums = new long[1 << m];
        int row = 1 << m;
        while (row-- > 0) {
            System.err.println(row);
            for (int j = 0; j < m; j++) {
                column[j] = 0;
                for (int i = 0; i < n; i++) {
                    column[j] = (column[j] * (m + 1)) + ((row >> ((i + j) % m)) & 1);
                }
            }
            for (int subset = 0; subset < (1 << m); subset++) {
                long sum = 0;
                for (int j = 0; j < m; j++) {
                    if (((subset >> j) & 1) == 1) {
                        sum += column[j];
                    }
                }
                sums[subset] = sum;
            }
            java.util.Arrays.sort(sums);
            boolean duplicate = false;
            for (int k = 1; k < (1 << m); k++) {
                if (sums[k - 1] == sums[k]) {
                    duplicate = true;
                    break;
                }
            }
            if (!duplicate) {
                break;
            }
        }
        return row;
    }

    private static boolean powOverflows(long b, int e) {
        if (b <= 0 || e < 0) {
            throw new IllegalArgumentException();
        }
        if (e == 0) {
            return false;
        }
        long max = Long.MAX_VALUE;
        while (e > 1) {
            if (b > Integer.MAX_VALUE) {
                return true;
            }
            if ((e & 1) == 1) {
                max /= b;
            }
            b *= b;
            e >>= 1;
        }
        return b > max;
    }

    private static void printRow(int m, int row) {
        for (int j = 0; j < m; j++) {
            System.out.print((row >> j) & 1);
        }
        System.out.println();
    }
}
© www.soinside.com 2019 - 2024. All rights reserved.