对于给定的 n 和 m,我迭代所有 n × m 个部分循环矩阵,其条目为 0 或 1。我想查找是否存在一个矩阵,使得不存在给出相同的金额。在这里,当我们添加列时,我们只是按元素进行操作。我当前的代码通过 ortools 使用约束编程。然而它并没有我想要的那么快。对于 n = 7 和 m = 12,需要超过 3 分钟;对于 n = 10、m = 18,即使只有 2^18 = 262144 个不同的矩阵需要考虑,它也不会终止。这是我的代码。
from scipy.linalg import circulant
import numpy as np
import itertools
from ortools.constraint_solver import pywrapcp as cs
n = 7
m = 12
def isdetecting(matrix):
X = np.array([solver.IntVar(values) for i in range(matrix.shape[1])])
X1 = X.tolist()
for row in matrix:
x = X[row].tolist()
solver.Add(solver.Sum(x) == 0)
db = solver.Phase(X1, solver.INT_VAR_DEFAULT, solver.INT_VALUE_DEFAULT)
solver.NewSearch(db)
count = 0
while (solver.NextSolution() and count < 2):
solution = [x.Value() for x in X1]
count += 1
solver.EndSearch()
if (count < 2):
return True
values = [-1,0,1]
solver = cs.Solver("scip")
for row in itertools.product([0,1],repeat = m):
M = np.array(circulant(row)[0:n], dtype=bool)
if isdetecting(M):
print M.astype(int)
break
这个问题能否足够快地解决,以便能够解决 n = 10,m = 18?
一个问题是您在全局声明“solver”变量,这似乎会混淆或工具多次重用它。当将其移入“isDetecting”时,(7,12) 问题的解决速度要快得多,大约需要 7 秒(而原始模型需要 2 分 51 分钟)。不过,我还没有检查过更大的问题。
此外,测试不同的标签(而不是solver.INT_VAR_DEFAULT和solver.INT_VALUE_DEFAULT)可能是个好主意,尽管二进制值往往对不同的标签不太敏感。请参阅另一个标签的代码。
def isdetecting(matrix):
solver = cs.Solver("scip") # <----
X = np.array([solver.IntVar(values) for i in range(matrix.shape[1])])
X1 = X.tolist()
for row in matrix:
x = X[row].tolist()
solver.Add(solver.Sum(x) == 0)
# db = solver.Phase(X1, solver.INT_VAR_DEFAULT, solver.INT_VALUE_DEFAULT)
db = solver.Phase(X1, solver.CHOOSE_FIRST_UNBOUND, solver.ASSIGN_CENTER_VALUE)
solver.NewSearch(db)
count = 0
while (solver.NextSolution() and count < 2):
solution = [x.Value() for x in X1]
count += 1
solver.EndSearch()
if (count < 2):
print "FOUND"
return True
编辑:以下是删除评论中提到的全 0 解决方案的约束。据我所知,它需要一个单独的列表。现在需要更长的时间(10.4 秒 vs 7 秒)。
X1Abs = [solver.IntVar(values, 'X1Abs[%i]' % i) for i in range(X1_len)]
for i in range(X1_len):
solver.Add(X1Abs[i] == abs(X1[i]))
solver.Add(solver.Sum(X1Abs) > 0)
我就是这么想的。我估计命令行参数 10 18 在我的机器上的运行时间不到 8 小时。
public class Search {
public static void main(String[] args) {
int n = Integer.parseInt(args[0]);
int m = Integer.parseInt(args[1]);
int row = search(n, m);
if (row >= 0) {
printRow(m, row);
}
}
private static int search(int n, int m) {
if (n < 0 || m < n || m >= 31 || powOverflows(m + 1, n)) {
throw new IllegalArgumentException();
}
long[] column = new long[m];
long[] sums = new long[1 << m];
int row = 1 << m;
while (row-- > 0) {
System.err.println(row);
for (int j = 0; j < m; j++) {
column[j] = 0;
for (int i = 0; i < n; i++) {
column[j] = (column[j] * (m + 1)) + ((row >> ((i + j) % m)) & 1);
}
}
for (int subset = 0; subset < (1 << m); subset++) {
long sum = 0;
for (int j = 0; j < m; j++) {
if (((subset >> j) & 1) == 1) {
sum += column[j];
}
}
sums[subset] = sum;
}
java.util.Arrays.sort(sums);
boolean duplicate = false;
for (int k = 1; k < (1 << m); k++) {
if (sums[k - 1] == sums[k]) {
duplicate = true;
break;
}
}
if (!duplicate) {
break;
}
}
return row;
}
private static boolean powOverflows(long b, int e) {
if (b <= 0 || e < 0) {
throw new IllegalArgumentException();
}
if (e == 0) {
return false;
}
long max = Long.MAX_VALUE;
while (e > 1) {
if (b > Integer.MAX_VALUE) {
return true;
}
if ((e & 1) == 1) {
max /= b;
}
b *= b;
e >>= 1;
}
return b > max;
}
private static void printRow(int m, int row) {
for (int j = 0; j < m; j++) {
System.out.print((row >> j) & 1);
}
System.out.println();
}
}