commit 0d409e30fd712507216d5b4bd32ab4b6cb5fae28
Author:     Mattias Andrée <[email protected]>
AuthorDate: Tue Mar 15 22:38:08 2016 +0100
Commit:     Mattias Andrée <[email protected]>
CommitDate: Tue Mar 15 22:38:08 2016 +0100

    Optimise zsqr, and optimise zmul a little
    
    Signed-off-by: Mattias Andrée <[email protected]>

diff --git a/src/internals.h b/src/internals.h
index 456a2df..bcb4c3f 100644
--- a/src/internals.h
+++ b/src/internals.h
@@ -232,3 +232,42 @@ zswap_tainted_unsigned(z_t a, z_t b)
        b->chars = a->chars;
        a->chars = t->chars;
 }
+
+static inline void
+zsplit_fast_large_taint(z_t high, z_t low, z_t a, size_t n)
+{
+       n >>= LB_BITS_PER_CHAR;
+       high->sign = a->sign;
+       high->used = a->used - n;
+       high->chars = a->chars + n;
+       low->sign = a->sign;
+       low->used = n;
+       low->chars = a->chars;
+       TRIM_AND_ZERO(low);
+       TRIM_AND_ZERO(high);
+}
+
+static inline void
+zsplit_fast_small_tainted(z_t high, z_t low, z_t a, size_t n)
+{
+       zahl_char_t mask = 1;
+       mask = (mask << n) - 1;
+
+       high->sign = a->sign;
+       high->used = 1;
+       high->chars[0] = a->chars[0] >> n;
+       if (a->used == 2) {
+               high->chars[1] = a->chars[1] >> n;
+               high->used += !!high->chars[1];
+               n = BITS_PER_CHAR - n;
+               high->chars[0] |= (a->chars[1] & mask) << n;
+       }
+       if (unlikely(!high->chars[high->used - 1]))
+               high->sign = 0;
+
+       low->sign = a->sign;
+       low->used = 1;
+       low->chars[0] = a->chars[0] & mask;
+       if (unlikely(!low->chars[0]))
+               low->sign = 0;
+}
diff --git a/src/zmul.c b/src/zmul.c
index ab41213..71460d8 100644
--- a/src/zmul.c
+++ b/src/zmul.c
@@ -2,8 +2,17 @@
 #include "internals.h"
 
 
-void
-zmul(z_t a, z_t b, z_t c)
+static inline void
+zmul_impl_single_char(z_t a, z_t b, z_t c)
+{
+       ENSURE_SIZE(a, 1);
+       a->used = 1;
+       a->chars[0] = b->chars[0] * c->chars[0];
+       SET_SIGNUM(a, 1);
+}
+
+static void
+zmul_impl(z_t a, z_t b, z_t c)
 {
        /*
         * Karatsuba algorithm
@@ -17,12 +26,8 @@ zmul(z_t a, z_t b, z_t c)
 
        size_t m, m2;
        z_t z0, z1, z2, b_high, b_low, c_high, c_low;
-       int b_sign, c_sign;
-
-       b_sign = zsignum(b);
-       c_sign = zsignum(c);
 
-       if (unlikely(!b_sign || !c_sign)) {
+       if (unlikely(zzero1(b, c))) {
                SET_SIGNUM(a, 0);
                return;
        }
@@ -31,18 +36,10 @@ zmul(z_t a, z_t b, z_t c)
        m2 = b == c ? m : zbits(c);
 
        if (m + m2 <= BITS_PER_CHAR) {
-               /* zsetu(a, b->chars[0] * c->chars[0]); { */
-               ENSURE_SIZE(a, 1);
-               a->used = 1;
-               a->chars[0] = b->chars[0] * c->chars[0];
-               /* } */
-               SET_SIGNUM(a, b_sign * c_sign);
+               zmul_impl_single_char(a, b, c);
                return;
        }
 
-       SET_SIGNUM(b, 1);
-       SET_SIGNUM(c, 1);
-
         m = MAX(m, m2);
        m2 = m >> 1;
 
@@ -58,11 +55,11 @@ zmul(z_t a, z_t b, z_t c)
        zsplit(c_high, c_low, c, m2);
 
 
-       zmul(z0, b_low, c_low);
-       zmul(z2, b_high, c_high);
+       zmul_impl(z0, b_low, c_low);
        zadd_unsigned_assign(b_low, b_high);
        zadd_unsigned_assign(c_low, c_high);
-       zmul(z1, b_low, c_low);
+       zmul_impl(z1, b_low, c_low);
+       zmul_impl(z2, b_high, c_high);
 
        zsub_nonnegative_assign(z1, z0);
        zsub_nonnegative_assign(z1, z2);
@@ -81,8 +78,16 @@ zmul(z_t a, z_t b, z_t c)
        zfree(b_low);
        zfree(c_high);
        zfree(c_low);
+}
 
-       SET_SIGNUM(b, b_sign);
-       SET_SIGNUM(c, c_sign);
-       SET_SIGNUM(a, b_sign * c_sign);
+void
+zmul(z_t a, z_t b, z_t c)
+{
+       int b_sign, c_sign;
+       b_sign = b->sign, b->sign *= b_sign;
+       c_sign = c->sign, c->sign *= c_sign;
+       zmul_impl(a, b, c);
+       c->sign *= c_sign;
+       b->sign *= b_sign;
+       SET_SIGNUM(a, zsignum(b) * zsignum(c));
 }
diff --git a/src/zsqr.c b/src/zsqr.c
index 68480ba..e9418bf 100644
--- a/src/zsqr.c
+++ b/src/zsqr.c
@@ -2,54 +2,61 @@
 #include "internals.h"
 
 
-void
-zsqr(z_t a, z_t b)
+static inline void
+zsqr_impl_single_char(z_t a, z_t b)
+{
+       ENSURE_SIZE(a, 1);
+       a->used = 1;
+       a->chars[0] = b->chars[0] * b->chars[0];
+       SET_SIGNUM(a, 1);
+}
+
+static void
+zsqr_impl(z_t a, z_t b)
 {
        /*
         * Karatsuba algorithm, optimised for equal factors.
         */
 
-       size_t m2;
        z_t z0, z1, z2, high, low;
-       int sign;
+       size_t bits;
+       zahl_char_t auxchars[3];
 
-       if (unlikely(zzero(b))) {
-               SET_SIGNUM(a, 0);
-               return;
-       }
-
-       m2 = zbits(b);
+       bits = zbits(b);
 
-       if (m2 <= BITS_PER_CHAR / 2) {
-               /* zsetu(a, b->chars[0] * b->chars[0]); { */
-               ENSURE_SIZE(a, 1);
-               a->used = 1;
-               a->chars[0] = b->chars[0] * b->chars[0];
-               /* } */
-               SET_SIGNUM(a, 1);
+       if (bits <= BITS_PER_CHAR / 2) {
+               zsqr_impl_single_char(a, b);
                return;
        }
 
-       sign = zsignum(b);
-       SET_SIGNUM(b, 1);
-       m2 >>= 1;
+       bits >>= 1;
 
        zinit(z0);
        zinit(z1);
        zinit(z2);
-       zinit(high);
-       zinit(low);
 
-       zsplit(high, low, b, m2);
+       if (bits < BITS_PER_CHAR) {
+               low->chars = auxchars;
+               high->chars = auxchars + 1;
+               zsplit_fast_small_tainted(high, low, b, bits);
+       } else {
+               bits &= ~(BITS_PER_CHAR - 1);
+               zsplit_fast_large_taint(high, low, b, bits);
+       }
 
 
-       zsqr(z0, low);
-       zsqr(z2, high);
-       zmul(z1, low, high);
+       zsqr_impl(z2, high);
+       if (unlikely(zzero(low))) {
+               SET_SIGNUM(z0, 0);
+               SET_SIGNUM(z1, 0);
+       } else {
+               zsqr_impl(z0, low);
+               zmul(z1, low, high);
+       }
 
-       zlsh(z1, z1, m2 + 1);
-       m2 <<= 1;
-       zlsh(a, z2, m2);
+       zlsh(z1, z1, bits + 1);
+       bits <<= 1;
+       zlsh(a, z2, bits);
        zadd_unsigned_assign(a, z1);
        zadd_unsigned_assign(a, z0);
 
@@ -57,9 +64,15 @@ zsqr(z_t a, z_t b)
        zfree(z0);
        zfree(z1);
        zfree(z2);
-       zfree(high);
-       zfree(low);
+}
 
-       SET_SIGNUM(b, sign);
-       SET_SIGNUM(a, 1);
+void
+zsqr(z_t a, z_t b)
+{
+       if (unlikely(zzero(b))) {
+               SET_SIGNUM(a, 0);
+       } else {
+               zsqr_impl(a, b);
+               SET_SIGNUM(a, 1);
+       }
 }

Reply via email to