ni...@lysator.liu.se (Niels Möller) writes:

> So back to the drawing board.

I had to redo the splitting a bit in sqrt_nm1, and write some special
case code for n = 3. But now I have code that survives some testing.

I'm attaching my current algorithm description and code. The code
currently doesn't really exploit cancellation. Also, it computes
divisions like 

  floor (B^{k-1} E / 2H)

by zero-padding, and calling mpn_tdiv_qr. 

To make this fast, we need some variant of divappr_q which don't require
any of the uninteresting low limbs. Or alternatively, resurrect the
notion of fraction limbs.

Compare to bdiv_q functions, which computes N / D mod B^n, where both
inputs and the output are n limbs. A divappr_q function ought to compute
an approximation of N B^n / D, but possibly with a little mmore
flexibility in the input and output sizes.

The most flexible way is perhaps to mimic the "fraction
limbs"-interface. Say we have inputs N = {np, nn}, D = {dp, dn}, and a
scaling factor k, and compute an approximation of

   Q = floor (N B^k / D)

of qn = nn + dn + 1 + k limbs. Or we could tie things together a bit
more, by requiring that dn = nn = qn + 1 or so.

For the divisions in the sqrt algorithms, I'm not sure of exactly how
the sizes of the numerator E, the denominator H, and the quotient,
relate, but they ought to all be pretty close to k.

Regards,
/Niels

sqrt_nm1 (n, A):

  Input: A = {a_{n-1}, ..., a_0}, n >= 2
  Output: X = {x_{n-2}, ..., x_0} \approx sqrt(B^{n-2} A)

  if (n == 2)
    x_0 <-- floor (sqrt (A))
    return X = {x_0}

  if (n == 3) // Needs normalization
    c <-- count_leading_zeros (a_2) & ~1
    A' <-- 2^c A
    x_1 = floor (sqrt ({a'_2, a'_1}))
    E = A' - B x_1^2
    return X = (B x_1 + floor (E / 2 x_1)) 2^{-c/2}

  k <-- floor (n/2)
  H <-- sqrt_nm1 (n - (k-1), {a_{n-1},...,a_{k-1}}) // n - (k-1) limbs

  if (n odd)    // n = 2k+1
    // We have a (k+1)-limb H, and produce k-1 low limbs.
    E <-- B A - H^2
    X <-- B^{k-1} H + floor (B^{k-1} E / 2H) // XXX

  else          // n = 2k
    // We have a k-limb H, and produce k-1 low limbs
    E <-- A - H^2
    X <-- B^{k-1} H + floor (B^{k-1} E / 2H)

  return X

sqrt_n (n, A):

  Input: A = {a_{n-1}, ..., a_0}, n >= 1
  Output: X = {x_{n-1}, ..., x_0} \approx sqrt(B^{n-1} A)

  if (n == 1)
    x_0 <-- floor (sqrt (a_0))
    return X = {x_0}

  if (n == 2)
    h <-- floor (sqrt(A))
    E <-- A - h*h
    X <-- sqrt(B) h + floor (sqrt(B) E / 2h)

  k <-- floor (n/2)
  H <-- sqrt_n(k+1, {a_{n-1},...,a_{n-1-k}}) // k+1 limbs

  if (n odd)    // n = 2k+1
    // We have (k+1)-limb H, and produce k low limbs
    E <-- A - H^2
    X <-- B^k H + floor (B^k E / 2 H)

  else          // n = 2k
    // We have (k+1)-limb H, and produce k-1 low limbs
    E <-- B A - H^2
    X <-- B^{k-1} H + floor (B^{k-1} E / 2H)

  return X

Size of remainder:

  Assume X = floor(sqrt(A)) = sqrt(A) - e, where A is n limbs. Define
  the remainder R = A - X^2.

  Then R = A - (sqrt(A) - e)^2
         = 2 e sqrt(A) - e^2
         = e (2 sqrt(A) - e) < 2 sqrt(A)

  If n is even, n = 2k, then A - X^2 < B^{k+1}

  If n is odd, n = 2k+1, then A - X^2 < B^{k+1}

  So in all cases, R < B^{floor(n/2) + 1}, and X < B^{ceil(n/2)}, and
  with almost half a limb margin.

Correctness test:

  0 > A - (X+1)^2 = A - X^2 - 2X - 1 = R - 2X - 1

  R < 2X + 1

  R <= 2X

Small cases. Consider sqrtrem, size n.

  n = 1: sqrt_1
  n = 2: sqrt_2
  n = 3: sqrt_n(2)
         -> sqrt_2 + special update
  n = 4: sqrt_nm1(3)
         -> sqrt_nm1(2) -> sqrt_2
  n = 5: sqrt_n (3)
         -> sqrt_n(2)
         -> sqrt_2 + special update
  n = 6: sqrt_nm1 (4)
         -> sqrt_nm1 (3)
         -> sqrt_nm1(2) -> sqrt_2
#include <assert.h>
#include <stdio.h>
#include <stdlib.h>

#include <gmp.h>

#if GMP_NUMB_BITS != 64
#error Unsupported limb size
#endif

#if !defined (__amd64__)
#error Unsupported arch
#endif

/* From longlong.h */
typedef unsigned long int UDItype;

#define add_ssaaaa(sh, sl, ah, al, bh, bl) \
  __asm__ ("addq %5,%q1\n\tadcq %3,%q0"					\
	   : "=r" (sh), "=&r" (sl)					\
	   : "0"  ((UDItype)(ah)), "rme" ((UDItype)(bh)),		\
	     "%1" ((UDItype)(al)), "rme" ((UDItype)(bl)))
#define sub_ddmmss(sh, sl, ah, al, bh, bl) \
  __asm__ ("subq %5,%q1\n\tsbbq %3,%q0"					\
	   : "=r" (sh), "=&r" (sl)					\
	   : "0" ((UDItype)(ah)), "rme" ((UDItype)(bh)),		\
	     "1" ((UDItype)(al)), "rme" ((UDItype)(bl)))
#define umul_ppmm(w1, w0, u, v) \
  __asm__ ("mulq %3"							\
	   : "=a" (w0), "=d" (w1)					\
	   : "%0" ((UDItype)(u)), "rm" ((UDItype)(v)))
#define udiv_qrnnd(q, r, n1, n0, dx) /* d renamed to dx avoiding "=d" */\
  __asm__ ("divq %4"		     /* stringification in K&R C */	\
	   : "=a" (q), "=d" (r)						\
	   : "0" ((UDItype)(n0)), "1" ((UDItype)(n1)), "rm" ((UDItype)(dx)))

#define assert_no_carry(x) do {			\
    mp_limb_t __assert_cy = (x);		\
    assert (!__assert_cy);			\
  } while (0)

#define assert_carry(x) do {			\
    mp_limb_t __assert_cy = (x);		\
    assert (__assert_cy);			\
  } while (0)

static int
mpn_zero_p (const mp_limb_t *xp, mp_size_t n)
{
  mp_size_t i;
  for (i = 0; i < n; i++)
    if (xp[i])
      return 0;
  return 1;
}

static int verbose = 0;

static mp_limb_t *
xalloc_limbs (mp_size_t n)
{
  mp_limb_t *p = malloc (n * sizeof (*p));
  if (!p)
    {
      fprintf (stderr, "Virtual memory exhasuted!\n");
      abort ();
    }
  return p;
}

/* Borrowed from factor.c */
static mp_limb_t
sqrt_1 (mp_limb_t a)
{
  mp_limb_t x;
  unsigned c;
  assert (a > 0);
  c = __builtin_clzl (a);
  x = (mp_limb_t) 1 << ((65 - c)/2);
  for (;;)
    {
      mp_limb_t y = (x + a/x) / 2;
      if (y >= x)
	return x;

      x = y;
    }
}

static mp_limb_t
sqrt_2 (mp_limb_t ah, mp_limb_t al)
{
  mp_limb_t x;
  unsigned c;

  assert (ah > 0);

  c = __builtin_clzl (ah) & ~1;

  if (c)
    {
      x = sqrt_1 ( (ah << c) + (al >> (64 - c))) + 1;
      x <<= (64 - c) / 2;
    }
  /* Need to handle these cases separately, to exclude x == ah in the
     loop */
  else if (ah == GMP_NUMB_MAX)
    return GMP_NUMB_MAX;
  else if (ah == GMP_NUMB_MAX - 1)
    return GMP_NUMB_MAX - (al == 0);
  else
    x = (sqrt_1 (ah) << 32) | 0xffffffffUL;

  /* Do we need more than one iteration? */
  for (;;)
    {
      mp_limb_t q, r, y;
      udiv_qrnnd (q, r, ah, al, x);
      y = x + q;
      if (y < x)
	y = (y / 2) | ((mp_limb_t) 1 << 63);
      else
	y = y / 2;

      if (y >= x)
	return x;

      x = y;
    }
}

/* Computes floor(sqrt(B^{n-2} A), an (n-1)-limb number. */
static void
sqrt_nm1 (mp_limb_t *xp, const mp_limb_t *ap, mp_size_t n)
{
  assert (n >= 2);
  assert (ap[n-1] > 0);
  if (n == 2)
    xp[0] = sqrt_2 (ap[1], ap[0]);
  else if (n == 3)
    {
      unsigned c;
      mp_limb_t ah, al, h, eh, el;
      mp_limb_t t[3];
      ah = ap[2];
      al = ap[1];
      c = __builtin_clzl (ah) & ~1;
      if (c)
	{
	  ah = (ah << c) | (al >> (64 - c));
	  al = (al << c) | (ap[0] >> (64 - c));
	}
      h = sqrt_2 (ah, al);
      umul_ppmm (eh, el, h, h);
      sub_ddmmss (eh, el, ah, al, eh, el);
      t[0] = ap[0] << c;
      t[1] = el;
      t[2] = eh;
      mpn_divmod_1 (t, t, 3, h);
      c /= 2;
      mpn_rshift (t, t, 3, c + 1);
      assert (t[2] == 0);
      assert (t[1] <= 1);
      if (c)
	add_ssaaaa (xp[1], xp[0], t[1], t[0], h >> c, h << (64 - c));
      else if (t[1] && h == GMP_NUMB_MAX)
	xp[0] = xp[1] = GMP_NUMB_MAX;
      else
	{
	  xp[0] = t[0];
	  xp[1] = h + t[1];
	  assert (xp[1] >= h);
	}
    }
  else
    {
      mp_limb_t *ep;
      mp_limb_t *qp;
      mp_size_t k = n/2;

      sqrt_nm1 (xp + k-1, ap + k-1, n-k+1);

      if (n & 1)
	{
	  ep = xalloc_limbs (2*n + k + 1);  /* n + k + 1 */
	  qp = ep + n + k + 1; /* n */

	  /* E = A - B H^2 */
	  mpn_zero (ep, k-1);
	  ep[k-1] = ap[0];
	  mpn_sqr (ep + k, xp + k-1, k + 1);
	  /* FIXME: Handle unlikely underflow. */
	  assert (ep[n] == 0);
	  assert_no_carry (mpn_sub_n (ep + k, ap+1, ep+k, n-1));
	  /* FIXME: Figure out how much cancellation to expect. */
	  mpn_tdiv_qr (qp, ep, 0, ep, n + k - 1, xp + k-1, k+1);
	  qp[n-1] = 0;
	}
      else
	{
	  ep = xalloc_limbs (2*n + k - 1); /* 3k-1 = n + k - 1 */
	  qp = ep + n + k - 1;

	  /* E = A - H^2 */
	  mpn_zero (ep, k-1);
	  mpn_sqr (ep + k-1, xp + k-1, k);

	  if (mpn_sub_n (ep + k-1, ap, ep+k-1, n))
	    {
	      /* Unlikely overflow, high part of root too large */
	      assert_no_carry (mpn_sub_1 (xp + k-1, xp + k-1, k, 1));
	      /* FIXME: Update, without re-squaring. */
	      mpn_sqr (ep + k-1, xp + k-1, k);
	      assert_no_carry (mpn_sub_n (ep + k-1, ap, ep+k-1, n));
	    }
	  /* FIXME: Figure out how much cancellation to expect. */
	  mpn_tdiv_qr (qp, ep, 0, ep, n + k - 1, xp + k-1, k);
	}

      mpn_rshift (qp, qp, n, 1);
      assert (qp[n-1] == 0);
      mpn_copyi (xp, qp, n-k-1);
      assert_no_carry (mpn_add_n (xp + n-k-1, xp + n-k-1, qp + n-k-1, k));

      free (ep);
    }
}

/* Computes floor(sqrt(B^{n-1} A), an n-limb bumber. */
static void
sqrt_n (mp_limb_t *xp, const mp_limb_t *ap, mp_size_t n)
{
  assert (n >= 1);
  assert (ap[n-1] > 0);

  if (n == 1)
    xp[0] = sqrt_1 (ap[0]);

  else if (n == 2)
    {
      mp_limb_t h = sqrt_2 (ap[1], ap[0]);
      mp_limb_t xh, xl, q, r;
      mp_limb_t eh, el;
      umul_ppmm (eh, el, h, h);
      sub_ddmmss (eh, el, ap[1], ap[0], eh, el);
      assert (eh <= 1);

      /* A single Newton step may produce a root which is one too large */
      udiv_qrnnd (q, r, (el >> 33) | (eh << 31), el << 31, h);
      xh = h >> 32;
      xl = h << 32;
      xl += q;
      xh += xl < q;

      xp[0] = xl;
      xp[1] = xh;
    }
  else
    {
      mp_limb_t *ep;
      mp_limb_t *qp;
      mp_size_t k = n/2;

      if (n & 1)
	{
	  sqrt_n (xp + k, ap + k, k+1);

	  ep = xalloc_limbs (2*n + 1); /* n + 1 */
	  qp = ep + n + 1; /* n */
	  mpn_zero (ep, k);
	  mpn_sqr (ep + k, xp + k, k+1);
	  assert (ep[n+k] <= 1);

	  if (ep[n+k] > 0 || mpn_sub_n (ep + k, ap, ep + k, n))
	    {
	      /* Unlikely overflow, high part of root too large */
	      assert_no_carry (mpn_sub_1 (xp + k, xp + k, k+1, 1));
	      /* FIXME: Update, without re-squaring. */
	      mpn_sqr (ep + k, xp + k, k+1);
	      assert (ep[k+n] == 0);
	      assert_no_carry (mpn_sub_n (ep + k, ap, ep + k, n));
	    }

	  mpn_tdiv_qr (qp, ep, 0, ep, n+k, xp + k, k+1);
	  mpn_rshift (qp, qp, n, 1);

	  mpn_copyi (xp, qp, k);
	  assert_no_carry (mpn_add_n (xp + k, xp + k, qp + k, n-k));
	}
      else
	{
	  sqrt_n (xp + k-1, ap + k-1, k+1);

	  ep = xalloc_limbs (2*n + 1); /* n + 1 */
	  qp = ep + n + 1; /* n */
	  mpn_zero (ep, k-1);
	  mpn_sqr (ep + k-1, xp + k-1, k+1);
	  /* FIXME: Handle unlikely underflow. */
	  assert_no_carry (mpn_sub_n (ep + k, ap, ep + k, n));
	  if (ep[k-1])
	    {
	      ep[k-1] = - ep[k-1];
	      assert_no_carry (mpn_sub_1 (ep + k, ep + k, n, 1));
	    }
	  mpn_tdiv_qr (qp, ep, 0, ep, n + k, xp + k - 1, k + 1);
	  mpn_rshift (qp, qp, n, 1);

	  mpn_copyi (xp, qp, k - 1);
	  assert_no_carry (mpn_add_n (xp + k - 1, xp + k - 1, qp + k, n-k+1));
	}
      free (ep);
    }
}

/* Computes X = floor (sqrt ({ap, n})), ceil(n/2) limbs, and remainder
   R = A - X^2 */
static mp_size_t
sqrtrem (mp_limb_t *xp, mp_limb_t *rp, const mp_limb_t *ap, mp_size_t n)
{
  assert (n > 0);
  if (n == 1)
    {
      mp_limb_t x = sqrt_1 (ap[0]);
      mp_limb_t r = ap[0] - x*x;
      xp[0] = x;
      rp[0] = r;
      return (r > 0);
    }
  else if (n == 2)
    {
      mp_limb_t x = sqrt_2 (ap[1], ap[0]);
      mp_limb_t rh, rl;
      xp[0] = x;
      umul_ppmm (rh, rl, x, x);
      sub_ddmmss (rh, rl, ap[1], ap[0], rh, rl);
      rp[0] = rl;
      rp[1] = rh;
      if (!rh)
	return (rl > 0);
      return 2;
    }
  else
    {
      mp_size_t k = n/2;
      mp_size_t rn = k + 1;
      mp_size_t xn = (n+1)/2;
      mp_limb_t *tp;
      mp_limb_t cy;
      mp_limb_t adjust_m;

      if (n & 1)
	/* Compute sqrt based on the (n+1)/2 high limbs */
	sqrt_n (xp, ap + k, k+1);
      else
	/* Compute sqrt based on the n/2+1 high limbs */
	sqrt_nm1 (xp, ap + k-1, k+1);

      tp = xalloc_limbs (2*xn);

      mpn_sqr (tp, xp, xn);

      adjust_m = mpn_sub_n (tp, ap, tp, n);
      if (n & 1)
	{
	  /* The extra limb is usually zero */
	  assert (tp[n] <= 1);
	  adjust_m |= tp[n];
	}
      /* Cancellation */
      assert (adjust_m || mpn_zero_p (tp + rn, n - rn));
      mpn_copyi (rp, tp, rn);

      /* T = 2X */
      cy = mpn_lshift (tp, xp, xn, 1);
      tp[xn] = cy;
      /* Can't have carry if n is odd */
      assert (! ((n&1) && cy));

      if (adjust_m)
	{
	  /* x is one too large */
	  if (verbose)
	      fprintf (stderr, "-");

	  /* A - (X-1)^2 = A - X^2 + 2X - 1
	     T = 2X - 1 */
	  assert_no_carry (mpn_sub_1 (tp, tp, rn, 1));
	  assert_carry (mpn_add_n (rp, rp, tp, rn));
	  assert (mpn_cmp (rp, tp, rn) < 0);

	  assert_no_carry (mpn_sub_1 (xp, xp, xn, 1));
	}

      else
	{
	  /* T = 2X + 1 */
	  tp[0] |= 1;
	  while (mpn_cmp (rp, tp, rn) >= 0)
	    {
	      if (verbose)
		{
		  static int adjust = 0;
		  fprintf (stderr, "+");
		  adjust++;
		  if (adjust > 100)
		    abort ();
		}

	      assert_no_carry (mpn_sub_n (rp, rp, tp, rn));
	      assert_no_carry (mpn_add_1 (xp, xp, xn, 1));
	      assert_no_carry (mpn_add_1 (tp, tp, rn, 2));
	    }
	}
      free (tp);

      while (rn > 0 && !rp[rn-1])
	rn--;
      return rn;
    }
}

static int
check_sqrt (const mpz_t a, const mpz_t x, const mpz_t r)
{
  mpz_t t;

  mpz_init (t);
  mpz_add_ui (t, x, 1);
  mpz_mul (t, t, t);
  if (mpz_cmp (t, a) < 0)
    {
    fail:
      mpz_clear (t);
      return 0;
    }

  mpz_mul (t, x, x);
  if (mpz_cmp (t, a) > 0)
    goto fail;

  mpz_sub (t, a, t);
  if (mpz_cmp (t, r) != 0)
    goto fail;

  mpz_clear (t);
  return 1;
}

#define MAX_SIZE 6

int
main (int argc, char **argv)
{
  gmp_randstate_t rands;
  mp_limb_t x[(MAX_SIZE+1)/2];
  mp_limb_t r[MAX_SIZE/2+1];
  mpz_t a;

  unsigned c;

  if (argc > 1)
    verbose = 1;

  mpz_init (a);
  gmp_randinit_default (rands);
  for (c = 0; c < 10000; c++)
    {
      unsigned bits = 1 + gmp_urandomm_ui (rands, GMP_NUMB_BITS * MAX_SIZE);
      mpz_t t1, t2;
      mp_size_t an, xn, rn;

      mpz_rrandomb (a, rands, bits);
      an = mpz_size (a);
      if (verbose)
	gmp_fprintf (stderr, "%u, size %u, A = %Zd", c, (unsigned) an, a);
      xn = (an + 1)/2;
      rn = sqrtrem (x, r, mpz_limbs_read (a), an);
      assert (rn == 0 || r[rn-1]);

      if (verbose)
	fprintf (stderr, "\n");

      if (!check_sqrt (a, mpz_roinit_n (t1, x, xn), mpz_roinit_n (t2, r, rn)))
	{
	  gmp_fprintf (stderr, "Test %u failed, size %u:\n"
		       "  a: %Zd\n"
		       "  x: %Nd\n"
		       "  r: %Nd\n",
		       c, (unsigned) an, a, x, xn, r, rn);
	  abort ();
	}
    }

  mpz_clear (a);
  return EXIT_SUCCESS;
}
-- 
Niels Möller. PGP-encrypted email is preferred. Keyid C0B98E26.
Internet email is subject to wholesale government surveillance.
_______________________________________________
gmp-devel mailing list
gmp-devel@gmplib.org
https://gmplib.org/mailman/listinfo/gmp-devel

Reply via email to