Author: Brian Kearns <[email protected]>
Branch:
Changeset: r73852:2cc2e4c576c5
Date: 2014-10-08 19:34 -0400
http://bitbucket.org/pypy/pypy/changeset/2cc2e4c576c5/
Log: optimize iterator goto if array is contiguous
diff --git a/pypy/module/micronumpy/concrete.py
b/pypy/module/micronumpy/concrete.py
--- a/pypy/module/micronumpy/concrete.py
+++ b/pypy/module/micronumpy/concrete.py
@@ -19,6 +19,7 @@
'strides[*]', 'backstrides[*]', 'order']
start = 0
parent = None
+ flags = 0
# JIT hints that length of all those arrays is a constant
diff --git a/pypy/module/micronumpy/constants.py
b/pypy/module/micronumpy/constants.py
--- a/pypy/module/micronumpy/constants.py
+++ b/pypy/module/micronumpy/constants.py
@@ -74,6 +74,9 @@
WRAP = 1
RAISE = 2
+ARRAY_C_CONTIGUOUS = 0x0001
+ARRAY_F_CONTIGUOUS = 0x0002
+
LITTLE = '<'
BIG = '>'
NATIVE = '='
diff --git a/pypy/module/micronumpy/flagsobj.py
b/pypy/module/micronumpy/flagsobj.py
--- a/pypy/module/micronumpy/flagsobj.py
+++ b/pypy/module/micronumpy/flagsobj.py
@@ -2,6 +2,46 @@
from pypy.interpreter.error import OperationError
from pypy.interpreter.gateway import interp2app
from pypy.interpreter.typedef import TypeDef, GetSetProperty
+from pypy.module.micronumpy import constants as NPY
+
+
+def enable_flags(arr, flags):
+ arr.flags |= flags
+
+
+def clear_flags(arr, flags):
+ arr.flags &= ~flags
+
+
+def _update_contiguous_flags(arr):
+ shape = arr.shape
+ strides = arr.strides
+
+ is_c_contig = True
+ sd = arr.dtype.elsize
+ for i in range(len(shape) - 1, -1, -1):
+ dim = shape[i]
+ if strides[i] != sd:
+ is_c_contig = False
+ break
+ if dim == 0:
+ break
+ sd *= dim
+ if is_c_contig:
+ enable_flags(arr, NPY.ARRAY_C_CONTIGUOUS)
+ else:
+ clear_flags(arr, NPY.ARRAY_C_CONTIGUOUS)
+
+ sd = arr.dtype.elsize
+ for i in range(len(shape)):
+ dim = shape[i]
+ if strides[i] != sd:
+ clear_flags(arr, NPY.ARRAY_F_CONTIGUOUS)
+ return
+ if dim == 0:
+ break
+ sd *= dim
+ enable_flags(arr, NPY.ARRAY_F_CONTIGUOUS)
class W_FlagsObject(W_Root):
diff --git a/pypy/module/micronumpy/iterators.py
b/pypy/module/micronumpy/iterators.py
--- a/pypy/module/micronumpy/iterators.py
+++ b/pypy/module/micronumpy/iterators.py
@@ -37,8 +37,9 @@
All the calculations happen in next()
"""
from rpython.rlib import jit
-from pypy.module.micronumpy import support
+from pypy.module.micronumpy import support, constants as NPY
from pypy.module.micronumpy.base import W_NDimArray
+from pypy.module.micronumpy.flagsobj import _update_contiguous_flags
class PureShapeIter(object):
@@ -86,11 +87,14 @@
class ArrayIter(object):
- _immutable_fields_ = ['array', 'size', 'ndim_m1', 'shape_m1[*]',
+ _immutable_fields_ = ['contiguous', 'array', 'size', 'ndim_m1',
'shape_m1[*]',
'strides[*]', 'backstrides[*]', 'factors[*]']
def __init__(self, array, size, shape, strides, backstrides):
assert len(shape) == len(strides) == len(backstrides)
+ _update_contiguous_flags(array)
+ self.contiguous = array.flags & NPY.ARRAY_C_CONTIGUOUS
+
self.array = array
self.size = size
self.ndim_m1 = len(shape) - 1
@@ -137,12 +141,14 @@
@jit.unroll_safe
def goto(self, index):
- # XXX simplify if self.contiguous (offset = start + index * elsize)
offset = self.array.start
- current = index
- for i in xrange(len(self.shape_m1)):
- offset += (current / self.factors[i]) * self.strides[i]
- current %= self.factors[i]
+ if self.contiguous:
+ offset += index * self.array.dtype.elsize
+ else:
+ current = index
+ for i in xrange(len(self.shape_m1)):
+ offset += (current / self.factors[i]) * self.strides[i]
+ current %= self.factors[i]
return IterState(self, index, None, offset)
def done(self, state):
diff --git a/pypy/module/micronumpy/test/test_zjit.py
b/pypy/module/micronumpy/test/test_zjit.py
--- a/pypy/module/micronumpy/test/test_zjit.py
+++ b/pypy/module/micronumpy/test/test_zjit.py
@@ -544,19 +544,13 @@
assert result == 10.0
self.check_trace_count(1)
self.check_simple_loop({
- 'arraylen_gc': 2,
'getarrayitem_gc': 1,
'guard_false': 1,
'guard_true': 1,
- 'int_add': 7,
- 'int_and': 1,
- 'int_floordiv': 1,
+ 'int_add': 5,
'int_ge': 1,
'int_lt': 1,
- 'int_mod': 1,
- 'int_mul': 2,
- 'int_rshift': 2,
- 'int_sub': 1,
+ 'int_mul': 1,
'jump': 1,
'raw_load': 1,
'raw_store': 1,
@@ -576,19 +570,14 @@
assert result == 1.0
self.check_trace_count(1)
self.check_simple_loop({
- 'arraylen_gc': 2,
'getarrayitem_gc': 1,
'guard_not_invalidated': 1,
'guard_true': 2,
- 'int_add': 7,
- 'int_and': 1,
- 'int_floordiv': 1,
+ 'int_add': 5,
'int_gt': 1,
'int_lt': 1,
- 'int_mod': 1,
- 'int_mul': 2,
- 'int_rshift': 2,
- 'int_sub': 2,
+ 'int_mul': 1,
+ 'int_sub': 1,
'jump': 1,
'raw_load': 1,
'raw_store': 1,
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit