LLL Algorithm — Yet Another Paper-Reading Problem

Dec 9, 2023

So recently there has been a paper-reading problem in UCup Stage 12, 2023 ICPC Asia Hefei Regional: SQRT Problem

Problem Statement

Miss Burger has three positive integers \( n \), \( a \), and \( b \). She wants to find a positive integer solution \( x \) (\( 1 \leq x \leq n - 1 \)) that satisfies the following two conditions:

Additionally, it is guaranteed that \( n \) is an odd number and \( \gcd(a, n) = 1 \). Here \( \gcd(x, y) \) denotes the greatest common divisor of \( x \) and \( y \). We also guarantee that there exists a unique solution.

Input

The first line contains a single integer \( n \) (\( 3 \leq n \leq 10^{100} - 1 \)).

The second line contains a single integer \( a \) (\( 1 \leq a \leq n - 1 \)).

The third line contains a single integer \( b \) (\( 1 \leq b \leq n - 1 \)).

Output

Output a single integer denoting the solution \( x \).

Solution

First we read this wonderful lecture note from U. of Michigan by Prof. Chris Peikert to Lecture 3. This note tells us how to do a simple version of Coppersmith's method (while Lecture 4 tells us how to do a full version, which is highly interesting but not needed in this problem).

We observe that \( \left\lfloor \sqrt[3]{x^2} \right\rfloor = b \) allows us to bound \( x \) between a small range \( [l, r] \). How small is this range? We can use some math to bound it. Observe that

\[ r - l \le n - (n^\frac{2}{3} - 1)^{\frac{3}{2}} = \frac{n^2-(n^\frac{2}{3}-1)^3}{n+(n^\frac{2}{3} - 1)^{\frac{3}{2}}} < \frac{3n^\frac{4}{3}-2}{2n} < \frac{3}{2}n^\frac{1}{3}. \]

Therefore, we can rewrite \( x \) as \( x = l + k \) such that \( 0 \le k \le r - l \le \frac{3}{2}n^\frac{1}{3} \). This allows us to rewrite the original statement \( x^2 \equiv a \ (\text{mod} \ n) \) as \( k^2 + 2lk +l^2 - a \equiv 0 (\text{mod} \ n) \) with a relatively small \( k \).

We now turn our attention to the simple Coppersmith. The \( 3 \times 3 \) matrix in Lecture 3 allows us to find a solution of this equation if \( k < d=\frac{1}{6}N^\frac{1}{3} \), so we are off by a constant of 9. Nevertheless, we can partition the range \( [l, r] \) into at most 9 intervals \( \{ [l, l + d), [l + d, l + 2d), \cdots, [l + 8d, l + 9d) \} \), and run LLL 9 times to get a solution.

I used the LLL code from here with a little modification to implement the solution:

import sys from fractions import Fraction from typing import List, Sequence from math import isqrt class Vector(list): def __init__(self, x): super().__init__(map(Fraction, x)) def sdot(self) -> Fraction: return self.dot(self) def dot(self, rhs: "Vector") -> Fraction: rhs = Vector(rhs) assert len(self) == len(rhs) return sum(map(lambda x: x[0] * x[1], zip(self, rhs))) def proj_coff(self, rhs: "Vector") -> Fraction: rhs = Vector(rhs) assert len(self) == len(rhs) return self.dot(rhs) / self.sdot() def proj(self, rhs: "Vector") -> "Vector": rhs = Vector(rhs) assert len(self) == len(rhs) return self.proj_coff(rhs) * self def __sub__(self, rhs: "Vector") -> "Vector": rhs = Vector(rhs) assert len(self) == len(rhs) return Vector(x - y for x, y in zip(self, rhs)) def __mul__(self, rhs: Fraction) -> "Vector": return Vector(x * rhs for x in self) def __rmul__(self, lhs: Fraction) -> "Vector": return Vector(x * lhs for x in self) def __repr__(self) -> str: return "[{}]".format(", ".join(str(x) for x in self)) def gramschmidt(v: Sequence[Vector]) -> Sequence[Vector]: u: List[Vector] = [] for vi in v: ui = Vector(vi) for uj in u: ui = ui - uj.proj(vi) if any(ui): u.append(ui) return u def reduction( basis: Sequence[Sequence[int]], delta: Fraction = Fraction(3, 4) ) -> Sequence[Sequence[int]]: n = len(basis) basis = list(map(Vector, basis)) ortho = gramschmidt(basis) def mu(i: int, j: int) -> Fraction: return ortho[j].proj_coff(basis[i]) k = 1 while k < n: for j in range(k - 1, -1, -1): mu_kj = mu(k, j) if abs(mu_kj) > Fraction(1, 2): basis[k] = basis[k] - basis[j] * round(mu_kj) ortho = gramschmidt(basis) if ortho[k].sdot() >= (delta - mu(k, k - 1) ** 2) * ortho[k - 1].sdot(): k += 1 else: basis[k], basis[k - 1] = basis[k - 1], basis[k] ortho = gramschmidt(basis) k = max(k - 1, 1) return [list(map(int, b)) for b in basis] def icube(x): l, r, ans = 0, x, 0 while l <= r: m = (l + r) // 2 if m * m * m <= x: l, ans = m + 1, m else: r = m - 1 return ans input = sys.stdin.readline N = int(input()) A = int(input()) B = int(input()) def find_left(): l, r, ans = 0, N, 0 while l <= r: m = (l + r) // 2 if icube(m * m) < B: l = m + 1 else: r, ans = m - 1, m return ans L = find_left() def is_answer(x): return x >= 1 and x <= N - 1 and x * x % N == A and icube(x * x) == B def coppersmith(poly: List, mod: int, d: int): n = len(poly) pd = [d**i for i in range(n)] mat = [[poly[i] * pd[i] for i in range(n)]] + [ [pd[i] * mod if i == j else 0 for i in range(n)] for j in range(n - 1) ] mat = reduction(mat) return [mat[0][i] // pd[i] for i in range(n)] if N <= 100: for x in range(N): if is_answer(x): sys.stdout.write(f"{x}\n") exit(0) else: d = int(N ** (1 / 3) / 6) while True: poly = [L * L - A, 2 * L, 1] reduced = coppersmith(poly=poly, mod=N, d=d) a, b, c = reduced[2], reduced[1], reduced[0] ans = [] if a != 0: delta = b * b - 4 * a * c if delta >= 0: delta = isqrt(delta) ans = [(-b + delta) // (2 * a), (-b - delta) // (2 * a)] elif b != 0: ans = [-c // b] for x in ans: if is_answer(L + x): sys.stdout.write(f"{L + x}\n") exit(0) L = L + d

Some Optimization

In fact, we can observe that even if the shortest vector \( b_1 \)'s function \( h_1(x) \) we get from LLL does not solve the problem directly, it is so short that \( |h_1(k)| \le |h_1(9d)| \le 81n \). So we can try every equation with the form \( |h_1(k)| - in = 0 \) with \( 0\le i\le 81 \) to check for a solution.

The only problem here is that \( h_1(k) \) may devolve to a trivial function (i.e. \( h_1(k) = n \) is a basis but not interesting). So we may instead use the second row \( b_2 \) in the basis to get a non-trivial function \( h_2(x) \). Observe that since \( h_1(k)=n \), we can simply remove the first basis (i.e. \( b_1 \)) and the first column from the reduced matrix \( A \), and the remaining matrix \( A' \) still satisfies both LLL conditions, with \( b_2 \) now being the approximation. Therefore, we can try every equation with the form \( |h_2(k)| - in = 0 \) with \( 0\le i\le 81 \) to check for a solution.

def coppersmith(poly: List, mod: int, d: int): n = len(poly) pd = [d**i for i in range(n)] mat = [[poly[i] * pd[i] for i in range(n)]] + [ [pd[i] * mod if i == j else 0 for i in range(n)] for j in range(n - 1) ] mat = reduction(mat) if mat[0][1] == 0 and mat[0][2] == 0: return [mat[1][i] // pd[i] for i in range(n)] return [mat[0][i] // pd[i] for i in range(n)] if N <= 100: for x in range(N): if is_answer(x): sys.stdout.write(f"{x}\n") exit(0) else: d = int(N ** (1 / 3) / 6) poly = [L * L - A, 2 * L, 1] reduced = coppersmith(poly=poly, mod=N, d=d) a, b, c = reduced[2], reduced[1], reduced[0] - 81 * N while True: ans = [] if a != 0: delta = b * b - 4 * a * c if delta >= 0: delta = isqrt(delta) ans = [(-b + delta) // (2 * a), (-b - delta) // (2 * a)] elif b != 0: ans = [-c // b] for x in ans: if is_answer(L + x): sys.stdout.write(f"{L + x}\n") exit(0) c = c + N