Author: Maciej Fijalkowski <[email protected]>
Branch: numpy-multidim
Changeset: r48510:61f36db28f06
Date: 2011-10-27 17:37 +0200
http://bitbucket.org/pypy/pypy/changeset/61f36db28f06/
Log: setitem with slice - part one
diff --git a/pypy/module/micronumpy/interp_numarray.py
b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -10,9 +10,11 @@
numpy_driver = jit.JitDriver(greens = ['signature'],
reds = ['result_size', 'i', 'self', 'result'])
-all_driver = jit.JitDriver(greens=['signature'], reds=['i', 'size', 'self',
'dtype'])
-any_driver = jit.JitDriver(greens=['signature'], reds=['i', 'size', 'self',
'dtype'])
-slice_driver = jit.JitDriver(greens=['signature'], reds=['i', 'j', 'step',
'stop', 'source', 'dest'])
+all_driver = jit.JitDriver(greens=['signature'], reds=['i', 'size', 'self',
+ 'dtype'])
+any_driver = jit.JitDriver(greens=['signature'], reds=['i', 'size', 'self',
+ 'dtype'])
+slice_driver = jit.JitDriver(greens=['signature'], reds=['i', 'self',
'source'])
class BaseArray(Wrappable):
_attrs_ = ["invalidates", "signature"]
@@ -304,55 +306,26 @@
def descr_setitem(self, space, w_idx, w_value):
self.invalidated()
- if self._single_item_at_index(space, w_idx):
+ if self._single_item_result(space, w_idx):
item = self._single_item_at_index(space, w_idx)
self.get_concrete().setitem_w(space, item, w_value)
return
- xxx
- if space.isinstance_w(w_idx, space.w_tuple):
- length = space.len_w(w_idx)
- if length > 1: # only one dimension for now.
- raise OperationError(space.w_IndexError,
- space.wrap("invalid index"))
- if length == 0:
- w_idx = space.newslice(space.wrap(0),
- space.wrap(self.find_size()),
- space.wrap(1))
- else:
- w_idx = space.getitem(w_idx, space.wrap(0))
- start, stop, step, slice_length = space.decode_index4(w_idx,
- self.find_size())
- if step == 0:
- # Single index
- self.get_concrete().setitem_w(space, start, w_value)
+ concrete = self.get_concrete()
+ if isinstance(w_value, BaseArray):
+ # for now we just copy if setting part of an array from
+ # part of itself. can be improved.
+ if (concrete.get_root_storage() ==
+ w_value.get_concrete().get_root_storage()):
+ w_value = space.call_function(space.gettypefor(BaseArray),
w_value)
+ assert isinstance(w_value, BaseArray)
else:
- concrete = self.get_concrete()
- if isinstance(w_value, BaseArray):
- # for now we just copy if setting part of an array from
- # part of itself. can be improved.
- if (concrete.get_root_storage() ==
- w_value.get_concrete().get_root_storage()):
- w_value = space.call_function(space.gettypefor(BaseArray),
w_value)
- assert isinstance(w_value, BaseArray)
- else:
- w_value = convert_to_array(space, w_value)
- concrete.setslice(space, start, stop, step,
- slice_length, w_value)
+ w_value = convert_to_array(space, w_value)
+ view = self._create_slice(space, w_idx)
+ view.setslice(space, w_value)
def descr_mean(self, space):
return
space.wrap(space.float_w(self.descr_sum(space))/self.find_size())
- def _sliceloop(self, start, stop, step, source, dest):
- i = start
- j = 0
- while (step > 0 and i < stop) or (step < 0 and i > stop):
- slice_driver.jit_merge_point(signature=source.signature, step=step,
- stop=stop, i=i, j=j, source=source,
- dest=dest)
- dest.setitem(i, source.eval(j).convert_to(dest.find_dtype()))
- j += 1
- i += step
-
def convert_to_array(space, w_obj):
if isinstance(w_obj, BaseArray):
return w_obj
@@ -557,13 +530,23 @@
def find_dtype(self):
return self.parent.find_dtype()
- def setslice(self, space, start, stop, step, slice_length, arr):
- xxx
- start = self.calc_index(start)
- if stop != -1:
- stop = self.calc_index(stop)
- step = self.step * step
- self._sliceloop(start, stop, step, arr, self.parent)
+ def setslice(self, space, w_value):
+ assert isinstance(w_value, NDimArray)
+ if self.shape != w_value.shape:
+ raise OperationError(space.w_TypeError, space.wrap(
+ "wrong assignment"))
+ self._sliceloop(w_value)
+
+ def _sliceloop(self, source):
+ i = 0
+ while i < self.size:
+ slice_driver.jit_merge_point(signature=source.signature, i=i,
+ self=self, source=source)
+ self.setitem(i, source.eval(i).convert_to(self.find_dtype()))
+ i += 1
+
+ def setitem(self, item, value):
+ self.parent.setitem(self.calc_index(item), value)
def len_of_shape(self):
return self.parent.len_of_shape() - self.shape_reduction
@@ -644,9 +627,6 @@
self.invalidated()
self.dtype.setitem(self.storage, item, value)
- def setslice(self, space, start, stop, step, slice_length, arr):
- self._sliceloop(start, stop, step, arr, self)
-
def __del__(self):
lltype.free(self.storage, flavor='raw', track_allocation=False)
diff --git a/pypy/module/micronumpy/test/test_numarray.py
b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -658,6 +658,12 @@
assert a[0][1][1] == 13
assert a[1][2][1] == 15
+ def test_setitem_slice(self):
+ import numpy
+ a = numpy.zeros((3, 4))
+ a[1] = [1, 2, 3, 4]
+ assert a[1, 2] == 3
+
class AppTestSupport(object):
def setup_class(cls):
import struct
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit