This is an automated email from the ASF dual-hosted git repository.

wesm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new d623567  ARROW-2341: [Python] Improve pa.union() mode argument 
behaviour
d623567 is described below

commit d623567da78a11572b3992b26788f76c5c241434
Author: Antoine Pitrou <[email protected]>
AuthorDate: Thu Mar 22 13:49:13 2018 -0400

    ARROW-2341: [Python] Improve pa.union() mode argument behaviour
    
    Also:
    - make UnionType.mode return a string ('sparse' or 'dense')
    - make UnionType indexing return fields, not types (like StructType)
    
    Author: Antoine Pitrou <[email protected]>
    
    Closes #1778 from pitrou/ARROW-2341-python-union and squashes the following 
commits:
    
    c215f9bd <Antoine Pitrou> ARROW-2341:  Improve pa.union() mode argument 
behaviour
---
 python/pyarrow/lib.pxd             |  5 ----
 python/pyarrow/scalar.pxi          |  4 ++--
 python/pyarrow/tests/test_types.py | 29 ++++++++++++++++++----
 python/pyarrow/types.pxi           | 49 ++++++++++++++++++++++++++++----------
 4 files changed, 63 insertions(+), 24 deletions(-)

diff --git a/python/pyarrow/lib.pxd b/python/pyarrow/lib.pxd
index e4d574f..be103b3 100644
--- a/python/pyarrow/lib.pxd
+++ b/python/pyarrow/lib.pxd
@@ -56,11 +56,6 @@ cdef class DictionaryType(DataType):
         const CDictionaryType* dict_type
 
 
-cdef class UnionType(DataType):
-    cdef:
-        list child_types
-
-
 cdef class TimestampType(DataType):
     cdef:
         const CTimestampType* ts_type
diff --git a/python/pyarrow/scalar.pxi b/python/pyarrow/scalar.pxi
index 2692ace..a0f8480 100644
--- a/python/pyarrow/scalar.pxi
+++ b/python/pyarrow/scalar.pxi
@@ -334,9 +334,9 @@ cdef class UnionValue(ArrayValue):
         cdef int8_t type_id = self.ap.raw_type_ids()[i]
         cdef shared_ptr[CArray] child = self.ap.child(type_id)
         if self.ap.mode() == _UnionMode_SPARSE:
-            return box_scalar(self.type[type_id], child, i)
+            return box_scalar(self.type[type_id].type, child, i)
         else:
-            return box_scalar(self.type[type_id], child,
+            return box_scalar(self.type[type_id].type, child,
                               self.ap.value_offset(i))
 
     def as_py(self):
diff --git a/python/pyarrow/tests/test_types.py 
b/python/pyarrow/tests/test_types.py
index ad683e9..6459496 100644
--- a/python/pyarrow/tests/test_types.py
+++ b/python/pyarrow/tests/test_types.py
@@ -88,10 +88,11 @@ def test_is_nested_or_struct():
 
 
 def test_is_union():
-    assert types.is_union(pa.union([pa.field('a', pa.int32()),
-                                    pa.field('b', pa.int8()),
-                                    pa.field('c', pa.string())],
-                                   pa.lib.UnionMode_SPARSE))
+    for mode in [pa.lib.UnionMode_SPARSE, pa.lib.UnionMode_DENSE]:
+        assert types.is_union(pa.union([pa.field('a', pa.int32()),
+                                        pa.field('b', pa.int8()),
+                                        pa.field('c', pa.string())],
+                                       mode=mode))
     assert not types.is_union(pa.list_(pa.int32()))
 
 
@@ -141,6 +142,26 @@ def test_timestamp_type():
     assert isinstance(pa.timestamp('ns'), pa.TimestampType)
 
 
+def test_union_type():
+    def check_fields(ty, fields):
+        assert ty.num_children == len(fields)
+        assert [ty[i] for i in range(ty.num_children)] == fields
+
+    fields = [pa.field('x', pa.list_(pa.int32())),
+              pa.field('y', pa.binary())]
+    for mode in ('sparse', pa.lib.UnionMode_SPARSE):
+        ty = pa.union(fields, mode=mode)
+        assert ty.mode == 'sparse'
+        check_fields(ty, fields)
+    for mode in ('dense', pa.lib.UnionMode_DENSE):
+        ty = pa.union(fields, mode=mode)
+        assert ty.mode == 'dense'
+        check_fields(ty, fields)
+    for mode in ('unknown', 2):
+        with pytest.raises(ValueError, match='Invalid union mode'):
+            pa.union(fields, mode=mode)
+
+
 def test_types_hashable():
     types = [
         pa.null(),
diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi
index 5f96290..1294850 100644
--- a/python/pyarrow/types.pxi
+++ b/python/pyarrow/types.pxi
@@ -189,9 +189,6 @@ cdef class UnionType(DataType):
 
     cdef void init(self, const shared_ptr[CDataType]& type):
         DataType.init(self, type)
-        self.child_types = [
-            pyarrow_wrap_data_type(type.get().child(i).get().type())
-            for i in range(self.num_children)]
 
     property num_children:
 
@@ -202,20 +199,25 @@ cdef class UnionType(DataType):
 
         def __get__(self):
             cdef CUnionType* type = <CUnionType*> self.sp_type.get()
-            return type.mode()
+            cdef int mode = type.mode()
+            if mode == _UnionMode_DENSE:
+                return 'dense'
+            if mode == _UnionMode_SPARSE:
+                return 'sparse'
+            assert 0
 
     def __getitem__(self, i):
-        return self.child_types[i]
+        return pyarrow_wrap_field(self.type.child(i))
 
     def __getstate__(self):
-        children = [pyarrow_wrap_field(self.type.child(i))
-                    for i in range(self.num_children)]
+        children = [self[i] for i in range(self.num_children)]
         return children, self.mode
 
     def __setstate__(self, state):
         cdef DataType reconstituted = union(*state)
         self.init(reconstituted.sp_type)
 
+
 cdef class TimestampType(DataType):
 
     cdef void init(self, const shared_ptr[CDataType]& type):
@@ -1145,6 +1147,16 @@ def struct(fields):
 def union(children_fields, mode):
     """
     Create UnionType from children fields.
+
+    Parameters
+    ----------
+    fields : sequence of Field values
+    mode : str
+        'dense' or 'sparse'
+
+    Returns
+    -------
+    type : DataType
     """
     cdef:
         Field child_field
@@ -1153,16 +1165,27 @@ def union(children_fields, mode):
         shared_ptr[CDataType] union_type
         int i
 
+    if isinstance(mode, int):
+        if mode not in (_UnionMode_SPARSE, _UnionMode_DENSE):
+            raise ValueError("Invalid union mode {0!r}".format(mode))
+    else:
+        if mode == 'sparse':
+            mode = _UnionMode_SPARSE
+        elif mode == 'dense':
+            mode = _UnionMode_DENSE
+        else:
+            raise ValueError("Invalid union mode {0!r}".format(mode))
+
     for i, child_field in enumerate(children_fields):
         type_codes.push_back(i)
         c_fields.push_back(child_field.sp_field)
 
-        if mode == UnionMode_SPARSE:
-            union_type.reset(new CUnionType(c_fields, type_codes,
-                                            _UnionMode_SPARSE))
-        else:
-            union_type.reset(new CUnionType(c_fields, type_codes,
-                                            _UnionMode_DENSE))
+    if mode == UnionMode_SPARSE:
+        union_type.reset(new CUnionType(c_fields, type_codes,
+                                        _UnionMode_SPARSE))
+    else:
+        union_type.reset(new CUnionType(c_fields, type_codes,
+                                        _UnionMode_DENSE))
 
     return pyarrow_wrap_data_type(union_type)
 

-- 
To stop receiving notification emails like this one, please contact
[email protected].

Reply via email to