commit 3e3b44d087ab616089402129b2bc4c4831c6b33a
Author:     Mattias Andrée <[email protected]>
AuthorDate: Wed Mar 16 14:30:29 2016 +0100
Commit:     Mattias Andrée <[email protected]>
CommitDate: Wed Mar 16 14:30:29 2016 +0100

    Optimise zsqr, zmul, zstr, zdivmod, zpow, and zpowu
    
    Signed-off-by: Mattias Andrée <[email protected]>

diff --git a/TODO b/TODO
index 603fabb..5c69340 100644
--- a/TODO
+++ b/TODO
@@ -15,3 +15,7 @@ Test optimisation of zmul:
      - [(Hb - Lb)(Hc - Lc) << m2]
      + [(Lb * Lc) << m2]
      + (Lb * Lc)
+
+Would zmul be faster if we split only one of the
+factors until they are both approximately the same
+size?
diff --git a/src/internals.h b/src/internals.h
index a9d9af5..ed5d7e9 100644
--- a/src/internals.h
+++ b/src/internals.h
@@ -234,25 +234,28 @@ zswap_tainted_unsigned(z_t a, z_t b)
 }
 
 static inline void
-zsplit_fast_large_taint(z_t high, z_t low, z_t a, size_t n)
+zsplit_unsigned_fast_large_taint(z_t high, z_t low, z_t a, size_t n)
 {
        n >>= LB_BITS_PER_CHAR;
-       high->sign = a->sign;
+       high->sign = 1;
        high->used = a->used - n;
        high->chars = a->chars + n;
-       low->sign = a->sign;
+#if 0
+       TRIM_AND_ZERO(high);
+#endif
+       low->sign = 1;
        low->used = n;
        low->chars = a->chars;
        TRIM_AND_ZERO(low);
 }
 
 static inline void
-zsplit_fast_small_tainted(z_t high, z_t low, z_t a, size_t n)
+zsplit_unsigned_fast_small_tainted(z_t high, z_t low, z_t a, size_t n)
 {
        zahl_char_t mask = 1;
        mask = (mask << n) - 1;
 
-       high->sign = a->sign;
+       high->sign = 1;
        high->used = 1;
        high->chars[0] = a->chars[0] >> n;
        if (a->used == 2) {
@@ -261,10 +264,12 @@ zsplit_fast_small_tainted(z_t high, z_t low, z_t a, 
size_t n)
                n = BITS_PER_CHAR - n;
                high->chars[0] |= (a->chars[1] & mask) << n;
        }
+#if 0
        if (unlikely(!high->chars[high->used - 1]))
                high->sign = 0;
+#endif
 
-       low->sign = a->sign;
+       low->sign = 1;
        low->used = 1;
        low->chars[0] = a->chars[0] & mask;
        if (unlikely(!low->chars[0]))
diff --git a/src/zdivmod.c b/src/zdivmod.c
index 55d0e06..d907450 100644
--- a/src/zdivmod.c
+++ b/src/zdivmod.c
@@ -1,44 +1,17 @@
 /* See LICENSE file for copyright and license details. */
 #include "internals.h"
 
-#define ta   libzahl_tmp_divmod_a
-#define tb   libzahl_tmp_divmod_b
-#define td   libzahl_tmp_divmod_d
-#define tds  libzahl_tmp_divmod_ds
+#define ta          libzahl_tmp_divmod_a
+#define tb          libzahl_tmp_divmod_b
+#define td          libzahl_tmp_divmod_d
+#define tds_proper  libzahl_tmp_divmod_ds
 
 
-void
-zdivmod(z_t a, z_t b, z_t c, z_t d)
+static inline void
+zdivmod_impl(z_t a, z_t b, z_t c, z_t d)
 {
-       size_t c_bits, d_bits, bit;
-       int sign, cmpmag;
-
-       sign = zsignum(c) * zsignum(d);
-
-       if (unlikely(!sign)) {
-               if (zzero(c)) {
-                       if (zzero(d)) {
-                               libzahl_failure(-ZERROR_0_DIV_0);
-                       } else {
-                               SET_SIGNUM(a, 0);
-                               SET_SIGNUM(b, 0);
-                       }
-               } else {
-                       libzahl_failure(-ZERROR_DIV_0);
-               }
-               return;
-       } else if (unlikely((cmpmag = zcmpmag(c, d)) <= 0)) {
-               if (cmpmag == 0) {
-                       zseti(a, sign);
-                       SET_SIGNUM(b, 0);
-                       return;
-               } else {
-                       SET(b, c);
-               }
-               SET_SIGNUM(b, 1);
-               SET_SIGNUM(a, 0);
-               return;
-       }
+       size_t c_bits, d_bits, bit, i;
+       static z_t tds[BITS_PER_CHAR];
 
        c_bits = zbits(c);
        d_bits = zbits(d);
@@ -54,10 +27,10 @@ zdivmod(z_t a, z_t b, z_t c, z_t d)
        SET_SIGNUM(ta, 0);
        zabs(tb, c);
 
-       if (bit < BITS_PER_CHAR) {
+       if (unlikely(bit <= BITS_PER_CHAR)) {
                for (;;) {
                        if (zcmpmag(td, tb) <= 0) {
-                               zsub(tb, tb, td);
+                               zsub_unsigned(tb, tb, td);
                                zbset(ta, ta, bit, 1);
                        }
                        if (!bit-- || zzero(tb))
@@ -65,25 +38,61 @@ zdivmod(z_t a, z_t b, z_t c, z_t d)
                        zrsh(td, td, 1);
                }
        } else {
-               size_t i;
-               for (i = 0; i < BITS_PER_CHAR; i++)
-                       zrsh(tds[i], td, i);
+               for (i = 0; i < BITS_PER_CHAR; i++) {
+                       zrsh(tds_proper[i], td, i);
+                       tds[i]->used = tds_proper[i]->used;
+                       tds[i]->sign = tds_proper[i]->sign;
+                       tds[i]->chars = tds_proper[i]->chars;
+               }
                for (;;) {
                        for (i = 0; i < BITS_PER_CHAR; i++) {
                                if (zcmpmag(tds[i], tb) <= 0) {
-                                       zsub(tb, tb, tds[i]);
+                                       zsub_unsigned(tb, tb, tds[i]);
                                        zbset(ta, ta, bit, 1);
                                }
                                if (!bit-- || zzero(tb))
                                        goto done;
                        }
                        for (i = MIN(bit, BITS_PER_CHAR - 1) + 1; i--;)
-                               zrsh(tds[i], tds[i], BITS_PER_CHAR);
+                               zrsh_taint(tds[i], BITS_PER_CHAR);
                }
        }
 done:
 
        zswap(a, ta);
        zswap(b, tb);
+}
+
+
+void
+zdivmod(z_t a, z_t b, z_t c, z_t d)
+{
+       int sign, cmpmag;
+
+       sign = zsignum(c) * zsignum(d);
+
+       if (unlikely(!sign)) {
+               if (unlikely(!zzero(c))) {
+                       libzahl_failure(-ZERROR_DIV_0);
+               } else if (unlikely(zzero(d))) {
+                       libzahl_failure(-ZERROR_0_DIV_0);
+               } else {
+                       SET_SIGNUM(a, 0);
+                       SET_SIGNUM(b, 0);
+               }
+               return;
+       } else if (cmpmag = zcmpmag(c, d), unlikely(cmpmag <= 0)) {
+               if (unlikely(cmpmag == 0)) {
+                       zseti(a, sign);
+                       SET_SIGNUM(b, 0);
+               } else {
+                       SET(b, c);
+                       SET_SIGNUM(b, 1);
+                       SET_SIGNUM(a, 0);
+               }
+               return;
+       }
+
+       zdivmod_impl(a, b, c, d);
        SET_SIGNUM(a, sign);
 }
diff --git a/src/zmul.c b/src/zmul.c
index 71460d8..6633edd 100644
--- a/src/zmul.c
+++ b/src/zmul.c
@@ -11,7 +11,7 @@ zmul_impl_single_char(z_t a, z_t b, z_t c)
        SET_SIGNUM(a, 1);
 }
 
-static void
+void
 zmul_impl(z_t a, z_t b, z_t c)
 {
        /*
@@ -19,13 +19,18 @@ zmul_impl(z_t a, z_t b, z_t c)
         * 
         * Basically, this is how you were taught to multiply large numbers
         * by hand in school: 4010⋅3020 = (4000 + 10)(3000 + 20) =
-        = 40⋅30⋅10⁴ + (40⋅20 + 30⋅10)⋅10² + 10⋅20, but the 
middle is
+        * = 40⋅30⋅10⁴ + (40⋅20 + 30⋅10)⋅10² + 10⋅20, but the 
middle is
         * optimised to only one multiplication:
         * 40⋅20 + 30⋅10 = (40 + 10)(30 + 20) − 40⋅30 − 10⋅20.
+        * This optimisation is crucial. Without it, the algorithm with
+        * run in O(n²).
         */
 
+#define z2 c_low
+#define z1 b_low
+#define z0 a
        size_t m, m2;
-       z_t z0, z1, z2, b_high, b_low, c_high, c_low;
+       z_t b_high, b_low, c_high, c_low;
 
        if (unlikely(zzero1(b, c))) {
                SET_SIGNUM(a, 0);
@@ -43,9 +48,6 @@ zmul_impl(z_t a, z_t b, z_t c)
         m = MAX(m, m2);
        m2 = m >> 1;
 
-       zinit(z0);
-       zinit(z1);
-       zinit(z2);
        zinit(b_high);
        zinit(b_low);
        zinit(c_high);
@@ -66,14 +68,11 @@ zmul_impl(z_t a, z_t b, z_t c)
 
        zlsh(z1, z1, m2);
        m2 <<= 1;
-       zlsh(a, z2, m2);
+       zlsh(z2, z2, m2);
        zadd_unsigned_assign(a, z1);
-       zadd_unsigned_assign(a, z0);
+       zadd_unsigned_assign(a, z2);
 
 
-       zfree(z0);
-       zfree(z1);
-       zfree(z2);
        zfree(b_high);
        zfree(b_low);
        zfree(c_high);
diff --git a/src/zpow.c b/src/zpow.c
index 81f8098..84f4927 100644
--- a/src/zpow.c
+++ b/src/zpow.c
@@ -5,6 +5,9 @@
 #define tc  libzahl_tmp_pow_c
 
 
+extern void zmul_impl(z_t a, z_t b, z_t c);
+extern void zsqr_impl(z_t a, z_t b);
+
 void
 zpow(z_t a, z_t b, z_t c)
 {
@@ -16,6 +19,7 @@ zpow(z_t a, z_t b, z_t c)
 
        size_t i, j, n, bits;
        zahl_char_t x;
+       int neg;
 
        if (unlikely(zsignum(c) <= 0)) {
                if (zzero(c)) {
@@ -36,7 +40,8 @@ zpow(z_t a, z_t b, z_t c)
        bits = zbits(c);
        n = FLOOR_BITS_TO_CHARS(bits);
 
-       zset(tb, b);
+       neg = znegative(b) && zodd(c);
+       zabs(tb, b);
        zset(tc, c);
        zsetu(a, 1);
 
@@ -44,14 +49,17 @@ zpow(z_t a, z_t b, z_t c)
                x = tc->chars[i];
                for (j = BITS_PER_CHAR; j--; x >>= 1) {
                        if (x & 1)
-                               zmul(a, a, tb);
-                       zsqr(tb, tb);
+                               zmul_impl(a, a, tb);
+                       zsqr_impl(tb, tb);
                }
        }
        x = tc->chars[i];
        for (; x; x >>= 1) {
                if (x & 1)
-                       zmul(a, a, tb);
-               zsqr(tb, tb);
+                       zmul_impl(a, a, tb);
+               zsqr_impl(tb, tb);
        }
+
+       if (neg)
+               zneg(a, a);
 }
diff --git a/src/zpowu.c b/src/zpowu.c
index c4a2a64..cf879e0 100644
--- a/src/zpowu.c
+++ b/src/zpowu.c
@@ -4,9 +4,14 @@
 #define tb  libzahl_tmp_pow_b
 
 
+extern void zmul_impl(z_t a, z_t b, z_t c);
+extern void zsqr_impl(z_t a, z_t b);
+
 void
 zpowu(z_t a, z_t b, unsigned long long int c)
 {
+       int neg;
+
        if (unlikely(!c)) {
                if (zzero(b))
                        libzahl_failure(-ZERROR_0_POW_0);
@@ -17,12 +22,16 @@ zpowu(z_t a, z_t b, unsigned long long int c)
                return;
        }
 
-       zset(tb, b);
+       neg = znegative(b) && (c & 1);
+       zabs(tb, b);
        zsetu(a, 1);
 
        for (; c; c >>= 1) {
                if (c & 1)
-                       zmul(a, a, tb);
-               zsqr(tb, tb);
+                       zmul_impl(a, a, tb);
+               zsqr_impl(tb, tb);
        }
+
+       if (neg)
+               zneg(a, a);
 }
diff --git a/src/zsqr.c b/src/zsqr.c
index e9418bf..8a616f0 100644
--- a/src/zsqr.c
+++ b/src/zsqr.c
@@ -11,16 +11,19 @@ zsqr_impl_single_char(z_t a, z_t b)
        SET_SIGNUM(a, 1);
 }
 
-static void
+extern void zmul_impl(z_t a, z_t b, z_t c);
+
+void
 zsqr_impl(z_t a, z_t b)
 {
        /*
         * Karatsuba algorithm, optimised for equal factors.
         */
 
-       z_t z0, z1, z2, high, low;
-       size_t bits;
+#define z2 a
+       z_t z0, z1, high, low;
        zahl_char_t auxchars[3];
+       size_t bits;
 
        bits = zbits(b);
 
@@ -31,39 +34,41 @@ zsqr_impl(z_t a, z_t b)
 
        bits >>= 1;
 
-       zinit(z0);
-       zinit(z1);
-       zinit(z2);
-
+       /* Try to split only at a character level rather than a bit level.
+        * Such splits are faster, even if bit-level is required, and do
+        * not require auxiliary memory except for the bit-level split
+        * which require constant auxiliary memory. */
        if (bits < BITS_PER_CHAR) {
                low->chars = auxchars;
                high->chars = auxchars + 1;
-               zsplit_fast_small_tainted(high, low, b, bits);
+               zsplit_unsigned_fast_small_tainted(high, low, b, bits);
        } else {
                bits &= ~(BITS_PER_CHAR - 1);
-               zsplit_fast_large_taint(high, low, b, bits);
+               zsplit_unsigned_fast_large_taint(high, low, b, bits);
        }
 
 
-       zsqr_impl(z2, high);
        if (unlikely(zzero(low))) {
-               SET_SIGNUM(z0, 0);
-               SET_SIGNUM(z1, 0);
+               zsqr_impl(z2, high);
+               zlsh(a, z2, bits << 1);
        } else {
+               zinit(z0);
+               zinit(z1);
+
                zsqr_impl(z0, low);
-               zmul(z1, low, high);
-       }
 
-       zlsh(z1, z1, bits + 1);
-       bits <<= 1;
-       zlsh(a, z2, bits);
-       zadd_unsigned_assign(a, z1);
-       zadd_unsigned_assign(a, z0);
+               zmul_impl(z1, low, high);
+               zlsh(z1, z1, bits + 1);
 
+               zsqr_impl(z2, high);
+               zlsh(a, z2, bits << 1);
 
-       zfree(z0);
-       zfree(z1);
-       zfree(z2);
+               zadd_unsigned_assign(a, z1);
+               zadd_unsigned_assign(a, z0);
+
+               zfree(z0);
+               zfree(z1);
+       }
 }
 
 void
diff --git a/src/zstr.c b/src/zstr.c
index c82ec89..f919dc8 100644
--- a/src/zstr.c
+++ b/src/zstr.c
@@ -12,6 +12,48 @@
  * the cast to unsigned long long must be changed accordingly. */
 
 
+#define S1(P)     P"0"    P"1"    P"2"    P"3"    P"4"    P"5"    P"6"    P"7" 
   P"8"    P"9"
+#define S2(P)  
S1(P"0")S1(P"1")S1(P"2")S1(P"3")S1(P"4")S1(P"5")S1(P"6")S1(P"7")S1(P"8")S1(P"9")
+
+
+static inline O2 void
+sprintint_fix(char *buf, zahl_char_t v)
+{
+       const char *partials = S2("");
+       uint16_t *buffer = (uint16_t *)(buf + 1);
+
+       buffer[8] = *(uint16_t *)(partials + 2 * (v % 100)), v /= 100;
+       buffer[7] = *(uint16_t *)(partials + 2 * (v % 100)), v /= 100;
+       buffer[6] = *(uint16_t *)(partials + 2 * (v % 100)), v /= 100;
+       buffer[5] = *(uint16_t *)(partials + 2 * (v % 100)), v /= 100;
+       buffer[4] = *(uint16_t *)(partials + 2 * (v % 100)), v /= 100;
+       buffer[3] = *(uint16_t *)(partials + 2 * (v % 100)), v /= 100;
+       buffer[2] = *(uint16_t *)(partials + 2 * (v % 100)), v /= 100;
+       buffer[1] = *(uint16_t *)(partials + 2 * (v % 100)), v /= 100;
+       buffer[0] = *(uint16_t *)(partials + 2 * (v % 100)), v /= 100;
+       *buf = '0' + v;
+       buf[19] = 0;
+}
+
+static inline void
+cmemmove(char *d, const char *s, long n)
+{
+       while (n--)
+               *d++ = *s++;
+}
+
+static inline size_t
+sprintint_min(char *buf, zahl_char_t v)
+{
+       long i = 0, j;
+       sprintint_fix(buf, v);
+       for (; buf[i] == '0'; i++);
+       cmemmove(buf, buf + i, j = 19 - i);
+       buf[j] = 0;
+       return j;
+}
+
+
 char *
 zstr(z_t a, char *b)
 {
@@ -42,12 +84,12 @@ zstr(z_t a, char *b)
        for (;;) {
                zdivmod(num, rem, num, libzahl_const_1e19);
                if (likely(!zzero(num))) {
-                       sprintf(b + n, "%019llu", zzero(rem) ? 0ULL : (unsigned 
long long)(rem->chars[0]));
+                       sprintint_fix(b + n, zzero(rem) ? 0 : rem->chars[0]);
                        b[n + 19] = overridden;
                        overridden = b[n];
                        n = n > 19 ? (n - 19) : 0;
                } else {
-                       len = (size_t)sprintf(buf, "%llu", (unsigned long 
long)(rem->chars[0]));
+                       len = sprintint_min(buf, rem->chars[0]);
                        if (overridden)
                                buf[len] = b[n + len];
                        memcpy(b + n, buf, len + 1);

Reply via email to