ni...@lysator.liu.se (Niels Möller) writes:

> Correction: Each iteration gets from \ell to 2 \ell-2, and it needs 2
> \ell-1 bits of precision for intermediate values.

Yet another correction, after more work on the implementation.

For numbers x = 1 (mod 8), there are four square roots mod 2^k. If x is
one of them, the other three are

  x + 2^{k-1}
  -x
  -x + 2^{k-1}

So the top bit doesn't matter. This implies that no extra bit is needed
for temporary values: If we start with \ell bits, we can do the
iteration using 2\ell-2 bits for all values. And it doesn't matter
which value we get in the most significant bit after the final shift
right, we get a square root (mod 2^{2\ell - 2}) in either case.

And we can canonicalize the returned (mod 2^k) square root by saying
that it must be = 1 (mod 4), and < 2^{k-1} (i.e., most significant bit
always zero).

I'm attaching my current code for both square root and nth root. Appears
to work fine. None of them yet use any wraparound tricks, but they do
take some advantage from cancellation.

In the nth-root code, I use the iteration 

     x' <-- x - x * (a^{n-1} x^n - 1) / n

which converges to a^{1/n-1}, so the nth root is recovered as a * x. The
number of correct bits is *exactly* doubled in each iteration, so I
didn't see much point in using bit counts, instead I use a limb count
for the desired precision. Perhaps this iteration could be useful in the
euclidean case too, I haven't investigated that.

The factor a^{n-1} is a loop invariant, so it can be computed (at the
highest needed precision) before the loop.

For the power x^n, I currently use mpn_powlo. But the least significant
half is known from the previous iteration, so wraparound would be
desirable. To me it would make some sense with a pow function for (mod
(2^k - 1)), i.e., using mpn_mulmod_bnm1 and mpn_sqrmod_bnm1. Or is there
any easier way to take advantage of wraparound for that computation?

Maybe it would be a good idea to write a general or "abstract" pow
function taking multiplication and squaring functions as arguments?

We currently have modular exponentation, powlo and "regular" powering
with no reduction of any kind. I'm suggesting a pow_modbnm1. For
euclidean square root, and for mpfr, it might also be useful with a
pow_high, keeping only the n most significant limbs of each product, and
returning the number of discarded low limbs. Another potential use is
powm with moduli of special form, where the reduction can be done
cheaper than with montgomery redc. Maybe the function could even be used
for more complicated groups, e.g., if implementing elliptic curve
operations on top of gmp. Or maybe it's easy enough to duplicate the
code for each useful case, perhaps sharing some bit extraction macros in
gmp-impl.h.

Regards,
/Niels

#include <stdio.h>
#include <stdlib.h>

#include "gmp.h"
#include "gmp-impl.h"

/* Computes a^e (mod B). Uses right-to-left binary algorithm, since
   typical use will have e small. */
static mp_limb_t
powlimb (mp_limb_t a, mp_limb_t e)
{
  mp_limb_t r = 1;
  mp_limb_t s = a;

  for (r = 1, s = a; e > 0; e >>= 1, s *= s)
    if (e & 1)
      r *= s;

  return r;
}

/* Computes the nth root of A, mod B^k. Both A and n must be odd.

   Uses the iteration

     x' <-- x - x * (a^{n-1} x^n - 1) / n

   converging to a^{1/n - 1}.

   If

     a^{n-1} x^n = 1 (mod 2^\ell),

   then

     a^{n-1} x'^n = 1 (mod 2^{2\ell}),
*/

void
mpn_broot (mp_ptr rp, mp_srcptr ap, mp_size_t an, mp_limb_t n)
{
  mp_size_t sizes[GMP_LIMB_BITS * 2];  
  mp_ptr anm1, tp, xp, xnp, ep;
  mp_limb_t a0, r0, nm1, ninv;
  mp_size_t xn;
  unsigned i;

  TMP_DECL;

  ASSERT (an > 0);
  ASSERT (ap[0] & 1);
  ASSERT (n & 1);
  ASSERT (n >= 3);

  TMP_MARK;
  
  anm1 = TMP_ALLOC_LIMBS (4*an);
  tp = anm1 + an;

  nm1 = n-1;
  mpn_powlo (anm1, ap, &nm1, 1, an, tp); /* 3 an scratch space */

  a0 = ap[0];
  binvert_limb (ninv, n);

  /* 4 bits: a^{1/n - 1} (mod 16):

       a mod 8
       1 3 5 7   
     1 1 1 1 1
     3 1 9 9 1
  */
  r0 = 1 + (((n << 2) & ((a0 << 1) ^ (a0 << 2))) & 8);
  r0 = ninv * r0 * (n+1 - anm1[0] * powlimb (r0, n)); /* 8 bits */
  r0 = ninv * r0 * (n+1 - anm1[0] * powlimb (r0, n)); /* 16 bits */
  r0 = ninv * r0 * (n+1 - anm1[0] * powlimb (r0, n)); /* 32 bits */
#if GMP_NUMB_BITS > 32
  r0 = ninv * r0 * (n+1 - anm1[0] * powlimb (r0, n)); /* 64 bits */
#endif

  if (an == 1)
    {
      TMP_FREE;
      rp[0] = r0 * a0;
      return;
    }

  xp = TMP_ALLOC_LIMBS (3*an);
  xnp = xp + an;
  ep = xp + 2*an;

  /* FIXME: Possible to this on the fly with some bit fiddling. */
  for (i = 0; an > 1; an = (an + 1)/2)
    sizes[i++] = an;

  xp[0] = r0;
  xn = 1;

  while (i-- > 0)
    {
      /* Compute x^n. What's the best way to handle the doubled
         precision? Could do the complete powering using
         wraparound. */
      MPN_ZERO (xp + xn, sizes[i] - xn);
      mpn_powlo (xnp, xp, &n, 1, sizes[i], tp); 

      /* Multiply by a^{n-1}. Can use wraparound; low part is
         000...01. */

      mpn_mullo_n (ep, xnp, anm1, sizes[i]);
      ASSERT (ep[0] == 1);
      ASSERT (xn == 1 || mpn_zero_p (ep + 1, xn - 1));

      ASSERT (sizes[i] <= 2*xn);
      mpn_pi1_bdiv_q_1 (ep, ep + xn, sizes[i] - xn, n, ninv, 0);

      /* Multiply by x, plain mullo. */
      mpn_mullo_n (xp + xn, ep, xp, sizes[i] - xn);

      /* FIXME: Avoid negation, e.g., by using a bdiv_q_1 variant
         returning -q. */
      mpn_neg (xp + xn, xp + xn, sizes[i] - xn);

      xn = sizes[i];
    }
  mpn_mullo_n (rp, ap, xp, xn);
  TMP_FREE;
}

#define MAX_LIMBS 150

int
main (int argc, char **argv)
{
  mp_limb_t a[MAX_LIMBS];
  mp_limb_t r[MAX_LIMBS];
  mp_limb_t t[4*MAX_LIMBS];

  gmp_randstate_t rands;
  unsigned i;
  
  gmp_randinit_default (rands);

  for (i = 0; i < 5000; i++)
    {
      mp_size_t s;
      mp_limb_t n;
      int c;

      s = 1 + gmp_urandomm_ui (rands, MAX_LIMBS);

      if (i & 1)
        mpn_random2 (a, s);
      else
        mpn_random (a, s);

      a[0] |= 1;

      if (i < 100)
        n = 3 + 2*i;
      else
        {
          mpn_random (&n, 1);
          if (n < 3)
            n = 3;
          else
            n |= 1;
        }

      mpn_broot (r, a, s, n);
      mpn_powlo (t, r, &n, 1, s, t + s);

      MPN_CMP (c, t, a, s);
      if (c != 0)
        {
          gmp_fprintf (stderr,
                       "mpn_broot returned bad result: %d limbs\n",
                       (int) s);
          gmp_fprintf (stderr, "n   = %Mx\n", n, s);
          gmp_fprintf (stderr, "a   = %Nx\n", a, s);
          gmp_fprintf (stderr, "r   = %Nx\n", r, s);
          gmp_fprintf (stderr, "r^n = %Nx\n", t, s);
          abort ();
        }
    }
  return EXIT_SUCCESS;
}
#include <stdio.h>
#include <stdlib.h>

#include "gmp.h"
#include "gmp-impl.h"


static void
mpn_mullo (mp_ptr rp, mp_srcptr ap, mp_size_t an, mp_srcptr bp, mp_size_t bn)
{
  ASSERT (an >= bn);
  ASSERT (bn > 0);
  if (an == bn)
    mpn_mullo_n (rp, ap, bp, an);
    
  else
    {
      mp_size_t fn = an - bn;
      mp_ptr tp;
      TMP_DECL;
      TMP_MARK;
      tp = TMP_ALLOC_LIMBS (bn);
      if (fn >= bn)
        mpn_mul (rp, ap, fn, bp, bn);
      else
        mpn_mul (rp, bp, bn, ap, fn);
      mpn_mullo_n (tp, ap + fn, bp, bn);
      mpn_add_n (rp + fn, rp + fn, tp, bn);
      TMP_FREE;
    }
}

/* Element i is (8i+1)^{-1/2} (mod 2^{10}) */
static const unsigned short bsqrt_table[0x80] =
  {
    0x1,0x155,0x159,0xcd,0x171,0x5,0x49,0xfd,
    0x1e1,0x1b5,0x39,0x2d,0x151,0x65,0x129,0x5d,
    0x1c1,0x15,0x119,0x18d,0x131,0xc5,0x9,0x1bd,
    0x1a1,0x75,0x1f9,0xed,0x111,0x125,0xe9,0x11d,
    0x181,0xd5,0xd9,0x4d,0xf1,0x185,0x1c9,0x7d,
    0x161,0x135,0x1b9,0x1ad,0xd1,0x1e5,0xa9,0x1dd,
    0x141,0x195,0x99,0x10d,0xb1,0x45,0x189,0x13d,
    0x121,0x1f5,0x179,0x6d,0x91,0xa5,0x69,0x9d,
    0x101,0x55,0x59,0x1cd,0x71,0x105,0x149,0x1fd,
    0xe1,0xb5,0x139,0x12d,0x51,0x165,0x29,0x15d,
    0xc1,0x115,0x19,0x8d,0x31,0x1c5,0x109,0xbd,
    0xa1,0x175,0xf9,0x1ed,0x11,0x25,0x1e9,0x1d,
    0x81,0x1d5,0x1d9,0x14d,0x1f1,0x85,0xc9,0x17d,
    0x61,0x35,0xb9,0xad,0x1d1,0xe5,0x1a9,0xdd,
    0x41,0x95,0x199,0xd,0x1b1,0x145,0x89,0x3d,
    0x21,0xf5,0x79,0x16d,0x191,0x1a5,0x169,0x19d,
  };

/* Computes R such that R^2 = A (mod 2^b), if one exists. Of the four
   square roots, returns the one which is < 2^{b-1} and == 1 (mod 4).
   Size of both R and A is ceil(b / GMP_NUMB_BITS) limbs. Currently, A
   must be odd. */

/* Uses the iteration

     x' <-- x - x * (a x^2 - 1) / 2

   converging to x = a^{-1/2}. If

     a x^2 = 1 (mod 2^k),

   then

     a x'^2 = 1 (mod 2^{2k - 2})

   Note that when dividing by two, the new value of the high bit
   doesn't matter. */

int
mpn_bsqrt(mp_ptr rp, mp_srcptr ap, mp_bitcnt_t b)
{
  mp_bitcnt_t c, k;
  mp_size_t an, n;
  mp_limb_t a0, r0;
  mp_size_t sizes[GMP_LIMB_BITS * 2];  
  mp_ptr tp, xp, ep;
  unsigned i;
  TMP_DECL;

  ASSERT (b > 0);

  a0 = ap[0];
  ASSERT (a0 & 1);

  if (b == 1)
    {
      rp[0] = 1;
      return 1;
    }
  else if (b == 2)
    {
      if ( (a0 & 3) != 1)
        return 0;

      rp[0] = 1;
      return 1;
    }
  if ( (a0 & 7)  != 1)
    return 0;

  r0 = bsqrt_table[(a0 >> 3) & 0x7f]; /* 10 bits */
  r0 = r0*(3 - a0*r0*r0) / 2; /* 18 bits */
  r0 = r0*(3 - a0*r0*r0) / 2; /* 34 bits, assuming 35 bit arithmetic */
  if (GMP_LIMB_BITS >= 34)
    {
      r0 = r0*(3 - a0*r0*r0) / 2; /* 66 bits, assuming 67 bit
                                 arithmetic */
    }
  c = GMP_LIMB_BITS - 1;

  if (b <= c)
    {
      rp[0] = (r0*a0) & (((mp_limb_t) 1 << (b-1)) - 1);
      return 1;
    }

  for (k = b, i = 0; k > c; k = (k+1)/2 + 1)
    {
      /* k is the desired number of bits in each iteration, and also
         the size neeed for intermediate values. */
      sizes[i++] = (k + GMP_NUMB_BITS - 1) / GMP_NUMB_BITS;
    }

  TMP_MARK;

  tp = TMP_ALLOC_LIMBS (3*sizes[0] + 1);/* sizes[0] + 2 */
  ep = tp + sizes[0] + 1;               /* sizes[0] */
  xp = ep + sizes[0];                   /* sizes[0] */

  xp[0] = r0;
  n = 1;
  an = (b + GMP_NUMB_BITS - 1) / GMP_NUMB_BITS;

  while (i-- > 0)
    {
      mp_size_t zn, en;

      /* Iterate x <-- x - x (a x^2 - 1)/2
         First compute e = (a x^2 - 1) / 2 */

      /* Adjustment need in the case that we used only a single bit of
         precision in the highest limb in the previous iteration. */
      if (2*n == sizes[i] + 2)
        n--;

      ASSERT (2*n >= sizes[i]);
      ASSERT (2*n <= sizes[i] + 1);
      /* FIXME: Could maybe use wraparound, low half known from
         previous iteration. */
      mpn_sqr (tp, xp, n);
      
      /* FIXME: Could use wraparound; result is 1 (mod 2^c) */
      ASSERT (sizes[i] <= an);
      mpn_mullo_n (ep, tp, ap, sizes[i]);

      /* Minimum number of low zero limbs, in (a x^2 - 1) */
      zn = (c - 1) / GMP_NUMB_BITS;
      
      if (zn == 0)
        ASSERT_CARRY (mpn_rshift (ep, ep, sizes[i], 1));
      else
        {
          ASSERT (ep[0] == 1 && mpn_zero_p (ep+1, zn-1));
          ASSERT_NOCARRY (mpn_rshift (ep+zn, ep+zn, sizes[i]-zn, 1));
        }

      /* Need low sizes[i] low product x * e. Low zn limbs are known zero. */
      en = sizes[i] - zn;
      if (en >= n)
        mpn_mullo (tp, ep+zn, en, xp, n);
      else
        mpn_mullo (tp, xp, n, ep+zn, en);
      mpn_neg (tp, tp, en);
      ASSERT (zn <= n);
      if (zn < n)
        mpn_add (xp + zn, tp, en, xp + zn, n - zn);
      else
        MPN_COPY (xp + zn, tp, en);

      n = sizes[i];
      c = MIN (2*c - 2, n * GMP_NUMB_BITS - 1);
    }

  n = (b + GMP_NUMB_BITS - 1) / GMP_NUMB_BITS;
  mpn_mullo_n (rp, xp, ap, n);
  TMP_FREE;

  i = (b-1) % GMP_NUMB_BITS;
  if (i > 0)
    rp[n-1] &= GMP_NUMB_MAX >> (GMP_NUMB_BITS - i);
  else
    /* FIXME: Avoid writing the unnecessary limb in this case. */
    rp[n-1] = 0;

  return 1;
}

#define MAX_BIT_SIZE 10000
#define MAX_LIMBS ((MAX_BIT_SIZE + GMP_NUMB_BITS - 1)/GMP_NUMB_BITS)

int
main (int argc, char **argv)
{
  mp_limb_t a[MAX_LIMBS];
  mp_limb_t r[MAX_LIMBS];
  mp_limb_t t[MAX_LIMBS];

  gmp_randstate_t rands;
  unsigned i, n_root;
  
  gmp_randinit_default (rands);
  
  for (i = n_root = 0; i < 10000; i++)
    {
      mp_bitcnt_t size = 1 + gmp_urandomm_ui (rands, MAX_BIT_SIZE);
      mp_size_t n = (size + GMP_NUMB_BITS - 1) / GMP_NUMB_BITS;
      if (i & 1)
        mpn_random2 (a, n);
      else
        mpn_random (a, n);

      a[0] |= 1;

      if (i & 6)
        {
          /* Ensure root exists. */
          if (size > 1)
            {
              a[0] &= ~(mp_limb_t) 2;
              if (size > 2)
                a[0] &= ~(mp_limb_t) 4;
            }
        }

      if (mpn_bsqrt (r, a, size))
        {
          mp_limb_t mask;
          int c;

          mpn_mullo_n (t, r, r, n);

          /* Check t == a mod 2^size */
          mask = GMP_NUMB_MAX >> (n*GMP_NUMB_BITS - size);

          if ((r[0] & 3) != 1)
            {
              fprintf (stderr,
                       "mpn_bsqrt returned != 1 (mod 4): %d bits, %d limbs\n",
                       (int) size, (int) n);
              goto fail;
            }
          if (size > 1 && (r[n-1] & (mask - (mask >> 1))))
            {   
              fprintf (stderr,
                       "mpn_bsqrt returned with high bit set: %d bits, %d 
limbs\n",
                       (int) size, (int) n);
              goto fail;
            }
      
          MPN_CMP (c, t, a, n-1);
          if (c != 0 || ((t[n-1] - a[n-1]) & mask))
            {
              fprintf (stderr,
                       "mpn_bsqrt returned bad result: %d bits, %d limbs\n",
                       (int) size, (int) n);
            fail:
              gmp_fprintf (stderr, "a   = %Nx\n", a, n);
              gmp_fprintf (stderr, "r   = %Nx\n", r, n);
              gmp_fprintf (stderr, "r^2 = %Nx\n", t, n);
              abort ();
            }
          n_root ++;
        }
      else
        {
          mp_limb_t a0 = a[0];
          if ( (a0 & 7) == 1
               || (size == 2 && (a0 & 3) == 1)
               || (size == 1 && (a0 & 1)))
            {
              fprintf (stderr,
                       "mpn_bsqrt returned zero: %d bits, %d limbs\n",
                       (int) size, (int) n);
              gmp_fprintf (stderr, "a   = %Nx\n", a, n);
              abort ();
            }
        }
    }
  fprintf (stderr, "%d tests, %d with square root, %d with none.\n",
           i, n_root, i - n_root);
  return EXIT_SUCCESS;
}
-- 
Niels Möller. PGP-encrypted email is preferred. Keyid C0B98E26.
Internet email is subject to wholesale government surveillance.
_______________________________________________
gmp-devel mailing list
gmp-devel@gmplib.org
http://gmplib.org/mailman/listinfo/gmp-devel

Reply via email to