jorisvandenbossche commented on code in PR #33948:
URL: https://github.com/apache/arrow/pull/33948#discussion_r1107104715


##########
python/pyarrow/tests/test_extension_type.py:
##########
@@ -1079,3 +1082,282 @@ def test_array_constructor_from_pandas():
         pd.Series([1, 2, 3], dtype="category"), type=IntegerType()
     )
     assert result.equals(expected)
+
+
+class FixedShapeTensorType(pa.ExtensionType):
+    """
+    Canonical extension type class for fixed shape tensors.
+
+    Parameters
+    ----------
+    value_type : DataType or Field
+        The data type of an individual tensor
+    shape : tuple
+        Shape of the tensors
+    dim_names : tuple, default: None
+        Explicit names of the dimensions.
+    permutation : tuple, default: None
+        Indices of the dimensions ordering.
+
+    Examples
+    --------
+    >>> import pyarrow as pa
+    >>> tensor_type = FixedShapeTensorType(pa.int32(), (2, 2))
+    >>> tensor_type
+    FixedShapeTensorType(FixedSizeListType(fixed_size_list<item: int32>[4]))
+    >>> pa.register_extension_type(tensor_type)
+    """
+
+    def __init__(self, value_type, shape, dim_names=None, permutation=None):
+        self._value_type = value_type
+        self._shape = shape
+        size = math.prod(shape)
+        self._dim_names = dim_names
+        self._permutation = permutation
+        pa.ExtensionType.__init__(self, pa.list_(self._value_type, size),
+                                  'arrow.fixed_size_tensor')
+
+    @property
+    def value_type(self):
+        """
+        Data type of an individual tensor.
+        """
+        return self._value_type
+
+    @property
+    def shape(self):
+        """
+        Shape of the tensors.
+        """
+        return self._shape
+
+    @property
+    def dim_names(self):
+        """
+        Explicit names of the dimensions.
+        """
+        return self._dim_names
+
+    @property
+    def permutation(self):
+        """
+        Indices of the dimensions ordering.
+        """
+        return self._permutation
+
+    def __arrow_ext_serialize__(self):
+        metadata = {"shape": str(self._shape),
+                    "dim_names": str(self._dim_names),
+                    "permutation": str(self._permutation)}

Review Comment:
   You will need to protect this from `dim_names` or `permutation` being None. 
We should only include it in the json metadata if they are _not_ None, because 
a serialization of `"dim_names": "None"` violates the specification (if 
present, it should be an array of the same length as shape)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to