按顺序查找n个元素。如何加快 n > 10^6 的程序(时间)?

问题描述 投票:0回答:1

我认为我有正确的算法,但是当值增加到 106 或更多时,我超出了允许的内存或时间限制。起初,我尝试将元素推送到向量,然后我更改了方法以重用变量,并且通过了更多测试。

公式:Ai = (Ai-1 + 2 * Ai-2 + 3 * Ai-3) mod M,其中 M = 109 + 7。
1 <= n <= 1012 时间限制:1秒,内存:256mb

代码:

#include<iostream>
#include<cmath>

using namespace std;
using ull = unsigned long long;

ull func(ull n){
    ull a = 1;
    ull b = 1;
    ull c = 2;
    if (n < 2) return a;
    if (n == 3) return c;
    ull res = 0;
    for (ull i = 0; i < n - 3; i++){
        res = (3 * a + 2 * b + c) % (ull)(pow(10, 9) + 7);
        a = b;
        b = c;
        c = res;
    }
    return c;
}

int main() {
    int x; 
    cin >> x;
    cout << func(x);
}

现在我有了算法,它通过了 3 个初始睾丸(然后失败了 63 个测试,我认为值 > 10^6)

测试1 输入:6 输出:34

测试2 输入:10 输出:1096

测试3 输入:500 输出:340736120

我需要改变算法或通过任何方法加速吗?

c++ algorithm performance memory biginteger
1个回答
0
投票

您当前的解决方案是

O(n)
,当
n
可以大到 1012 时,它太慢了。

我们可以找到一个矩阵

M
,这样我们就可以通过乘法从一个状态转换到下一个状态。
M
满足

[Ai, Ai-1, Ai-2]T = M * [Ai-1, Ai-2, Ai-3]

显然,

M
的最后一行就是简单的
[0, 1, 0]
来得到Ai-2

同样,第二行是

[1, 0, 0]

第一行是

[1, 2, 3]
,直接来自递推方程。

现在,对于

n > 3
,我们可以通过(左)乘以初始条件来找到序列的第
n
元素,[A3,A2,A1] =
[2, 1, 1]
,总共
M
n-3
次。这相当于乘以 Mn-3。矩阵求幂可以在 O(S3 log(N)) 中执行,其中 S 是矩阵的维数(在本例中为常数
3
),N 是二进制求幂的指数。

这导致了以下解决方案:

#include <iostream>
#include <vector>
#include <span>
#include <initializer_list>
#include <stdexcept>
#include <cstddef>
constexpr int MOD = 1e9 + 7;
template<typename T>
class Matrix {
public:
    std::size_t rows, cols;
    std::vector<std::vector<T>> values;

public:
    Matrix(std::size_t rows, std::size_t cols) : rows{rows}, cols{cols}, values(rows, std::vector<T>(cols)) {}
    Matrix(std::initializer_list<std::initializer_list<T>> initVals) : rows{initVals.size()} {
        values.reserve(initVals.size());
        for (auto& row : initVals) {
            values.emplace_back(row);
            if ((cols = row.size()) != values[0].size()) throw std::domain_error("Not a matrix: rows have unequal size");
        }
    }
    std::span<T> operator[](std::size_t r) {
        return values[r];
    }
    std::span<const T> operator[](std::size_t r) const {
        return values[r];
    }
    static Matrix identity(std::size_t size) {
        Matrix id(size, size);
        for (std::size_t i = 0; i < size; ++i) id.values[i][i] = 1;
        return id;
    }
    Matrix operator*(const Matrix& m) const {
        if (cols != m.rows) throw std::domain_error("Matrix dimensions do not match");
        Matrix res(rows, m.cols);
        for (std::size_t r = 0; r < rows; ++r)
            for (std::size_t c = 0; c < m.cols; ++c)
                for (std::size_t i = 0; i < cols; ++i)
                    res.values[r][c] += values[r][i] * m.values[i][c];
        return res;
    }
    Matrix operator%(T mod) const {
        auto res = *this;
        for (std::size_t r = 0; r < rows; ++r)
            for (std::size_t c = 0; c < cols; ++c)
                res.values[r][c] %= mod;
        return res;
    }
    Matrix modPow(std::size_t exp, T mod) const {
        if (rows != cols) throw std::domain_error("Matrix is not square");
        auto res = identity(rows), sq = *this;
        for (; exp; exp >>= 1) {
            if (exp & 1) res = res * sq % mod;
            sq = sq * sq % mod;
        }
        return res;
    }
};
const Matrix<unsigned long long> transition{{1, 2, 3}, {1, 0, 0}, {0, 1, 0}}, initialConditions{{2}, {1}, {1}};
unsigned long long nthValue(unsigned long long n){
    if (n < 3) return 1;
    return (transition.modPow(n - 3, MOD) * initialConditions % MOD)[0][0];
}

int main() {
    unsigned long long n; 
    std::cin >> n;
    std::cout << nthValue(n) << '\n';
}
© www.soinside.com 2019 - 2024. All rights reserved.