我认为我有正确的算法,但是当值增加到 106 或更多时,我超出了允许的内存或时间限制。起初,我尝试将元素推送到向量,然后我更改了方法以重用变量,并且通过了更多测试。
公式:Ai = (Ai-1 + 2 * Ai-2 + 3 * Ai-3) mod M,其中 M = 109 + 7。
1 <= n <= 1012 时间限制:1秒,内存:256mb
代码:
#include<iostream>
#include<cmath>
using namespace std;
using ull = unsigned long long;
ull func(ull n){
ull a = 1;
ull b = 1;
ull c = 2;
if (n < 2) return a;
if (n == 3) return c;
ull res = 0;
for (ull i = 0; i < n - 3; i++){
res = (3 * a + 2 * b + c) % (ull)(pow(10, 9) + 7);
a = b;
b = c;
c = res;
}
return c;
}
int main() {
int x;
cin >> x;
cout << func(x);
}
现在我有了算法,它通过了 3 个初始睾丸(然后失败了 63 个测试,我认为值 > 10^6)
测试1 输入:6 输出:34
测试2 输入:10 输出:1096
测试3 输入:500 输出:340736120
我需要改变算法或通过任何方法加速吗?
您当前的解决方案是
O(n)
,当 n
可以大到 1012 时,它太慢了。
我们可以找到一个矩阵
M
,这样我们就可以通过乘法从一个状态转换到下一个状态。 M
满足
[Ai, Ai-1, Ai-2]T = M * [Ai-1, Ai-2, Ai-3]
显然,
M
的最后一行就是简单的[0, 1, 0]
来得到Ai-2。
同样,第二行是
[1, 0, 0]
。
第一行是
[1, 2, 3]
,直接来自递推方程。
现在,对于
n > 3
,我们可以通过(左)乘以初始条件来找到序列的第n
元素,[A3,A2,A1] = [2, 1, 1]
,总共 M
n-3
次。这相当于乘以 Mn-3。矩阵求幂可以在 O(S3 log(N)) 中执行,其中 S 是矩阵的维数(在本例中为常数 3
),N 是二进制求幂的指数。
这导致了以下解决方案:
#include <iostream>
#include <vector>
#include <span>
#include <initializer_list>
#include <stdexcept>
#include <cstddef>
constexpr int MOD = 1e9 + 7;
template<typename T>
class Matrix {
public:
std::size_t rows, cols;
std::vector<std::vector<T>> values;
public:
Matrix(std::size_t rows, std::size_t cols) : rows{rows}, cols{cols}, values(rows, std::vector<T>(cols)) {}
Matrix(std::initializer_list<std::initializer_list<T>> initVals) : rows{initVals.size()} {
values.reserve(initVals.size());
for (auto& row : initVals) {
values.emplace_back(row);
if ((cols = row.size()) != values[0].size()) throw std::domain_error("Not a matrix: rows have unequal size");
}
}
std::span<T> operator[](std::size_t r) {
return values[r];
}
std::span<const T> operator[](std::size_t r) const {
return values[r];
}
static Matrix identity(std::size_t size) {
Matrix id(size, size);
for (std::size_t i = 0; i < size; ++i) id.values[i][i] = 1;
return id;
}
Matrix operator*(const Matrix& m) const {
if (cols != m.rows) throw std::domain_error("Matrix dimensions do not match");
Matrix res(rows, m.cols);
for (std::size_t r = 0; r < rows; ++r)
for (std::size_t c = 0; c < m.cols; ++c)
for (std::size_t i = 0; i < cols; ++i)
res.values[r][c] += values[r][i] * m.values[i][c];
return res;
}
Matrix operator%(T mod) const {
auto res = *this;
for (std::size_t r = 0; r < rows; ++r)
for (std::size_t c = 0; c < cols; ++c)
res.values[r][c] %= mod;
return res;
}
Matrix modPow(std::size_t exp, T mod) const {
if (rows != cols) throw std::domain_error("Matrix is not square");
auto res = identity(rows), sq = *this;
for (; exp; exp >>= 1) {
if (exp & 1) res = res * sq % mod;
sq = sq * sq % mod;
}
return res;
}
};
const Matrix<unsigned long long> transition{{1, 2, 3}, {1, 0, 0}, {0, 1, 0}}, initialConditions{{2}, {1}, {1}};
unsigned long long nthValue(unsigned long long n){
if (n < 3) return 1;
return (transition.modPow(n - 3, MOD) * initialConditions % MOD)[0][0];
}
int main() {
unsigned long long n;
std::cin >> n;
std::cout << nthValue(n) << '\n';
}