Author: Brian Kearns <[email protected]>
Branch: 
Changeset: r69345:3b0fac02a353
Date: 2014-02-23 22:49 -0500
http://bitbucket.org/pypy/pypy/changeset/3b0fac02a353/

Log:    fix changing field names on record dtypes

diff --git a/pypy/module/micronumpy/interp_dtype.py 
b/pypy/module/micronumpy/interp_dtype.py
--- a/pypy/module/micronumpy/interp_dtype.py
+++ b/pypy/module/micronumpy/interp_dtype.py
@@ -232,21 +232,23 @@
                           space.newtuple([subdtype, space.wrap(offset)]))
         return w_d
 
-    def descr_set_fields(self, space, w_fields):
+    def descr_set_fields(self, space, w_fieldnames, w_fields):
         if w_fields == space.w_None:
             self.fields = None
         else:
+            self.fieldnames = []
             self.fields = {}
             size = 0
-            for key in space.listview(w_fields):
-                value = space.getitem(w_fields, key)
+            for w_name in space.fixedview(w_fieldnames):
+                name = space.str_w(w_name)
+                value = space.getitem(w_fields, w_name)
 
                 dtype = space.getitem(value, space.wrap(0))
                 assert isinstance(dtype, W_Dtype)
+                offset = space.int_w(space.getitem(value, space.wrap(1)))
 
-                offset = space.int_w(space.getitem(value, space.wrap(1)))
-                self.fields[space.str_w(key)] = offset, dtype
-
+                self.fieldnames.append(name)
+                self.fields[name] = offset, dtype
                 size += dtype.get_size()
             self.itemtype = types.RecordType()
             self.size = size
@@ -257,20 +259,27 @@
         return space.newtuple([space.wrap(name) for name in self.fieldnames])
 
     def descr_set_names(self, space, w_names):
+        if len(self.fieldnames) == 0:
+            raise oefmt(space.w_ValueError, "there are no fields defined")
+        if not space.issequence_w(w_names) or \
+                space.len_w(w_names) != len(self.fieldnames):
+            raise oefmt(space.w_ValueError,
+                        "must replace all names at once "
+                        "with a sequence of length %d",
+                        len(self.fieldnames))
         fieldnames = []
-        if w_names != space.w_None:
-            iter = space.iter(w_names)
-            while True:
-                try:
-                    name = space.str_w(space.next(iter))
-                except OperationError, e:
-                    if not e.match(space, space.w_StopIteration):
-                        raise
-                    break
-                if name in fieldnames:
-                    raise OperationError(space.w_ValueError, space.wrap(
-                        "Duplicate field names given."))
-                fieldnames.append(name)
+        for w_name in space.fixedview(w_names):
+            if not space.isinstance_w(w_name, space.w_str):
+                raise oefmt(space.w_ValueError,
+                            "item #%d of names is of type %T and not string",
+                            len(fieldnames), w_name)
+            fieldnames.append(space.str_w(w_name))
+        fields = {}
+        for i in range(len(self.fieldnames)):
+            if fieldnames[i] in fields:
+                raise oefmt(space.w_ValueError, "Duplicate field names given.")
+            fields[fieldnames[i]] = self.fields[self.fieldnames[i]]
+        self.fields = fields
         self.fieldnames = fieldnames
 
     def descr_del_names(self, space):
@@ -353,11 +362,9 @@
             endian = NPY.NATIVE
         self.byteorder = endian
 
-        fieldnames = space.getitem(w_data, space.wrap(3))
-        self.descr_set_names(space, fieldnames)
-
-        fields = space.getitem(w_data, space.wrap(4))
-        self.descr_set_fields(space, fields)
+        w_fieldnames = space.getitem(w_data, space.wrap(3))
+        w_fields = space.getitem(w_data, space.wrap(4))
+        self.descr_set_fields(space, w_fieldnames, w_fields)
 
     @unwrap_spec(new_order=str)
     def descr_newbyteorder(self, space, new_order=NPY.SWAP):
diff --git a/pypy/module/micronumpy/test/test_dtypes.py 
b/pypy/module/micronumpy/test/test_dtypes.py
--- a/pypy/module/micronumpy/test/test_dtypes.py
+++ b/pypy/module/micronumpy/test/test_dtypes.py
@@ -48,6 +48,10 @@
         assert dtype('bool') is d
         assert dtype('|b1') is d
         assert repr(type(d)) == "<type 'numpy.dtype'>"
+        exc = raises(ValueError, "d.names = []")
+        assert exc.value[0] == "there are no fields defined"
+        exc = raises(ValueError, "d.names = None")
+        assert exc.value[0] == "there are no fields defined"
 
         assert dtype('int8').num == 1
         assert dtype('int8').name == 'int8'
@@ -1006,21 +1010,34 @@
         from numpypy import dtype, void
 
         raises(ValueError, "dtype([('x', int), ('x', float)])")
-        d = dtype([("x", "int32"), ("y", "int32"), ("z", "int32"), ("value", 
float)])
-        assert d.fields['x'] == (dtype('int32'), 0)
-        assert d.fields['value'] == (dtype(float), 12)
-        assert d['x'] == dtype('int32')
-        assert d.name == "void160"
+        d = dtype([("x", "<i4"), ("y", "<f4"), ("z", "<u2"), ("v", "<f8")])
+        assert d.fields['x'] == (dtype('<i4'), 0)
+        assert d.fields['v'] == (dtype('<f8'), 10)
+        assert d['x'] == dtype('<i4')
+        assert d.name == "void144"
         assert d.num == 20
-        assert d.itemsize == 20
+        assert d.itemsize == 18
         assert d.kind == 'V'
         assert d.base == d
         assert d.type is void
         assert d.char == 'V'
-        assert d.names == ("x", "y", "z", "value")
-        d.names = ('a', '', 'c', 'd')
-        assert d.names == ('a', '', 'c', 'd')
-        d.names = ('a', 'b', 'c', 'd')
+        exc = raises(ValueError, "d.names = None")
+        assert exc.value[0] == 'must replace all names at once with a sequence 
of length 4'
+        exc = raises(ValueError, "d.names = (a for a in 'xyzv')")
+        assert exc.value[0] == 'must replace all names at once with a sequence 
of length 4'
+        exc = raises(ValueError, "d.names = ('a', 'b', 'c', 4)")
+        assert exc.value[0] == 'item #3 of names is of type int and not string'
+        exc = raises(ValueError, "d.names = ('a', 'b', 'c', u'd')")
+        assert exc.value[0] == 'item #3 of names is of type unicode and not 
string'
+        assert d.names == ("x", "y", "z", "v")
+        d.names = ('x', '', 'v', 'z')
+        assert d.names == ('x', '', 'v', 'z')
+        assert d.fields['v'] == (dtype('<u2'), 8)
+        assert d.fields['z'] == (dtype('<f8'), 10)
+        assert [a[0] for a in d.descr] == ['x', '', 'v', 'z']
+        exc = raises(ValueError, "d.names = ('a', 'b', 'c')")
+        assert exc.value[0] == 'must replace all names at once with a sequence 
of length 4'
+        d.names = ['a', 'b', 'c', 'd']
         assert d.names == ('a', 'b', 'c', 'd')
         exc = raises(ValueError, "d.names = ('a', 'b', 'c', 'c')")
         assert exc.value[0] == "Duplicate field names given."
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to