我正在尝试使用 Pollard rho 求解椭圆曲线离散对数(找到 k,其中 G=kp),所以我搜索了 c 中的实现,并在我得到的
main
函数中添加问题特定数据后找到了一个segmentation fault (core dumped)
#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include <gmp.h>
#include <limits.h>
#include <sys/time.h>
#include <openssl/ec.h>
#include <openssl/bn.h>
#include <openssl/obj_mac.h> // for NID_secp256k1
#define POLLARD_SET_COUNT 16
#if defined(WIN32) || defined(_WIN32)
#define EXPORT __declspec(dllexport)
#else
#define EXPORT
#endif
#define MAX_RESTART 100
int ec_point_partition(const EC_GROUP *ecgrp, const EC_POINT *x) {
size_t len = EC_POINT_point2oct( ecgrp, x, POINT_CONVERSION_UNCOMPRESSED, NULL, 0, NULL );
unsigned char ret[len];
EC_POINT_point2oct( ecgrp, x, POINT_CONVERSION_UNCOMPRESSED, ret, len, NULL );
int id = ( ret[len - 1] & 0xFF ) % POLLARD_SET_COUNT;
return id;
}
// P generator
// Q result*P
// order of the curve
// result
//Reference: J. Sattler and C. P. Schnorr, "Generating random walks in groups"
int elliptic_pollard_rho_dlog(const EC_GROUP *group, const EC_POINT *P, const EC_POINT *Q, const BIGNUM *order, BIGNUM *res) {
printf("Pollard rho discrete log algorithm... \n");
BN_CTX* ctx;
ctx = BN_CTX_new();
int i, j;
int iterations = 0;
if ( !EC_POINT_is_on_curve(group, P, ctx ) || !EC_POINT_is_on_curve(group, Q, ctx ) ) return 1;
EC_POINT *X1 = EC_POINT_new(group);
EC_POINT *X2 = EC_POINT_new(group);
BIGNUM *c1 = BN_new();
BIGNUM *d1 = BN_new();
BIGNUM *c2 = BN_new();
BIGNUM *d2 = BN_new();
BIGNUM* a[POLLARD_SET_COUNT];
BIGNUM* b[POLLARD_SET_COUNT];
EC_POINT* R[POLLARD_SET_COUNT];
BN_zero(c1); BN_zero(d1);
BN_zero(c2); BN_zero(d2);
for (i = 0; i < POLLARD_SET_COUNT; i++) {
a[i] = BN_new();
b[i] = BN_new();
R[i] = EC_POINT_new(group);
BN_rand_range(a[i], order);
BN_rand_range(b[i], order);
// R = aP + bQ
EC_POINT_mul(group, R[i], a[i], Q, b[i], ctx);
//ep_norm(R[i], R[i]);
}
BN_rand_range(c1, order);
BN_rand_range(d1, order);
// X1 = c1*P + d1*Q
EC_POINT_mul(group, X1, c1, Q, d1, ctx);
//ep_norm(X1, X1);
BN_copy(c2, c1);
BN_copy(d2, d1);
EC_POINT_copy(X2, X1);
double work_time = (double) clock();
do {
j = ec_point_partition(group, X1);
EC_POINT_add(group, X1, X1, R[j], ctx);
BN_mod_add(c1, c1, a[j], order, ctx);
BN_mod_add(d1, d1, b[j], order, ctx);
for (i = 0; i < 2; i++) {
j = ec_point_partition(group, X2);
EC_POINT_add(group, X2, X2, R[j], ctx);
BN_mod_add(c2, c2, a[j], order, ctx);
BN_mod_add(d2, d2, b[j], order, ctx);
}
iterations++;
printf("Iteration %d \r",iterations );
} while ( EC_POINT_cmp(group, X1, X2, ctx) != 0 ) ;
printf("\n ");
work_time = ( (double) clock() - work_time ) / (double)CLOCKS_PER_SEC;
printf("Number of iterations %d %f\n",iterations, work_time );
BN_mod_sub(c1, c1, c2, order, ctx);
BN_mod_sub(d2, d2, d1, order, ctx);
if (BN_is_zero(d2) == 1) return 1;
//d1 = d2^-1 mod order
BN_mod_inverse(d1, d2, order, ctx);
BN_mod_mul(res, c1, d1, order, ctx);
for (int k = 0; k < POLLARD_SET_COUNT; ++k) {
BN_free(a[k]);
BN_free(b[k]);
EC_POINT_free(R[k]);
}
BN_free(c1); BN_free(d1);
BN_free(c2); BN_free(d2);
EC_POINT_free(X1); EC_POINT_free(X2);
BN_CTX_free(ctx);
return 0;
}
int main(int argc, char *argv[])
{
unsigned char *p_str="134747661567386867366256408824228742802669457";
unsigned char *a_str="-1";
unsigned char *b_str="0";
BIGNUM *p = BN_bin2bn(p_str, sizeof(p_str), NULL);
BIGNUM *a = BN_bin2bn(a_str, sizeof(a_str), NULL);
BIGNUM *b = BN_bin2bn(b_str, sizeof(b_str), NULL);
BN_CTX* ctx;
ctx = BN_CTX_new();
EC_GROUP* g = EC_GROUP_new(EC_GFp_simple_method());
EC_GROUP_set_curve_GFp(g,p,a,b,ctx);
unsigned char *XP_str="18185174461194872234733581786593019886770620";
unsigned char *YP_str="74952280828346465277451545812645059041440154";
BN_CTX* ctx1;
ctx1 = BN_CTX_new();
BIGNUM *XP = BN_bin2bn(XP_str, sizeof(XP_str), NULL);
BIGNUM *YP = BN_bin2bn(YP_str, sizeof(YP_str), NULL);
EC_POINT* P = EC_POINT_new(g);
EC_POINT_set_affine_coordinates_GFp(g,P,XP,YP,ctx1);
unsigned char *XQ_str="76468233972358960368422190121977870066985660";
unsigned char *YQ_str="33884872380845276447083435959215308764231090";
BIGNUM* XQ = BN_bin2bn(XQ_str, sizeof(XQ_str), NULL);
BIGNUM* YQ = BN_bin2bn(YQ_str, sizeof(YQ_str), NULL);
EC_POINT *Q = EC_POINT_new(g);
BN_CTX* ctx2;
ctx2 = BN_CTX_new();
EC_POINT_set_affine_coordinates_GFp(g,Q,XQ,YQ,ctx2);
char * str;
unsigned char *N_str="2902021510595963727029";
BIGNUM *N = BN_bin2bn(N_str, sizeof(N_str), NULL);
BIGNUM *res;
elliptic_pollard_rho_dlog (g,P,Q,N,res);
BN_bn2mpi(res,str);
printf("%s\n", str);
return 0;
}
这就是导致
segmentation fault
的陈述
BN_bn2mpi(res,str);
您能发送一个适用于 secp256k1 曲线的版本吗?
第 1 部分。 Python 版本。
更新:请参阅我的答案的新第 2 部分,其中我提供了与此 Python 版本相同算法的 C++ 版本。
你的任务很有趣!
也许您希望修复代码,但我决定从头开始实现纯 Python(答案的第 1 部分)和纯 C++(第 2 部分)解决方案,而不使用任何外部非标准模块。我认为这种从头开始、没有依赖关系的解决方案对于教育目的非常有用。
这样的算法相当复杂,而Python很容易在短时间内实现这样的算法。
在下面的代码中,我使用维基百科的帮助来实现Pollard 的 Rho 离散对数和椭圆曲线点乘法。
代码不依赖于任何外部模块,它只使用很少的内置Python模块。如果您通过 python -m pip install gmpy2
安装并取消代码中的
#import gmpy2
行注释,则可以使用 gmpy2模块。
您可能会看到我自己生成随机基点并计算其顺序。我不使用任何外部曲线,如比特币的secp256k1,或其他标准曲线。
在
main()
函数的开头,您可以看到我设置了 bits = 24
,这是曲线质数模数的位数,曲线的阶数(不同点的数量)将具有大约相同的位大小。您可以将其设置为bits = 32
以尝试解决更大曲线的任务。
众所周知,算法的复杂度是
O(Sqrt(Curve_Order))
,需要这么多椭圆曲线点相加。加点并不是原始操作,也需要一些时间。因此,对于 bits = 32
的曲线阶位大小的算法运行大约需要 10-15 秒。虽然 bits = 64
对于 Python 来说会花费太长的时间,但 C++ 版本(我稍后将实现)将足够快,可以在一个小时左右破解 64 位。
有时您可能会注意到,在运行代码时,它显示 Pollard Rho 失败了几次,如果算法在 Pollard Rho 的最后一步以及计算时尝试找到不可逆数的模逆(与模数非互质),就会发生这种情况无限点是椭圆曲线点相加的结果。当 GCD 等于 N 时,在常规 Pollard Rho 整数分解 中也会时常发生同样的故障。
import random
#random.seed(10)
class ECPoint:
gmpy2 = None
#import gmpy2
import random
class InvError(Exception):
def __init__(self, *args):
self.value = args
@classmethod
def Int(cls, x):
return int(x) if cls.gmpy2 is None else cls.gmpy2.mpz(x)
@classmethod
def fermat_prp(cls, n, trials = 32):
# https://en.wikipedia.org/wiki/Fermat_primality_test
if n <= 16:
return n in (2, 3, 5, 7, 11, 13)
for i in range(trials):
if pow(cls.random.randint(2, n - 2), n - 1, n) != 1:
return False
return True
@classmethod
def rand_prime(cls, bits):
while True:
p = cls.random.randrange(1 << (bits - 1), 1 << bits) | 1
if cls.fermat_prp(p):
return p
@classmethod
def base_gen(cls, bits = 128, *, min_order_pfactor = 0):
while True:
while True:
N = cls.rand_prime(bits)
if N % 4 != 3:
continue
x0, y0, A = [cls.random.randrange(1, N) for i in range(3)]
B = (y0 ** 2 - x0 ** 3 - A * x0) % N
y0_calc = pow(x0 ** 3 + A * x0 + B, (N + 1) // 4, N)
if y0 == y0_calc:
break
bp = ECPoint(A, B, N, x0, y0, calc_q = True)
if bp.q is not None and min(bp.q_ps) >= min_order_pfactor:
break
assert bp.q > 1 and (bp.q + 1) * bp == bp
return bp
def __init__(self, A, B, N, x, y, *, q = 0, prepare = True, calc_q = False):
if prepare:
N = self.Int(N)
assert (x is None) == (y is None), (x, y)
A, B, x, y, q = [(self.Int(e) % N if e is not None else None) for e in [A, B, x, y, q]]
assert (4 * A ** 3 + 27 * B ** 2) % N != 0
assert N % 4 == 3
if x is not None:
assert (y ** 2 - x ** 3 - A * x - B) % N == 0, (hex(N), hex((y ** 2 - x ** 3 - A * x) % N))
assert y == pow(x ** 3 + A * x + B, (N + 1) // 4, N)
self.A, self.B, self.N, self.x, self.y, self.q = A, B, N, x, y, q
if calc_q:
self.q, self.q_ps = self.find_order()
def copy(self):
return ECPoint(self.A, self.B, self.N, self.x, self.y, q = self.q, prepare = False)
def inf(self):
return ECPoint(self.A, self.B, self.N, None, None, q = self.q, prepare = False)
def find_order(self, *, _m = 1, _ps = []):
if 1:
try:
r = _m * self
except self.InvError:
return _m, _ps
B = 2 * self.N
for p in self.gen_primes():
if p * p > B * 2:
return None, []
assert _m % p != 0, (_m, p)
assert p <= B, (p, B)
hi = 1
try:
for cnt in range(1, 1 << 60):
hi *= p
if hi > B:
cnt -= 1
break
r = p * r
except self.InvError:
return self.find_order(_m = hi * _m, _ps = [p] * cnt + _ps)
else:
# Alternative slower way
r = self
for i in range(1 << 60):
try:
r = r + self
except self.InvError:
return i + 2, []
@classmethod
def gen_primes(cls, *, ps = [2, 3]):
yield from ps
for p in range(ps[-1] + 2, 1 << 60, 2):
is_prime = True
for e in ps:
if e * e > p:
break
if p % e == 0:
is_prime = False
break
if is_prime:
ps.append(p)
yield ps[-1]
def __add__(self, other):
if self.x is None:
return other.copy()
if other.x is None:
return self.copy()
A, B, N, q = self.A, self.B, self.N, self.q
Px, Py, Qx, Qy = self.x, self.y, other.x, other.y
if Px == Qx and Py == Qy:
s = ((Px * Px * 3 + A) * self.inv(Py * 2, N)) % N
else:
s = ((Py - Qy) * self.inv(Px - Qx, N)) % N
x = (s * s - Px - Qx) % N
y = (s * (Px - x) - Py) % N
return ECPoint(A, B, N, x, y, q = q, prepare = False)
def __rmul__(self, other):
assert other > 0, other
if other == 1:
return self.copy()
other = self.Int(other - 1)
r = self
while True:
if other & 1:
r = r + self
if other == 1:
return r
other >>= 1
self = self + self
@classmethod
def inv(cls, a, n):
a %= n
if cls.gmpy2 is None:
try:
return pow(a, -1, n)
except ValueError:
import math
raise cls.InvError(math.gcd(a, n), a, n)
else:
g, s, t = cls.gmpy2.gcdext(a, n)
if g != 1:
raise cls.InvError(g, a, n)
return s % n
def __repr__(self):
return str(dict(x = self.x, y = self.y, A = self.A, B = self.B, N = self.N, q = self.q))
def __eq__(self, other):
for i, (a, b) in enumerate([(self.x, other.x), (self.y, other.y), (self.A, other.A), (self.B, other.B), (self.N, other.N), (self.q, other.q)]):
if a != b:
return False
return True
def pollard_rho_ec_log(a, b, bp):
# https://en.wikipedia.org/wiki/Pollard%27s_rho_algorithm_for_logarithms#Algorithm
import math
for itry in range(1 << 60):
try:
i = -1
part_p = bp.rand_prime(max(3, int(math.log2(bp.N) / 2)))
def f(x):
mod3 = ((x.x or 0) % part_p) % 3
if mod3 == 0:
return b + x
elif mod3 == 1:
return x + x
elif mod3 == 2:
return a + x
else:
assert False
def g(x, n):
mod3 = ((x.x or 0) % part_p) % 3
if mod3 == 0:
return n
elif mod3 == 1:
return (2 * n) % bp.q
elif mod3 == 2:
return (n + 1) % bp.q
else:
assert False
def h(x, n):
mod3 = ((x.x or 0) % part_p) % 3
if mod3 == 0:
return (n + 1) % bp.q
elif mod3 == 1:
return (2 * n) % bp.q
elif mod3 == 2:
return n
else:
assert False
a0, b0, x0 = 0, 0, bp.inf()
aim1, bim1, xim1 = a0, b0, x0
a2im2, b2im2, x2im2 = a0, b0, x0
for i in range(1, 1 << 60):
xi = f(xim1)
ai = g(xim1, aim1)
bi = h(xim1, bim1)
x2i = f(f(x2im2))
a2i = g(f(x2im2), g(x2im2, a2im2))
b2i = h(f(x2im2), h(x2im2, b2im2))
if xi == x2i:
return (bp.inv(bi - b2i, bp.q) * (a2i - ai)) % bp.q
xim1, aim1, bim1 = xi, ai, bi
x2im2, a2im2, b2im2 = x2i, a2i, b2i
except bp.InvError as ex:
print(f'Try {itry:>4}, Pollard-Rho failed, invert err at iter {i:>7},', ex.value)
def main():
import random, math
bits = 24
print('Generating base point, wait...')
bp = ECPoint.base_gen(bits, min_order_pfactor = 10)
print('order', bp.q, '=', ' * '.join([str(e) for e in bp.q_ps]))
k0, k1 = [random.randrange(1, bp.q) for i in range(2)]
a = k0 * bp
x = k1
b = x * a
x_calc = pollard_rho_ec_log(a, b, bp)
print('our x', x, 'found x', x_calc)
print('equal points:', x * a == x_calc * a)
if __name__ == '__main__':
main()
输出:
Generating base point, wait...
order 5805013 = 19 * 109 * 2803
Try 0, Pollard-Rho failed, invert err at iter 1120, (109, 1411441, 5805013)
Try 1, Pollard-Rho failed, invert err at iter 3992, (19, 5231802, 5805013)
our x 990731 found x 990731
equal points: True
第 2 部分。 C++ 版本。
与上面的代码几乎相同,但用 C++ 重写。
这个 C++ 版本比 Python 快得多,C++ 代码在 1 Ghz CPU 上花费大约 1 分钟来破解 48 位曲线。 Python 在 32 位曲线上花费了相同的时间。
提醒一下,复杂性是
O(Sqrt(Curve_Order))
,这意味着如果C++在48位(sqrt是2^24)上花费与Python在32位(sqrt是2^16)上花费相同的时间,那么C++比Python快大约2^24/2^16 = 2^8 = 256
倍版本。
以下版本只能在CLang中编译,因为它使用128和192位整数。在 GCC 中也存在
__int128
但没有 192/256 个整数。 192位int只用在BarrettMod()
函数中,所以如果你用return x % n;
替换这个函数体,那么你就不需要256位int,然后你可以在GCC中编译。
我实现了Barrett Reduction算法,用基于乘法/移位/减法的特殊Barrett公式来代替基于慢速除法的取模操作(
% N
)。这使模数运算提高了数倍。
#include <cstdint>
#include <random>
#include <stdexcept>
#include <type_traits>
#include <iomanip>
#include <iostream>
#include <string>
#include <chrono>
#include <cmath>
using u64 = uint64_t;
using u128 = unsigned __int128;
using u192 = unsigned _ExtInt(192);
using Word = u64;
using DWord = u128;
using SWord = std::make_signed_t<Word>;
using TWord = u192;
#define ASSERT_MSG(cond, msg) { if (!(cond)) throw std::runtime_error("Assertion (" #cond ") failed at line " + std::to_string(__LINE__) + "! Msg '" + std::string(msg) + "'."); }
#define ASSERT(cond) ASSERT_MSG(cond, "")
#define LN { g_log << " LN " << __LINE__ << " " << std::flush; }
#define DUMP(x) { g_log << " " << (#x) << " = " << (x) << " " << std::flush; }
static auto & g_log = std::cout;
class ECPoint {
public:
class InvError : public std::runtime_error {
public:
InvError(Word const & gcd, Word const & x, Word const & mod)
: std::runtime_error("(gcd " + std::to_string(gcd) + ", x " + std::to_string(x) +
", mod " + std::to_string(mod) + ")") {}
};
static Word pow_mod(Word a, Word b, Word const & c) {
// https://en.wikipedia.org/wiki/Modular_exponentiation
Word r = 1;
while (b != 0) {
if (b & 1)
r = (DWord(r) * a) % c;
a = (DWord(a) * a) % c;
b >>= 1;
}
return r;
}
static Word rand_range(Word const & begin, Word const & end) {
u64 const seed = (u64(std::random_device{}()) << 32) + std::random_device{}();
thread_local std::mt19937_64 rng{seed};
ASSERT(begin < end);
return std::uniform_int_distribution<Word>(begin, end - 1)(rng);
}
static bool fermat_prp(Word const & n, size_t trials = 32) {
// https://en.wikipedia.org/wiki/Fermat_primality_test
if (n <= 16)
return n == 2 || n == 3 || n == 5 || n == 7 || n == 11 || n == 13;
for (size_t i = 0; i < trials; ++i)
if (pow_mod(rand_range(2, n - 2), n - 1, n) != 1)
return false;
return true;
}
static Word rand_prime_range(Word begin, Word end) {
while (true) {
Word const p = rand_range(begin, end) | 1;
if (fermat_prp(p))
return p;
}
}
static Word rand_prime(size_t bits) {
return rand_prime_range(Word(1) << (bits - 1), Word((DWord(1) << bits) - 1));
}
std::tuple<Word, size_t> BarrettRS(Word n) {
size_t constexpr extra = 3;
for (size_t k = 0; k < sizeof(DWord) * 8; ++k) {
if (2 * (k + extra) < sizeof(Word) * 8)
continue;
if ((DWord(1) << k) <= DWord(n))
continue;
k += extra;
ASSERT_MSG(2 * k < sizeof(DWord) * 8, "k " + std::to_string(k));
DWord r = (DWord(1) << (2 * k)) / n;
ASSERT_MSG(DWord(r) < (DWord(1) << (sizeof(Word) * 8)),
"k " + std::to_string(k) + " n " + std::to_string(n));
ASSERT(2 * k >= sizeof(Word) * 8);
return std::make_tuple(Word(r), size_t(2 * k - sizeof(Word) * 8));
}
ASSERT(false);
}
template <bool Adjust>
static Word BarrettMod(DWord const & x, Word const & n, Word const & r, size_t s) {
//return x % n;
DWord const q = DWord(((TWord(x) * r) >> (sizeof(Word) * 8)) >> s);
Word t = Word(DWord(x) - q * n);
if constexpr(Adjust) {
Word const mask = ~Word(SWord(t - n) >> (sizeof(Word) * 8 - 1));
t -= mask & n;
}
return t;
}
static Word Adjust(Word const & a, Word const & n) {
return a >= n ? a - n : a;
}
Word modNn(DWord const & a) const { return BarrettMod<false>(a, N_, N_br_, N_bs_); }
Word modNa(DWord const & a) const { return BarrettMod<true>(a, N_, N_br_, N_bs_); }
Word modQn(DWord const & a) const { return BarrettMod<false>(a, q_, Q_br_, Q_bs_); }
Word modQa(DWord const & a) const { return BarrettMod<true>(a, q_, Q_br_, Q_bs_); }
static Word mod(DWord const & a, Word const & n) { return a % n; }
static ECPoint base_gen(size_t bits = 128, Word min_order_pfactor = 0) {
while (true) {
Word const N = rand_prime(bits);
if (mod(N, 4) != 3)
continue;
Word const
x0 = rand_range(1, N), y0 = rand_range(1, N), A = rand_range(1, N),
B = mod(mod(DWord(y0) * y0, N) + N * 2 - mod(DWord(mod(DWord(x0) * x0, N)) * x0, N) - mod(DWord(A) * x0, N), N),
y0_calc = pow_mod(mod(DWord(y0) * y0, N), (N + 1) >> 2, N);
if (y0 != y0_calc)
continue;
auto const bp = ECPoint(A, B, N, x0, y0, 0, true, true);
auto BpCheckOrder = [&]{
for (auto e: bp.q_ps())
if (e < min_order_pfactor)
return false;
return true;
};
if (!(bp.q() != 0 && !bp.q_ps().empty() && BpCheckOrder()))
continue;
ASSERT(bp.q() > 1 && bp * (bp.q() + 1) == bp);
return bp;
}
ASSERT(false);
}
ECPoint(Word A, Word B, Word N, Word x, Word y, Word q = 0, bool prepare = true, bool calc_q = false) {
if (prepare) {
A = mod(A, N); B = mod(B, N); x = mod(x, N); y = mod(y, N); q = mod(q, N);
ASSERT(mod(4 * mod(DWord(mod(DWord(A) * A, N)) * A, N) + 27 * mod(DWord(B) * B, N), N) != 0);
ASSERT(mod(N, 4) == 3);
if (!(x == 0 && y == 0)) {
ASSERT(mod(mod(DWord(y) * y, N) + 3 * N - mod(DWord(mod(DWord(x) * x, N)) * x, N) - mod(DWord(A) * x, N) - B, N) == 0);
ASSERT(y == pow_mod(mod(DWord(mod(DWord(x) * x, N)) * x, N) + mod(DWord(A) * x, N) + B, (N + 1) >> 2, N));
}
std::tie(N_br_, N_bs_) = BarrettRS(N);
if (q != 0)
std::tie(Q_br_, Q_bs_) = BarrettRS(q);
}
std::tie(A_, B_, N_, x_, y_, q_) = std::tie(A, B, N, x, y, q);
if (calc_q) {
std::tie(q_, q_ps_) = find_order();
if (q_ != 0)
std::tie(Q_br_, Q_bs_) = BarrettRS(q_);
}
}
auto copy() const {
return ECPoint(A_, B_, N_, x_, y_, q_, false);
}
auto inf() const {
return ECPoint(A_, B_, N_, 0, 0, q_, false);
}
static auto const & gen_primes(Word const B) {
thread_local std::vector<Word> ps = {2, 3};
for (Word p = ps.back() + 2; p <= B; p += 2) {
bool is_prime = true;
for (auto const e: ps) {
if (e * e > p)
break;
if (p % e == 0) {
is_prime = false;
break;
}
}
if (is_prime)
ps.push_back(p);
}
return ps;
}
std::tuple<Word, std::vector<Word>> find_order(Word _m = 1, std::vector<Word> _ps = {}) const {
ASSERT(_m <= 2 * N_);
if constexpr(1) {
auto r = *this;
try {
r *= _m;
} catch (InvError const &) {
return std::make_tuple(_m, _ps);
}
Word const B = 2 * N_;
for (Word const p: gen_primes(std::llround(std::cbrt(B) + 1))) {
if (p * p * p > B)
break;
ASSERT(p <= B);
size_t cnt = 0;
Word hi = 1;
try {
for (cnt = 1;; ++cnt) {
if (hi * p > B) {
cnt -= 1;
break;
}
hi *= p;
r *= p;
}
} catch (InvError const & ex) {
_ps.insert(_ps.begin(), cnt, p);
return find_order(hi * _m, _ps);
}
}
} else {
// Alternative slower way
auto r = *this;
for (Word i = 0;; ++i)
try {
r += *this;
} catch (InvError const &) {
_ps.clear();
return std::make_tuple(i + 2, _ps);
}
}
_ps.clear();
return std::make_tuple(Word(0), _ps);
}
static std::tuple<Word, SWord, SWord> EGCD(Word const & a, Word const & b) {
Word ro = 0, r = 0, qu = 0, re = 0;
SWord so = 0, s = 0;
std::tie(ro, r, so, s) = std::make_tuple(a, b, 1, 0);
while (r != 0) {
std::tie(qu, re) = std::make_tuple(ro / r, ro % r);
std::tie(ro, r) = std::make_tuple(r, re);
std::tie(so, s) = std::make_tuple(s, so - s * SWord(qu));
}
SWord const to = (SWord(ro) - SWord(a) * so) / SWord(b);
return std::make_tuple(ro, so, to);
}
Word inv(Word a, Word const & n, size_t any_n_q = 0) const {
ASSERT(n > 0);
a = any_n_q == 0 ? mod(a, n) : any_n_q == 1 ? modNa(a) : any_n_q == 2 ? modQa(a) : 0;
auto [gcd, s, t] = EGCD(a, n);
if (gcd != 1)
throw InvError(gcd, a, n);
a = Word(SWord(n) + s);
a = any_n_q == 0 ? mod(a, n) : any_n_q == 1 ? modNa(a) : any_n_q == 2 ? modQa(a) : 0;
return a;
}
Word invN(Word a) const { return inv(a, N_, 1); }
Word invQ(Word a) const { return inv(a, q_, 2); }
ECPoint & operator += (ECPoint const & o) {
if (x_ == 0 && y_ == 0) {
*this = o;
return *this;
}
if (o.x_ == 0 && o.y_ == 0)
return *this;
Word const Px = x_, Py = y_, Qx = o.x_, Qy = o.y_;
Word s = 0;
if ((Adjust(Px, N_) == Adjust(Qx, o.N_)) && (Adjust(Py, N_) == Adjust(Qy, o.N_)))
s = modNn(DWord(modNn(DWord(Px) * Px * 3) + A_) * invN(Py * 2));
else
s = modNn(DWord(Py + 2 * N_ - Qy) * invN(Px + 2 * N_ - Qx));
x_ = modNn(DWord(s) * s + 4 * N_ - Px - Qx);
y_ = modNn(DWord(s) * (Px + 2 * N_ - x_) + 2 * N_ - Py);
return *this;
}
ECPoint operator + (ECPoint const & o) const {
ECPoint c = *this;
c += o;
return c;
}
ECPoint & operator *= (Word k) {
auto const ok = k;
ASSERT(k > 0);
if (k == 1)
return *this;
k -= 1;
auto r = *this, s = *this;
while (true) {
if (k & 1) {
r += s;
if (k == 1)
break;
}
k >>= 1;
s += s;
}
if constexpr(0) {
auto r2 = *this;
for (u64 i = 1; i < ok; ++i)
r2 += *this;
ASSERT(r == r2);
}
*this = r;
return *this;
}
ECPoint operator * (Word k) const {
ECPoint r = *this;
r *= k;
return r;
}
bool operator == (ECPoint const & o) const {
return A_ == o.A_ && B_ == o.B_ && N_ == o.N_ && q_ == o.q_ &&
Adjust(x_, N_) == Adjust(o.x_, o.N_) && Adjust(y_, N_) == Adjust(o.y_, o.N_);
}
Word const & q() const { return q_; }
std::vector<Word> const & q_ps() const { return q_ps_; }
Word const & x() const { return x_; }
private:
Word A_ = 0, B_ = 0, N_ = 0, q_ = 0, x_ = 0, y_ = 0, N_br_ = 0, Q_br_ = 0;
size_t N_bs_ = 0, Q_bs_ = 0;
std::vector<Word> q_ps_;
};
Word pollard_rho_ec_log(ECPoint const & a, ECPoint const & b, ECPoint const & bp) {
// https://en.wikipedia.org/wiki/Pollard%27s_rho_algorithm_for_logarithms#Algorithm
for (u64 itry = 0;; ++itry) {
u64 i = 0;
try {
Word const part_p = bp.rand_prime_range(8, bp.q() >> 4);
auto ModQ = [&](Word n) {
return n >= bp.q() ? n - bp.q() : n;
};
auto f = [&](auto const & x) -> ECPoint {
Word const mod3 = (x.x() % part_p) % 3;
if (mod3 == 0)
return b + x;
else if (mod3 == 1)
return x + x;
else if (mod3 == 2)
return a + x;
else
ASSERT(false);
};
auto const g = [&](auto const & x, Word n) -> Word {
Word const mod3 = (x.x() % part_p) % 3;
if (mod3 == 0)
return n;
else if (mod3 == 1)
return ModQ(2 * n);
else if (mod3 == 2)
return ModQ(n + 1);
else
ASSERT(false);
};
auto const h = [&](auto const & x, Word n) -> Word {
Word const mod3 = (x.x() % part_p) % 3;
if (mod3 == 0)
return ModQ(n + 1);
else if (mod3 == 1)
return ModQ(2 * n);
else if (mod3 == 2)
return n;
else
ASSERT(false);
};
Word aim1 = 0, bim1 = 0, a2im2 = 0, b2im2 = 0, ai = 0, bi = 0, a2i = 0, b2i = 0;
ECPoint xim1 = bp.inf(), x2im2 = bp.inf(), xi = bp.inf(), x2i = bp.inf();
for (i = 1;; ++i) {
xi = f(xim1);
ai = g(xim1, aim1);
bi = h(xim1, bim1);
x2i = f(f(x2im2));
a2i = g(f(x2im2), g(x2im2, a2im2));
b2i = h(f(x2im2), h(x2im2, b2im2));
if (xi == x2i)
return bp.modQa(DWord(bp.invQ(bp.q() + bi - b2i)) * (bp.q() + a2i - ai));
std::tie(xim1, aim1, bim1) = std::tie(xi, ai, bi);
std::tie(x2im2, a2im2, b2im2) = std::tie(x2i, a2i, b2i);
}
} catch (ECPoint::InvError const & ex) {
g_log << "Try " << std::setfill(' ') << std::setw(4) << itry << ", Pollard-Rho failed, invert err at iter "
<< std::setw(7) << i << ", " << ex.what() << std::endl;
}
}
}
void test() {
auto const gtb = std::chrono::high_resolution_clock::now();
auto Time = [&]() -> double {
return std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::high_resolution_clock::now() - gtb).count() / 1'000.0;
};
double tb = 0;
size_t constexpr bits = 36;
g_log << "Generating base point, wait... " << std::flush;
tb = Time();
auto const bp = ECPoint::base_gen(bits, 50);
g_log << "Time " << Time() - tb << " sec" << std::endl;
g_log << "order " << bp.q() << " = ";
for (auto e: bp.q_ps())
g_log << e << " * " << std::flush;
g_log << std::endl;
Word const k0 = ECPoint::rand_range(1, bp.q()),
x = ECPoint::rand_range(1, bp.q());
auto a = bp * k0;
auto b = a * x;
g_log << "Searching discrete logarithm... " << std::endl;
tb = Time();
Word const x_calc = pollard_rho_ec_log(a, b, bp);
g_log << "Time " << Time() - tb << " sec" << std::endl;
g_log << "our x " << x << ", found x " << x_calc << std::endl;
g_log << "equal points: " << std::boolalpha << (a * x == a * x_calc) << std::endl;
}
int main() {
try {
test();
} catch (std::exception const & ex) {
g_log << "Exception: " << ex.what() << std::endl;
}
}
输出:
Generating base point, wait... Time 38.932 sec
order 195944962603297 = 401 * 4679 * 9433 * 11071 *
Searching discrete logarithm...
Time 69.791 sec
our x 15520105103514, found x 15520105103514
equal points: true