jorisvandenbossche commented on code in PR #34883:
URL: https://github.com/apache/arrow/pull/34883#discussion_r1158536564
##########
python/pyarrow/tests/test_extension_type.py:
##########
@@ -1127,3 +1127,89 @@ def test_cpp_extension_in_python(tmpdir):
reconstructed_array = batch.column(0)
assert reconstructed_array.type == uuid_type
assert reconstructed_array == array
+
+
+def test_tensor_type():
+ tensor_type = pa.fixed_shape_tensor(pa.int8(), [2, 3])
+ assert tensor_type.extension_name == "arrow.fixed_shape_tensor"
+ assert tensor_type.storage_type == pa.list_(pa.int8(), 6)
+ assert tensor_type.shape == [2, 3]
+ assert tensor_type.dim_names is None
+ assert tensor_type.permutation is None
+
+ tensor_type = pa.fixed_shape_tensor(pa.float64(), [2, 2, 3],
+ permutation=[0, 2, 1])
+ assert tensor_type.extension_name == "arrow.fixed_shape_tensor"
+ assert tensor_type.storage_type == pa.list_(pa.float64(), 12)
+ assert tensor_type.shape == [2, 2, 3]
+ assert tensor_type.dim_names is None
+ assert tensor_type.permutation == [0, 2, 1]
+
+ tensor_type = pa.fixed_shape_tensor(pa.bool_(), [2, 2, 3],
+ dim_names=['C', 'H', 'W'])
+ assert tensor_type.extension_name == "arrow.fixed_shape_tensor"
+ assert tensor_type.storage_type == pa.list_(pa.bool_(), 12)
+ assert tensor_type.shape == [2, 2, 3]
+ assert tensor_type.dim_names == ['C', 'H', 'W']
+ assert tensor_type.permutation is None
+
+
[email protected]("numpy_order", ('C', 'F'))
+def test_tensor_class_methods(numpy_order):
+ tensor_type = pa.fixed_shape_tensor(pa.float32(), [2, 3])
+ storage = pa.array([[1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 6]],
+ pa.list_(pa.float32(), 6))
+ arr = pa.ExtensionArray.from_storage(tensor_type, storage)
+ expected = np.array(
+ [[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]], dtype=np.float32)
+
+ result = arr.to_numpy_ndarray()
+ np.testing.assert_array_equal(result, expected)
+
+ arr = np.array(
+ [[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]],
+ dtype=np.float32, order=numpy_order)
+ tensor_array_from_numpy = pa.FixedShapeTensorArray.from_numpy_ndarray(arr)
+ assert isinstance(tensor_array_from_numpy.type, pa.FixedShapeTensorType)
+ assert tensor_array_from_numpy.type.value_type == pa.float32()
+ assert tensor_array_from_numpy.type.shape == [2, 3]
Review Comment:
Hmm, OK, so checking the docstring of `ndarray.flatten()` (which is
currently used under the hood), that actually ensures to return C contiguous
data, and so will make a copy if needed (if the original array is not C
contiguous). So this will indeed work correctly, it will just not be zero-copy
in that case (while users might expect this).
If we keep it, we should at least explicitly document this (that the method
is only zero copy if the passed data is row-major). Alternatively, on the short
term we would only allow row major data to avoid surprises (and we can always
expand it later)
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]