Hi
I've come across an infinite loop in the polynomial modulus function
BN_GF2m_mod_arr when the degree of the modulus is a multiple of 32 (on
32-bit machines) or 64 (on 64-bit machines).
A simple example which illustrates this (in the attached test case) is:
(z^65 + z^5 + z^4 + z^2 + z + 1) mod (z^64 + z^4 + z^3 + z + 1) = 1
But this will hang indefinitely.
The problem occurs when the degree is a multiple of BN_BITS2. I believe
that the top d1 bits are never being cleared in the following line:
d0 = n % BN_BITS2;
d1 = BN_BITS2 - d0;
if (d0) z[dN] = (z[dN] << d1) >> d1; /* clear up the top d1 bits */
Note when n=32, that d0=0 and d1=32. Changing the above to (see
attached patch):
if (d0) z[dN] = (z[dN] << d1) >> d1; else z[dN] = 0; /* clear up the top
d1 bits */
fixes the problem.
The above seems intuitively correct since the above would be equivalent to:
z[dN] = (z[dN] << d1) >> d1;
if shifts greater than BN_BITS2 didn't wrap around (since 0<=d0<BN_BITS2
and 1<=d1<=BN_BITS2).
Attached are the following files:
test_BN_GF2m_mod_arr.cpp : My test case
BN_GF2m_mod_arr.patch : Patch to fix the bug
Regards
Robert
--- openssl-0.9.8g/crypto/bn/bn_gf2m.c.orig 2008-06-17 18:47:22.000000000 +0100
+++ openssl-0.9.8g/crypto/bn/bn_gf2m.c 2008-06-18 10:51:36.000000000 +0100
@@ -384,7 +384,7 @@
if (zz == 0) break;
d1 = BN_BITS2 - d0;
- if (d0) z[dN] = (z[dN] << d1) >> d1; /* clear up the top d1 bits */
+ if (d0) z[dN] = (z[dN] << d1) >> d1; else z[dN] = 0; /* clear up the top d1 bits */
z[0] ^= zz; /* reduction t^0 component */
for (k = 1; p[k] != 0; k++)
#include <openssl/bn.h>
int main()
{
// Primitive pentanomial p(z) = z^64 + z^4 + z^3 + z + 1
const unsigned int p[] = {64, 4, 3, 1, 0};
// Polynomial a(z) = z p(z) + 1 = z^65 + z^5 + z^4 + z^2 + z + 1
BIGNUM *a = BN_new();
BN_hex2bn(&a, "20000000000000037");
// Print polynomial a(z)
printf("a = "); BN_print_fp(stdout, a); printf("\n");
// Calculate r(z) = a(z) mod p(z) = 1
BIGNUM *r = BN_new();
BN_GF2m_mod_sqr_arr(r, a, p, BN_CTX_new());
// Print result r(z)
printf("r = "); BN_print_fp(stdout, r); printf("\n");
return 0;
}