rok commented on code in PR #38008:
URL: https://github.com/apache/arrow/pull/38008#discussion_r1427367604


##########
python/pyarrow/types.pxi:
##########
@@ -1591,6 +1591,114 @@ cdef class ExtensionType(BaseExtensionType):
         return ExtensionScalar
 
 
+cdef class VariableShapeTensorType(BaseExtensionType):
+    """
+    Concrete class for variable shape tensor extension type.
+
+    Examples
+    --------
+    Create an instance of variable shape tensor extension type:
+
+    >>> import pyarrow as pa
+    >>> pa.variable_shape_tensor(pa.int32(), 2)
+    VariableShapeTensorType(extension<arrow.variable_shape_tensor>)
+
+    Create an instance of variable shape tensor extension type with
+    permutation:
+
+    >>> tensor_type = pa.variable_shape_tensor(pa.int8(), 3,
+    ...                                     permutation=[0, 2, 1])
+    >>> tensor_type.permutation
+    [0, 2, 1]
+    """
+
+    cdef void init(self, const shared_ptr[CDataType]& type) except *:
+        BaseExtensionType.init(self, type)
+        self.tensor_ext_type = <const CVariableShapeTensorType*> type.get()
+
+    @property
+    def value_type(self):
+        """
+        Data type of an individual tensor.
+        """
+        return pyarrow_wrap_data_type(self.tensor_ext_type.value_type())
+
+    @property
+    def ndim(self):
+        """
+        Number of dimensions of the tensors.
+        """
+        return self.tensor_ext_type.ndim()
+
+    @property
+    def dim_names(self):
+        """
+        Explicit names of the dimensions.
+        """
+        list_of_bytes = self.tensor_ext_type.dim_names()
+        if len(list_of_bytes) != 0:
+            return [frombytes(x) for x in list_of_bytes]
+        else:
+            return None
+
+    @property
+    def permutation(self):
+        """
+        Indices of the dimensions ordering.
+        """
+        indices = self.tensor_ext_type.permutation()
+        if len(indices) != 0:
+            return indices
+        else:
+            return None
+
+    @property
+    def uniform_shape(self):
+        """
+        Shape over dimensions that are guaranteed to be constant.
+        """
+        cdef:
+            vector[optional[int64_t]] c_uniform_shape = 
self.tensor_ext_type.uniform_shape()
+            length = c_uniform_shape.size()
+
+        if length == 0:
+            return None
+
+        uniform_shape = []
+        for i in range(length):
+            if c_uniform_shape[i].has_value():
+                uniform_shape.append(c_uniform_shape[i].value())
+            else:
+                uniform_shape.append(None)
+
+        return uniform_shape
+
+    def __arrow_ext_serialize__(self):
+        """
+        Serialized representation of metadata to reconstruct the type object.
+        """
+        return self.tensor_ext_type.Serialize()
+
+    @classmethod
+    def __arrow_ext_deserialize__(self, storage_type, serialized):
+        """
+        Return an VariableShapeTensor type instance from the storage type and 
serialized
+        metadata.
+        """
+        return self.tensor_ext_type.Deserialize(storage_type, serialized)
+
+    def __arrow_ext_class__(self):
+        return VariableShapeTensorArray
+
+    def __reduce__(self):
+        return variable_shape_tensor, (self.value_type, self.ndim,
+                                       self.dim_names, self.permutation,
+                                       self.uniform_shape)

Review Comment:
   Naively changing:
   ```diff
   --- a/python/pyarrow/includes/libarrow.pxd
   +++ b/python/pyarrow/includes/libarrow.pxd
   @@ -2672,6 +2672,11 @@ cdef extern from 
"arrow/extension/variable_shape_tensor.h" namespace "arrow::ext
            const vector[c_string] dim_names()
            const vector[optional[int64_t]] uniform_shape()
    
   +        CResult[shared_ptr[CBuffer]] Serialize() const
   +
   +        @staticmethod
   +        CResult[unique_ptr[CExtensionType]] Deserialize(
   +            const c_string& type_name, const CBuffer& buffer)
    
   diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi
   index 5e2200ceb..e7baeb889 100644
   --- a/python/pyarrow/types.pxi
   +++ b/python/pyarrow/types.pxi
   @@ -1677,13 +1677,11 @@ cdef class 
VariableShapeTensorType(BaseExtensionType):
            return VariableShapeTensorArray
    
   +     def __arrow_ext_serialize__(self):
   +        return self.Serialize()
   +
   +     @classmethod
   +    def __arrow_ext_deserialize__(cls, storage_type, serialized):
   +        return cls.Deserialize(storage_type, serialized)
   ```
   
   Fails with:
   ```pytest
   _____________________ 
test_variable_shape_tensor_type_is_picklable[builtin_pickle] 
______________________
   
   >   alias = _type_aliases[name]
   E   KeyError: 'extension<arrow.variable_shape_tensor[value_type=int32, 
ndim=2]>'
   
   pyarrow/types.pxi:5221: KeyError
   
   During handling of the above exception, another exception occurred:
   
   pickle_module = <module 'pickle' from '/usr/lib/python3.10/pickle.py'>
   
       def test_variable_shape_tensor_type_is_picklable(pickle_module):
           expected_type = pa.variable_shape_tensor(pa.int32(), 2)
   >       result = pickle_module.loads(pickle_module.dumps(expected_type))
   
   pyarrow/tests/test_extension_type.py:1667: 
   _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _
   
   >   raise ValueError('No type alias for {0}'.format(name))
   E   ValueError: No type alias for 
extension<arrow.variable_shape_tensor[value_type=int32, ndim=2]>
   
   pyarrow/types.pxi:5223: ValueError
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to