rok commented on code in PR #38008:
URL: https://github.com/apache/arrow/pull/38008#discussion_r1427141819
##########
python/pyarrow/types.pxi:
##########
@@ -1591,6 +1591,114 @@ cdef class ExtensionType(BaseExtensionType):
return ExtensionScalar
+cdef class VariableShapeTensorType(BaseExtensionType):
+ """
+ Concrete class for variable shape tensor extension type.
+
+ Examples
+ --------
+ Create an instance of variable shape tensor extension type:
+
+ >>> import pyarrow as pa
+ >>> pa.variable_shape_tensor(pa.int32(), 2)
+ VariableShapeTensorType(extension<arrow.variable_shape_tensor>)
+
+ Create an instance of variable shape tensor extension type with
+ permutation:
+
+ >>> tensor_type = pa.variable_shape_tensor(pa.int8(), 3,
+ ... permutation=[0, 2, 1])
+ >>> tensor_type.permutation
+ [0, 2, 1]
+ """
+
+ cdef void init(self, const shared_ptr[CDataType]& type) except *:
+ BaseExtensionType.init(self, type)
+ self.tensor_ext_type = <const CVariableShapeTensorType*> type.get()
+
+ @property
+ def value_type(self):
+ """
+ Data type of an individual tensor.
+ """
+ return pyarrow_wrap_data_type(self.tensor_ext_type.value_type())
+
+ @property
+ def ndim(self):
+ """
+ Number of dimensions of the tensors.
+ """
+ return self.tensor_ext_type.ndim()
+
+ @property
+ def dim_names(self):
+ """
+ Explicit names of the dimensions.
+ """
+ list_of_bytes = self.tensor_ext_type.dim_names()
+ if len(list_of_bytes) != 0:
+ return [frombytes(x) for x in list_of_bytes]
+ else:
+ return None
+
+ @property
+ def permutation(self):
+ """
+ Indices of the dimensions ordering.
+ """
+ indices = self.tensor_ext_type.permutation()
+ if len(indices) != 0:
+ return indices
+ else:
+ return None
+
+ @property
+ def uniform_shape(self):
+ """
+ Shape over dimensions that are guaranteed to be constant.
+ """
+ cdef:
+ vector[optional[int64_t]] c_uniform_shape =
self.tensor_ext_type.uniform_shape()
+ length = c_uniform_shape.size()
+
+ if length == 0:
+ return None
+
+ uniform_shape = []
+ for i in range(length):
+ if c_uniform_shape[i].has_value():
+ uniform_shape.append(c_uniform_shape[i].value())
+ else:
+ uniform_shape.append(None)
+
+ return uniform_shape
+
+ def __arrow_ext_serialize__(self):
+ """
+ Serialized representation of metadata to reconstruct the type object.
+ """
+ return self.tensor_ext_type.Serialize()
+
+ @classmethod
+ def __arrow_ext_deserialize__(self, storage_type, serialized):
+ """
+ Return an VariableShapeTensor type instance from the storage type and
serialized
+ metadata.
+ """
+ return self.tensor_ext_type.Deserialize(storage_type, serialized)
+
+ def __arrow_ext_class__(self):
+ return VariableShapeTensorArray
+
+ def __reduce__(self):
+ return variable_shape_tensor, (self.value_type, self.ndim,
+ self.dim_names, self.permutation,
+ self.uniform_shape)
Review Comment:
Is this the proposed solution?
https://github.com/apache/arrow/issues/39094#issuecomment-1845619410
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]