Author: David Schneider <[email protected]>
Branch: 
Changeset: r44838:1a44e52c715b
Date: 2011-06-08 19:11 +0200
http://bitbucket.org/pypy/pypy/changeset/1a44e52c715b/

Log:    (arigo, bivab) add checks to shift operations that they do not get a
        shift count that is out of range, at least for C. If you do this in
        C you get undefined behaviour.

diff --git a/pypy/translator/c/genc.py b/pypy/translator/c/genc.py
--- a/pypy/translator/c/genc.py
+++ b/pypy/translator/c/genc.py
@@ -900,8 +900,9 @@
     print >> f, '}'
 
 def commondefs(defines):
-    from pypy.rlib.rarithmetic import LONG_BIT
+    from pypy.rlib.rarithmetic import LONG_BIT, LONGLONG_BIT
     defines['PYPY_LONG_BIT'] = LONG_BIT
+    defines['PYPY_LONGLONG_BIT'] = LONGLONG_BIT
 
 def add_extra_files(eci):
     srcdir = py.path.local(autopath.pypydir).join('translator', 'c', 'src')
diff --git a/pypy/translator/c/src/int.h b/pypy/translator/c/src/int.h
--- a/pypy/translator/c/src/int.h
+++ b/pypy/translator/c/src/int.h
@@ -73,15 +73,28 @@
 
 /* NB. shifting has same limitations as C: the shift count must be
        >= 0 and < LONG_BITS. */
-#define OP_INT_RSHIFT(x,y,r)    r = Py_ARITHMETIC_RIGHT_SHIFT(long, x, y)
-#define OP_UINT_RSHIFT(x,y,r)   r = (x) >> (y)
-#define OP_LLONG_RSHIFT(x,y,r)  r = Py_ARITHMETIC_RIGHT_SHIFT(PY_LONG_LONG,x,y)
-#define OP_ULLONG_RSHIFT(x,y,r) r = (x) >> (y)
+#define CHECK_SHIFT_RANGE(y, bits) RPyAssert(y >= 0 && y < bits, \
+              "The shift count is outside of the supported range")
 
-#define OP_INT_LSHIFT(x,y,r)    r = (x) << (y)
-#define OP_UINT_LSHIFT(x,y,r)   r = (x) << (y)
-#define OP_LLONG_LSHIFT(x,y,r)  r = (x) << (y)
-#define OP_ULLONG_LSHIFT(x,y,r) r = (x) << (y)
+
+#define OP_INT_RSHIFT(x,y,r)    CHECK_SHIFT_RANGE(y, PYPY_LONG_BIT); \
+                                               r = 
Py_ARITHMETIC_RIGHT_SHIFT(long, x, (y))
+#define OP_UINT_RSHIFT(x,y,r)   CHECK_SHIFT_RANGE(y, PYPY_LONG_BIT); \
+                                               r = (x) >> (y)
+#define OP_LLONG_RSHIFT(x,y,r)  CHECK_SHIFT_RANGE(y, PYPY_LONGLONG_BIT); \
+                                               r = 
Py_ARITHMETIC_RIGHT_SHIFT(PY_LONG_LONG,x, (y))
+#define OP_ULLONG_RSHIFT(x,y,r) CHECK_SHIFT_RANGE(y, PYPY_LONGLONG_BIT); \
+                                               r = (x) >> (y)
+
+
+#define OP_INT_LSHIFT(x,y,r)    CHECK_SHIFT_RANGE(y, PYPY_LONG_BIT); \
+                                                       r = (x) << (y)
+#define OP_UINT_LSHIFT(x,y,r)   CHECK_SHIFT_RANGE(y, PYPY_LONG_BIT); \
+                                                       r = (x) << (y)
+#define OP_LLONG_LSHIFT(x,y,r)  CHECK_SHIFT_RANGE(y, PYPY_LONGLONG_BIT); \
+                                                       r = (x) << (y)
+#define OP_ULLONG_LSHIFT(x,y,r) CHECK_SHIFT_RANGE(y, PYPY_LONGLONG_BIT); \
+                                                       r = (x) << (y)
 
 #define OP_INT_LSHIFT_OVF(x,y,r) \
        OP_INT_LSHIFT(x,y,r); \
diff --git a/pypy/translator/c/test/test_standalone.py 
b/pypy/translator/c/test/test_standalone.py
--- a/pypy/translator/c/test/test_standalone.py
+++ b/pypy/translator/c/test/test_standalone.py
@@ -596,6 +596,42 @@
         # The traceback stops at f() because it's the first function that
         # captures the AssertionError, which makes the program abort.
 
+    def test_int_lshift_too_large(self):
+        from pypy.rlib.rarithmetic import LONG_BIT, LONGLONG_BIT
+        def entry_point(argv):
+            a = int(argv[1])
+            b = int(argv[2])
+            print a << b
+            return 0
+
+        t, cbuilder = self.compile(entry_point, debug=True)
+        out = cbuilder.cmdexec("10 2", expect_crash=False)
+        assert out.strip() == str(10 << 2)
+        cases = [-4, LONG_BIT, LONGLONG_BIT]
+        for x in cases:
+            out, err = cbuilder.cmdexec("%s %s" % (1, x), expect_crash=True)
+            lines = err.strip()
+            assert 'The shift count is outside of the supported range' in lines
+
+    def test_llong_rshift_too_large(self):
+        from pypy.rlib.rarithmetic import LONG_BIT, LONGLONG_BIT
+        def entry_point(argv):
+            a = r_longlong(int(argv[1]))
+            b = r_longlong(int(argv[2]))
+            print a >> b
+            return 0
+
+        t, cbuilder = self.compile(entry_point, debug=True)
+        out = cbuilder.cmdexec("10 2", expect_crash=False)
+        assert out.strip() == str(10 >> 2)
+        out = cbuilder.cmdexec("%s %s" % (-42, LONGLONG_BIT - 1), 
expect_crash=False)
+        assert out.strip() == '-1'
+        cases = [-4, LONGLONG_BIT]
+        for x in cases:
+            out, err = cbuilder.cmdexec("%s %s" % (1, x), expect_crash=True)
+            lines = err.strip()
+            assert 'The shift count is outside of the supported range' in lines
+
     def test_ll_assert_error_debug(self):
         def entry_point(argv):
             ll_assert(len(argv) != 1, "foobar")
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to