ni...@lysator.liu.se (Niels Möller) writes: > So back to the drawing board.
I had to redo the splitting a bit in sqrt_nm1, and write some special case code for n = 3. But now I have code that survives some testing. I'm attaching my current algorithm description and code. The code currently doesn't really exploit cancellation. Also, it computes divisions like floor (B^{k-1} E / 2H) by zero-padding, and calling mpn_tdiv_qr. To make this fast, we need some variant of divappr_q which don't require any of the uninteresting low limbs. Or alternatively, resurrect the notion of fraction limbs. Compare to bdiv_q functions, which computes N / D mod B^n, where both inputs and the output are n limbs. A divappr_q function ought to compute an approximation of N B^n / D, but possibly with a little mmore flexibility in the input and output sizes. The most flexible way is perhaps to mimic the "fraction limbs"-interface. Say we have inputs N = {np, nn}, D = {dp, dn}, and a scaling factor k, and compute an approximation of Q = floor (N B^k / D) of qn = nn + dn + 1 + k limbs. Or we could tie things together a bit more, by requiring that dn = nn = qn + 1 or so. For the divisions in the sqrt algorithms, I'm not sure of exactly how the sizes of the numerator E, the denominator H, and the quotient, relate, but they ought to all be pretty close to k. Regards, /Niels
sqrt_nm1 (n, A): Input: A = {a_{n-1}, ..., a_0}, n >= 2 Output: X = {x_{n-2}, ..., x_0} \approx sqrt(B^{n-2} A) if (n == 2) x_0 <-- floor (sqrt (A)) return X = {x_0} if (n == 3) // Needs normalization c <-- count_leading_zeros (a_2) & ~1 A' <-- 2^c A x_1 = floor (sqrt ({a'_2, a'_1})) E = A' - B x_1^2 return X = (B x_1 + floor (E / 2 x_1)) 2^{-c/2} k <-- floor (n/2) H <-- sqrt_nm1 (n - (k-1), {a_{n-1},...,a_{k-1}}) // n - (k-1) limbs if (n odd) // n = 2k+1 // We have a (k+1)-limb H, and produce k-1 low limbs. E <-- B A - H^2 X <-- B^{k-1} H + floor (B^{k-1} E / 2H) // XXX else // n = 2k // We have a k-limb H, and produce k-1 low limbs E <-- A - H^2 X <-- B^{k-1} H + floor (B^{k-1} E / 2H) return X sqrt_n (n, A): Input: A = {a_{n-1}, ..., a_0}, n >= 1 Output: X = {x_{n-1}, ..., x_0} \approx sqrt(B^{n-1} A) if (n == 1) x_0 <-- floor (sqrt (a_0)) return X = {x_0} if (n == 2) h <-- floor (sqrt(A)) E <-- A - h*h X <-- sqrt(B) h + floor (sqrt(B) E / 2h) k <-- floor (n/2) H <-- sqrt_n(k+1, {a_{n-1},...,a_{n-1-k}}) // k+1 limbs if (n odd) // n = 2k+1 // We have (k+1)-limb H, and produce k low limbs E <-- A - H^2 X <-- B^k H + floor (B^k E / 2 H) else // n = 2k // We have (k+1)-limb H, and produce k-1 low limbs E <-- B A - H^2 X <-- B^{k-1} H + floor (B^{k-1} E / 2H) return X Size of remainder: Assume X = floor(sqrt(A)) = sqrt(A) - e, where A is n limbs. Define the remainder R = A - X^2. Then R = A - (sqrt(A) - e)^2 = 2 e sqrt(A) - e^2 = e (2 sqrt(A) - e) < 2 sqrt(A) If n is even, n = 2k, then A - X^2 < B^{k+1} If n is odd, n = 2k+1, then A - X^2 < B^{k+1} So in all cases, R < B^{floor(n/2) + 1}, and X < B^{ceil(n/2)}, and with almost half a limb margin. Correctness test: 0 > A - (X+1)^2 = A - X^2 - 2X - 1 = R - 2X - 1 R < 2X + 1 R <= 2X Small cases. Consider sqrtrem, size n. n = 1: sqrt_1 n = 2: sqrt_2 n = 3: sqrt_n(2) -> sqrt_2 + special update n = 4: sqrt_nm1(3) -> sqrt_nm1(2) -> sqrt_2 n = 5: sqrt_n (3) -> sqrt_n(2) -> sqrt_2 + special update n = 6: sqrt_nm1 (4) -> sqrt_nm1 (3) -> sqrt_nm1(2) -> sqrt_2
#include <assert.h> #include <stdio.h> #include <stdlib.h> #include <gmp.h> #if GMP_NUMB_BITS != 64 #error Unsupported limb size #endif #if !defined (__amd64__) #error Unsupported arch #endif /* From longlong.h */ typedef unsigned long int UDItype; #define add_ssaaaa(sh, sl, ah, al, bh, bl) \ __asm__ ("addq %5,%q1\n\tadcq %3,%q0" \ : "=r" (sh), "=&r" (sl) \ : "0" ((UDItype)(ah)), "rme" ((UDItype)(bh)), \ "%1" ((UDItype)(al)), "rme" ((UDItype)(bl))) #define sub_ddmmss(sh, sl, ah, al, bh, bl) \ __asm__ ("subq %5,%q1\n\tsbbq %3,%q0" \ : "=r" (sh), "=&r" (sl) \ : "0" ((UDItype)(ah)), "rme" ((UDItype)(bh)), \ "1" ((UDItype)(al)), "rme" ((UDItype)(bl))) #define umul_ppmm(w1, w0, u, v) \ __asm__ ("mulq %3" \ : "=a" (w0), "=d" (w1) \ : "%0" ((UDItype)(u)), "rm" ((UDItype)(v))) #define udiv_qrnnd(q, r, n1, n0, dx) /* d renamed to dx avoiding "=d" */\ __asm__ ("divq %4" /* stringification in K&R C */ \ : "=a" (q), "=d" (r) \ : "0" ((UDItype)(n0)), "1" ((UDItype)(n1)), "rm" ((UDItype)(dx))) #define assert_no_carry(x) do { \ mp_limb_t __assert_cy = (x); \ assert (!__assert_cy); \ } while (0) #define assert_carry(x) do { \ mp_limb_t __assert_cy = (x); \ assert (__assert_cy); \ } while (0) static int mpn_zero_p (const mp_limb_t *xp, mp_size_t n) { mp_size_t i; for (i = 0; i < n; i++) if (xp[i]) return 0; return 1; } static int verbose = 0; static mp_limb_t * xalloc_limbs (mp_size_t n) { mp_limb_t *p = malloc (n * sizeof (*p)); if (!p) { fprintf (stderr, "Virtual memory exhasuted!\n"); abort (); } return p; } /* Borrowed from factor.c */ static mp_limb_t sqrt_1 (mp_limb_t a) { mp_limb_t x; unsigned c; assert (a > 0); c = __builtin_clzl (a); x = (mp_limb_t) 1 << ((65 - c)/2); for (;;) { mp_limb_t y = (x + a/x) / 2; if (y >= x) return x; x = y; } } static mp_limb_t sqrt_2 (mp_limb_t ah, mp_limb_t al) { mp_limb_t x; unsigned c; assert (ah > 0); c = __builtin_clzl (ah) & ~1; if (c) { x = sqrt_1 ( (ah << c) + (al >> (64 - c))) + 1; x <<= (64 - c) / 2; } /* Need to handle these cases separately, to exclude x == ah in the loop */ else if (ah == GMP_NUMB_MAX) return GMP_NUMB_MAX; else if (ah == GMP_NUMB_MAX - 1) return GMP_NUMB_MAX - (al == 0); else x = (sqrt_1 (ah) << 32) | 0xffffffffUL; /* Do we need more than one iteration? */ for (;;) { mp_limb_t q, r, y; udiv_qrnnd (q, r, ah, al, x); y = x + q; if (y < x) y = (y / 2) | ((mp_limb_t) 1 << 63); else y = y / 2; if (y >= x) return x; x = y; } } /* Computes floor(sqrt(B^{n-2} A), an (n-1)-limb number. */ static void sqrt_nm1 (mp_limb_t *xp, const mp_limb_t *ap, mp_size_t n) { assert (n >= 2); assert (ap[n-1] > 0); if (n == 2) xp[0] = sqrt_2 (ap[1], ap[0]); else if (n == 3) { unsigned c; mp_limb_t ah, al, h, eh, el; mp_limb_t t[3]; ah = ap[2]; al = ap[1]; c = __builtin_clzl (ah) & ~1; if (c) { ah = (ah << c) | (al >> (64 - c)); al = (al << c) | (ap[0] >> (64 - c)); } h = sqrt_2 (ah, al); umul_ppmm (eh, el, h, h); sub_ddmmss (eh, el, ah, al, eh, el); t[0] = ap[0] << c; t[1] = el; t[2] = eh; mpn_divmod_1 (t, t, 3, h); c /= 2; mpn_rshift (t, t, 3, c + 1); assert (t[2] == 0); assert (t[1] <= 1); if (c) add_ssaaaa (xp[1], xp[0], t[1], t[0], h >> c, h << (64 - c)); else if (t[1] && h == GMP_NUMB_MAX) xp[0] = xp[1] = GMP_NUMB_MAX; else { xp[0] = t[0]; xp[1] = h + t[1]; assert (xp[1] >= h); } } else { mp_limb_t *ep; mp_limb_t *qp; mp_size_t k = n/2; sqrt_nm1 (xp + k-1, ap + k-1, n-k+1); if (n & 1) { ep = xalloc_limbs (2*n + k + 1); /* n + k + 1 */ qp = ep + n + k + 1; /* n */ /* E = A - B H^2 */ mpn_zero (ep, k-1); ep[k-1] = ap[0]; mpn_sqr (ep + k, xp + k-1, k + 1); /* FIXME: Handle unlikely underflow. */ assert (ep[n] == 0); assert_no_carry (mpn_sub_n (ep + k, ap+1, ep+k, n-1)); /* FIXME: Figure out how much cancellation to expect. */ mpn_tdiv_qr (qp, ep, 0, ep, n + k - 1, xp + k-1, k+1); qp[n-1] = 0; } else { ep = xalloc_limbs (2*n + k - 1); /* 3k-1 = n + k - 1 */ qp = ep + n + k - 1; /* E = A - H^2 */ mpn_zero (ep, k-1); mpn_sqr (ep + k-1, xp + k-1, k); if (mpn_sub_n (ep + k-1, ap, ep+k-1, n)) { /* Unlikely overflow, high part of root too large */ assert_no_carry (mpn_sub_1 (xp + k-1, xp + k-1, k, 1)); /* FIXME: Update, without re-squaring. */ mpn_sqr (ep + k-1, xp + k-1, k); assert_no_carry (mpn_sub_n (ep + k-1, ap, ep+k-1, n)); } /* FIXME: Figure out how much cancellation to expect. */ mpn_tdiv_qr (qp, ep, 0, ep, n + k - 1, xp + k-1, k); } mpn_rshift (qp, qp, n, 1); assert (qp[n-1] == 0); mpn_copyi (xp, qp, n-k-1); assert_no_carry (mpn_add_n (xp + n-k-1, xp + n-k-1, qp + n-k-1, k)); free (ep); } } /* Computes floor(sqrt(B^{n-1} A), an n-limb bumber. */ static void sqrt_n (mp_limb_t *xp, const mp_limb_t *ap, mp_size_t n) { assert (n >= 1); assert (ap[n-1] > 0); if (n == 1) xp[0] = sqrt_1 (ap[0]); else if (n == 2) { mp_limb_t h = sqrt_2 (ap[1], ap[0]); mp_limb_t xh, xl, q, r; mp_limb_t eh, el; umul_ppmm (eh, el, h, h); sub_ddmmss (eh, el, ap[1], ap[0], eh, el); assert (eh <= 1); /* A single Newton step may produce a root which is one too large */ udiv_qrnnd (q, r, (el >> 33) | (eh << 31), el << 31, h); xh = h >> 32; xl = h << 32; xl += q; xh += xl < q; xp[0] = xl; xp[1] = xh; } else { mp_limb_t *ep; mp_limb_t *qp; mp_size_t k = n/2; if (n & 1) { sqrt_n (xp + k, ap + k, k+1); ep = xalloc_limbs (2*n + 1); /* n + 1 */ qp = ep + n + 1; /* n */ mpn_zero (ep, k); mpn_sqr (ep + k, xp + k, k+1); assert (ep[n+k] <= 1); if (ep[n+k] > 0 || mpn_sub_n (ep + k, ap, ep + k, n)) { /* Unlikely overflow, high part of root too large */ assert_no_carry (mpn_sub_1 (xp + k, xp + k, k+1, 1)); /* FIXME: Update, without re-squaring. */ mpn_sqr (ep + k, xp + k, k+1); assert (ep[k+n] == 0); assert_no_carry (mpn_sub_n (ep + k, ap, ep + k, n)); } mpn_tdiv_qr (qp, ep, 0, ep, n+k, xp + k, k+1); mpn_rshift (qp, qp, n, 1); mpn_copyi (xp, qp, k); assert_no_carry (mpn_add_n (xp + k, xp + k, qp + k, n-k)); } else { sqrt_n (xp + k-1, ap + k-1, k+1); ep = xalloc_limbs (2*n + 1); /* n + 1 */ qp = ep + n + 1; /* n */ mpn_zero (ep, k-1); mpn_sqr (ep + k-1, xp + k-1, k+1); /* FIXME: Handle unlikely underflow. */ assert_no_carry (mpn_sub_n (ep + k, ap, ep + k, n)); if (ep[k-1]) { ep[k-1] = - ep[k-1]; assert_no_carry (mpn_sub_1 (ep + k, ep + k, n, 1)); } mpn_tdiv_qr (qp, ep, 0, ep, n + k, xp + k - 1, k + 1); mpn_rshift (qp, qp, n, 1); mpn_copyi (xp, qp, k - 1); assert_no_carry (mpn_add_n (xp + k - 1, xp + k - 1, qp + k, n-k+1)); } free (ep); } } /* Computes X = floor (sqrt ({ap, n})), ceil(n/2) limbs, and remainder R = A - X^2 */ static mp_size_t sqrtrem (mp_limb_t *xp, mp_limb_t *rp, const mp_limb_t *ap, mp_size_t n) { assert (n > 0); if (n == 1) { mp_limb_t x = sqrt_1 (ap[0]); mp_limb_t r = ap[0] - x*x; xp[0] = x; rp[0] = r; return (r > 0); } else if (n == 2) { mp_limb_t x = sqrt_2 (ap[1], ap[0]); mp_limb_t rh, rl; xp[0] = x; umul_ppmm (rh, rl, x, x); sub_ddmmss (rh, rl, ap[1], ap[0], rh, rl); rp[0] = rl; rp[1] = rh; if (!rh) return (rl > 0); return 2; } else { mp_size_t k = n/2; mp_size_t rn = k + 1; mp_size_t xn = (n+1)/2; mp_limb_t *tp; mp_limb_t cy; mp_limb_t adjust_m; if (n & 1) /* Compute sqrt based on the (n+1)/2 high limbs */ sqrt_n (xp, ap + k, k+1); else /* Compute sqrt based on the n/2+1 high limbs */ sqrt_nm1 (xp, ap + k-1, k+1); tp = xalloc_limbs (2*xn); mpn_sqr (tp, xp, xn); adjust_m = mpn_sub_n (tp, ap, tp, n); if (n & 1) { /* The extra limb is usually zero */ assert (tp[n] <= 1); adjust_m |= tp[n]; } /* Cancellation */ assert (adjust_m || mpn_zero_p (tp + rn, n - rn)); mpn_copyi (rp, tp, rn); /* T = 2X */ cy = mpn_lshift (tp, xp, xn, 1); tp[xn] = cy; /* Can't have carry if n is odd */ assert (! ((n&1) && cy)); if (adjust_m) { /* x is one too large */ if (verbose) fprintf (stderr, "-"); /* A - (X-1)^2 = A - X^2 + 2X - 1 T = 2X - 1 */ assert_no_carry (mpn_sub_1 (tp, tp, rn, 1)); assert_carry (mpn_add_n (rp, rp, tp, rn)); assert (mpn_cmp (rp, tp, rn) < 0); assert_no_carry (mpn_sub_1 (xp, xp, xn, 1)); } else { /* T = 2X + 1 */ tp[0] |= 1; while (mpn_cmp (rp, tp, rn) >= 0) { if (verbose) { static int adjust = 0; fprintf (stderr, "+"); adjust++; if (adjust > 100) abort (); } assert_no_carry (mpn_sub_n (rp, rp, tp, rn)); assert_no_carry (mpn_add_1 (xp, xp, xn, 1)); assert_no_carry (mpn_add_1 (tp, tp, rn, 2)); } } free (tp); while (rn > 0 && !rp[rn-1]) rn--; return rn; } } static int check_sqrt (const mpz_t a, const mpz_t x, const mpz_t r) { mpz_t t; mpz_init (t); mpz_add_ui (t, x, 1); mpz_mul (t, t, t); if (mpz_cmp (t, a) < 0) { fail: mpz_clear (t); return 0; } mpz_mul (t, x, x); if (mpz_cmp (t, a) > 0) goto fail; mpz_sub (t, a, t); if (mpz_cmp (t, r) != 0) goto fail; mpz_clear (t); return 1; } #define MAX_SIZE 6 int main (int argc, char **argv) { gmp_randstate_t rands; mp_limb_t x[(MAX_SIZE+1)/2]; mp_limb_t r[MAX_SIZE/2+1]; mpz_t a; unsigned c; if (argc > 1) verbose = 1; mpz_init (a); gmp_randinit_default (rands); for (c = 0; c < 10000; c++) { unsigned bits = 1 + gmp_urandomm_ui (rands, GMP_NUMB_BITS * MAX_SIZE); mpz_t t1, t2; mp_size_t an, xn, rn; mpz_rrandomb (a, rands, bits); an = mpz_size (a); if (verbose) gmp_fprintf (stderr, "%u, size %u, A = %Zd", c, (unsigned) an, a); xn = (an + 1)/2; rn = sqrtrem (x, r, mpz_limbs_read (a), an); assert (rn == 0 || r[rn-1]); if (verbose) fprintf (stderr, "\n"); if (!check_sqrt (a, mpz_roinit_n (t1, x, xn), mpz_roinit_n (t2, r, rn))) { gmp_fprintf (stderr, "Test %u failed, size %u:\n" " a: %Zd\n" " x: %Nd\n" " r: %Nd\n", c, (unsigned) an, a, x, xn, r, rn); abort (); } } mpz_clear (a); return EXIT_SUCCESS; }
-- Niels Möller. PGP-encrypted email is preferred. Keyid C0B98E26. Internet email is subject to wholesale government surveillance.
_______________________________________________ gmp-devel mailing list gmp-devel@gmplib.org https://gmplib.org/mailman/listinfo/gmp-devel