jorisvandenbossche commented on code in PR #33948:
URL: https://github.com/apache/arrow/pull/33948#discussion_r1093430820


##########
python/pyarrow/tests/test_extension_type.py:
##########
@@ -1079,3 +1082,279 @@ def test_array_constructor_from_pandas():
         pd.Series([1, 2, 3], dtype="category"), type=IntegerType()
     )
     assert result.equals(expected)
+
+
+class FixedShapeTensorType(pa.ExtensionType):
+    """
+    Canonical extension type class for fixed shape tensors.
+
+    Parameters
+    ----------
+    value_type : DataType or Field
+        The data type of an individual tensor
+    shape : tuple
+        shape of the tensors
+    is_row_major : bool
+        boolean indicating the order of elements
+        in memory
+
+    Examples
+    --------
+    >>> import pyarrow as pa
+    >>> tensor_type = FixedShapeTensorType(pa.int32(), (2, 2), 'C')
+    >>> tensor_type
+    FixedShapeTensorType(FixedSizeListType(fixed_size_list<item: int32>[4]))
+    >>> pa.register_extension_type(tensor_type)
+    """
+
+    def __init__(self, value_type, shape, is_row_major):
+        self._value_type = value_type
+        self._shape = shape
+        self._is_row_major = is_row_major
+        size = math.prod(shape)
+        pa.ExtensionType.__init__(self, pa.list_(self._value_type, size),
+                                  'arrow.fixed_size_tensor')
+
+    @property
+    def value_type(self):
+        """
+        Data type of an individual tensor.
+        """
+        return self._value_type
+
+    @property
+    def shape(self):
+        """
+        Shape of the tensors.
+        """
+        return self._shape
+
+    @property
+    def is_row_major(self):
+        """
+        Boolean indicating the order of elements in memory.
+        """
+        return self._is_row_major
+
+    def __arrow_ext_serialize__(self):
+        metadata = {"shape": str(self._shape),
+                    "is_row_major": self._is_row_major}
+        return json.dumps(metadata).encode()
+
+    @classmethod
+    def __arrow_ext_deserialize__(cls, storage_type, serialized):
+        # return an instance of this subclass given the serialized
+        # metadata.
+        assert serialized.decode().startswith('{"shape":')
+        metadata = json.loads(serialized.decode())
+        shape = ast.literal_eval(metadata['shape'])
+        order = metadata["is_row_major"]
+
+        return FixedShapeTensorType(storage_type.value_type, shape, order)
+
+    def __arrow_ext_class__(self):
+        return FixedShapeTensorArray
+
+
+class FixedShapeTensorArray(pa.ExtensionArray):
+    """
+    Canonical extension array class for fixed shape tensors.
+
+    Examples
+    --------
+    Define and register extension type for tensor array
+
+    >>> import pyarrow as pa
+    >>> tensor_type = FixedShapeTensorType(pa.int32(), (2, 2), 'C')
+    >>> pa.register_extension_type(tensor_type)
+
+    Create an extension array
+
+    >>> arr = [[1, 2, 3, 4], [10, 20, 30, 40], [100, 200, 300, 400]]
+    >>> storage = pa.array(arr, pa.list_(pa.int32(), 4))
+    >>> pa.ExtensionArray.from_storage(tensor_type, storage)
+    <__main__.FixedShapeTensorArray object at ...>
+    [
+      [
+        1,
+        2,
+        3,
+        4
+      ],
+      [
+        10,
+        20,
+        30,
+        40
+      ],
+      [
+        100,
+        200,
+        300,
+        400
+      ]
+    ]
+    """
+
+    def to_numpy_tensor(self):
+        """
+        Convert tensor extension array to a numpy array (with dim+1).
+
+        Examples
+        --------
+
+        """
+        np_flat = np.array([])
+        for tensor in self.storage:

Review Comment:
   It should be possible to do this without iterating over each element. If you 
do `to_numpy()` on the storage array's child values, you get the full column as 
one flat numpy array, that we can then reshape to the correct shape:
   
   Given a storage array of fixed size list:
   
   ```In [29]: arr = pa.array([[1, 2], [3, 4]], pa.list_(pa.int64(), 2))
   
   In [30]: arr
   Out[30]: 
   <pyarrow.lib.FixedSizeListArray object at 0x7ffb663b1000>
   [
     [
       1,
       2
     ],
     [
       3,
       4
     ]
   ]
   
   In [32]: arr.values
   Out[32]: 
   <pyarrow.lib.Int64Array object at 0x7ffb8b0848e0>
   [
     1,
     2,
     3,
     4
   ]
   
   In [33]: arr.values.to_numpy()
   Out[33]: array([1, 2, 3, 4])
   
   In [34]: arr.values.to_numpy().reshape(2, 2)
   Out[34]: 
   array([[1, 2],
          [3, 4]])
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to