jorisvandenbossche commented on code in PR #33948:
URL: https://github.com/apache/arrow/pull/33948#discussion_r1093104515


##########
python/pyarrow/tests/test_extension_type.py:
##########
@@ -1079,3 +1082,279 @@ def test_array_constructor_from_pandas():
         pd.Series([1, 2, 3], dtype="category"), type=IntegerType()
     )
     assert result.equals(expected)
+
+
+class TensorType(pa.ExtensionType):
+    """
+    Canonical extension type class for fixed shape tensors.
+
+    Parameters
+    ----------
+    value_type : DataType or Field
+        The data type of an individual tensor
+    shape : tuple
+        shape of the tensors
+    is_row_major : bool
+        boolean indicating the order of elements
+        in memory
+
+    Examples
+    --------
+    >>> import pyarrow as pa
+    >>> tensor_type = TensorType(pa.int32(), (2, 2), 'C')
+    >>> tensor_type
+    TensorType(FixedSizeListType(fixed_size_list<item: int32>[4]))
+    >>> pa.register_extension_type(tensor_type)
+    """
+
+    def __init__(self, value_type, shape, is_row_major):
+        self._value_type = value_type
+        self._shape = shape
+        self._is_row_major = is_row_major
+        size = math.prod(shape)
+        pa.ExtensionType.__init__(self, pa.list_(self._value_type, size),
+                                  'arrow.fixed_size_tensor')
+
+    @property
+    def dtype(self):
+        """
+        Data type of an individual tensor.
+        """
+        return self._value_type
+
+    @property
+    def shape(self):
+        """
+        Shape of the tensors.
+        """
+        return self._shape
+
+    @property
+    def is_row_major(self):
+        """
+        Boolean indicating the order of elements in memory.
+        """
+        return self._is_row_major
+
+    def __arrow_ext_serialize__(self):
+        metadata = {"shape": str(self._shape),
+                    "is_row_major": self._is_row_major}
+        return json.dumps(metadata).encode()
+
+    @classmethod
+    def __arrow_ext_deserialize__(cls, storage_type, serialized):
+        # return an instance of this subclass given the serialized
+        # metadata.
+        assert serialized.decode().startswith('{"shape":')
+        metadata = json.loads(serialized.decode())
+        shape = ast.literal_eval(metadata['shape'])
+        order = metadata["is_row_major"]
+
+        return TensorType(storage_type.value_type, shape, order)
+
+    def __arrow_ext_class__(self):
+        return TensorArray
+
+
+class TensorArray(pa.ExtensionArray):
+    """
+    Canonical extension array class for fixed shape tensors.
+
+    Examples
+    --------
+    Define and register extension type for tensor array
+
+    >>> import pyarrow as pa
+    >>> tensor_type = TensorType(pa.int32(), (2, 2), 'C')
+    >>> pa.register_extension_type(tensor_type)
+
+    Create an extension array
+
+    >>> arr = [[1, 2, 3, 4], [10, 20, 30, 40], [100, 200, 300, 400]]
+    >>> storage = pa.array(arr, pa.list_(pa.int32(), 4))
+    >>> pa.ExtensionArray.from_storage(tensor_type, storage)
+    <__main__.TensorArray object at 0x1491a5a00>
+    [
+      [
+        1,
+        2,
+        3,
+        4
+      ],
+      [
+        10,
+        20,
+        30,
+        40
+      ],
+      [
+        100,
+        200,
+        300,
+        400
+      ]
+    ]
+    """
+
+    def to_numpy_tensor_list(self):

Review Comment:
   I meant to change what the method does. But also not returning an "array of 
arrays", but a single ndarray with 1 more dimension as the shape of the 
individual tensors.
   
   (so for TensorArray with tensor shape of (2, 3) of length 10, return single 
array of shape (10, 2, 3))



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to