This is an automated email from the ASF dual-hosted git repository.
alenka pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new 76f987eb94 GH-35623: [C++][Python] FixedShapeTensorType.ToString()
should print the type's parameters (#36496)
76f987eb94 is described below
commit 76f987eb94d2832bc959de455847a03387cab41a
Author: Alenka Frim <[email protected]>
AuthorDate: Mon Oct 9 05:50:22 2023 +0200
GH-35623: [C++][Python] FixedShapeTensorType.ToString() should print the
type's parameters (#36496)
### Rationale for this change
The string representation of two different `FixedShapeTensorType` objects
is currently the same: `extension<arrow.fixed_shape_tensor>`.
### What changes are included in this PR?
Override general type `ToString()` method for `FixedShapeTensorType`. The
string representation of a tensor in this PR is proposed as follows:
```
extension<arrow.fixed_shape_tensor[value_type=*, shape=[*]]
```
### Are these changes tested?
Yes, in Python and in C++.
### Are there any user-facing changes?
No.
* Closes: #35623
Lead-authored-by: AlenkaF <[email protected]>
Co-authored-by: Alenka Frim <[email protected]>
Co-authored-by: Benjamin Kietzman <[email protected]>
Signed-off-by: AlenkaF <[email protected]>
---
cpp/src/arrow/extension/fixed_shape_tensor.cc | 18 ++++++++++++++
cpp/src/arrow/extension/fixed_shape_tensor.h | 1 +
cpp/src/arrow/extension/fixed_shape_tensor_test.cc | 28 ++++++++++++++++++++++
cpp/src/arrow/util/print.h | 26 ++++++++++++++++++++
docs/source/python/extending_types.rst | 4 ++--
python/pyarrow/tests/test_extension_type.py | 19 +++++++++++++++
python/pyarrow/types.pxi | 6 ++---
7 files changed, 97 insertions(+), 5 deletions(-)
diff --git a/cpp/src/arrow/extension/fixed_shape_tensor.cc
b/cpp/src/arrow/extension/fixed_shape_tensor.cc
index e4195ea9e6..af8305a025 100644
--- a/cpp/src/arrow/extension/fixed_shape_tensor.cc
+++ b/cpp/src/arrow/extension/fixed_shape_tensor.cc
@@ -26,7 +26,9 @@
#include "arrow/tensor.h"
#include "arrow/util/int_util_overflow.h"
#include "arrow/util/logging.h"
+#include "arrow/util/print.h"
#include "arrow/util/sort.h"
+#include "arrow/util/string.h"
#include <rapidjson/document.h>
#include <rapidjson/writer.h>
@@ -104,6 +106,22 @@ bool FixedShapeTensorType::ExtensionEquals(const
ExtensionType& other) const {
permutation_equivalent;
}
+std::string FixedShapeTensorType::ToString() const {
+ std::stringstream ss;
+ ss << "extension<" << this->extension_name()
+ << "[value_type=" << value_type_->ToString()
+ << ", shape=" << ::arrow::internal::PrintVector{shape_, ","};
+
+ if (!permutation_.empty()) {
+ ss << ", permutation=" << ::arrow::internal::PrintVector{permutation_,
","};
+ }
+ if (!dim_names_.empty()) {
+ ss << ", dim_names=[" << internal::JoinStrings(dim_names_, ",") << "]";
+ }
+ ss << "]>";
+ return ss.str();
+}
+
std::string FixedShapeTensorType::Serialize() const {
rj::Document document;
document.SetObject();
diff --git a/cpp/src/arrow/extension/fixed_shape_tensor.h
b/cpp/src/arrow/extension/fixed_shape_tensor.h
index 93837f1300..fcfb1ebbab 100644
--- a/cpp/src/arrow/extension/fixed_shape_tensor.h
+++ b/cpp/src/arrow/extension/fixed_shape_tensor.h
@@ -61,6 +61,7 @@ class ARROW_EXPORT FixedShapeTensorType : public
ExtensionType {
dim_names_(dim_names) {}
std::string extension_name() const override { return
"arrow.fixed_shape_tensor"; }
+ std::string ToString() const override;
/// Number of dimensions of tensor elements
size_t ndim() { return shape_.size(); }
diff --git a/cpp/src/arrow/extension/fixed_shape_tensor_test.cc
b/cpp/src/arrow/extension/fixed_shape_tensor_test.cc
index c3c97bc6e5..b8be1edc49 100644
--- a/cpp/src/arrow/extension/fixed_shape_tensor_test.cc
+++ b/cpp/src/arrow/extension/fixed_shape_tensor_test.cc
@@ -434,4 +434,32 @@ TEST_F(TestExtensionType, ComputeStrides) {
ASSERT_EQ(ext_type_7->Serialize(),
R"({"shape":[3,4,7],"permutation":[2,0,1]})");
}
+TEST_F(TestExtensionType, ToString) {
+ auto exact_ext_type =
internal::checked_pointer_cast<FixedShapeTensorType>(ext_type_);
+
+ auto ext_type_1 = internal::checked_pointer_cast<FixedShapeTensorType>(
+ fixed_shape_tensor(int16(), {3, 4, 7}));
+ auto ext_type_2 = internal::checked_pointer_cast<FixedShapeTensorType>(
+ fixed_shape_tensor(int32(), {3, 4, 7}, {1, 0, 2}));
+ auto ext_type_3 = internal::checked_pointer_cast<FixedShapeTensorType>(
+ fixed_shape_tensor(int64(), {3, 4, 7}, {}, {"C", "H", "W"}));
+
+ std::string result_1 = ext_type_1->ToString();
+ std::string expected_1 =
+ "extension<arrow.fixed_shape_tensor[value_type=int16, shape=[3,4,7]]>";
+ ASSERT_EQ(expected_1, result_1);
+
+ std::string result_2 = ext_type_2->ToString();
+ std::string expected_2 =
+ "extension<arrow.fixed_shape_tensor[value_type=int32, shape=[3,4,7], "
+ "permutation=[1,0,2]]>";
+ ASSERT_EQ(expected_2, result_2);
+
+ std::string result_3 = ext_type_3->ToString();
+ std::string expected_3 =
+ "extension<arrow.fixed_shape_tensor[value_type=int64, shape=[3,4,7], "
+ "dim_names=[C,H,W]]>";
+ ASSERT_EQ(expected_3, result_3);
+}
+
} // namespace arrow
diff --git a/cpp/src/arrow/util/print.h b/cpp/src/arrow/util/print.h
index d11aa443a9..82cea473c5 100644
--- a/cpp/src/arrow/util/print.h
+++ b/cpp/src/arrow/util/print.h
@@ -18,6 +18,9 @@
#pragma once
#include <tuple>
+#include "arrow/util/string.h"
+
+using arrow::internal::ToChars;
namespace arrow {
namespace internal {
@@ -47,5 +50,28 @@ void PrintTuple(OStream* os, const std::tuple<Args&...>&
tup) {
detail::TuplePrinter<OStream, std::tuple<Args&...>,
sizeof...(Args)>::Print(os, tup);
}
+template <typename Range, typename Separator>
+struct PrintVector {
+ const Range& range_;
+ const Separator& separator_;
+
+ template <typename Os> // template to dodge inclusion of <ostream>
+ friend Os& operator<<(Os& os, PrintVector l) {
+ bool first = true;
+ os << "[";
+ for (const auto& element : l.range_) {
+ if (first) {
+ first = false;
+ } else {
+ os << l.separator_;
+ }
+ os << ToChars(element); // use ToChars to avoid locale dependence
+ }
+ os << "]";
+ return os;
+ }
+};
+template <typename Range, typename Separator>
+PrintVector(const Range&, const Separator&) -> PrintVector<Range, Separator>;
} // namespace internal
} // namespace arrow
diff --git a/docs/source/python/extending_types.rst
b/docs/source/python/extending_types.rst
index 53ce70e13b..87f04f37dc 100644
--- a/docs/source/python/extending_types.rst
+++ b/docs/source/python/extending_types.rst
@@ -419,8 +419,8 @@ Extension arrays can be used as columns in
``pyarrow.Table`` or
f0: int8
f1: string
f2: bool
- tensors_int: extension<arrow.fixed_size_tensor>
- tensors_float: extension<arrow.fixed_size_tensor>
+ tensors_int: extension<arrow.fixed_shape_tensor[value_type=int32,
shape=[2,2]]>
+ tensors_float: extension<arrow.fixed_shape_tensor[value_type=float,
shape=[2,2]]>
----
f0: [[1,2,3]]
f1: [["foo","bar",null]]
diff --git a/python/pyarrow/tests/test_extension_type.py
b/python/pyarrow/tests/test_extension_type.py
index 1eb7d5fa76..ce575d984e 100644
--- a/python/pyarrow/tests/test_extension_type.py
+++ b/python/pyarrow/tests/test_extension_type.py
@@ -1351,3 +1351,22 @@ def test_tensor_type_is_picklable(pickle_module):
result = pickle_module.loads(pickle_module.dumps(expected_arr))
assert result == expected_arr
+
+
[email protected](("tensor_type", "text"), [
+ (
+ pa.fixed_shape_tensor(pa.int8(), [2, 2, 3]),
+ 'fixed_shape_tensor[value_type=int8, shape=[2,2,3]]'
+ ),
+ (
+ pa.fixed_shape_tensor(pa.int32(), [2, 2, 3], permutation=[0, 2, 1]),
+ 'fixed_shape_tensor[value_type=int32, shape=[2,2,3],
permutation=[0,2,1]]'
+ ),
+ (
+ pa.fixed_shape_tensor(pa.int64(), [2, 2, 3], dim_names=['C', 'H',
'W']),
+ 'fixed_shape_tensor[value_type=int64, shape=[2,2,3],
dim_names=[C,H,W]]'
+ )
+])
+def test_tensor_type_str(tensor_type, text):
+ tensor_type_str = tensor_type.__str__()
+ assert text in tensor_type_str
diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi
index 9f8b347d56..bd34726adb 100644
--- a/python/pyarrow/types.pxi
+++ b/python/pyarrow/types.pxi
@@ -1558,7 +1558,7 @@ cdef class FixedShapeTensorType(BaseExtensionType):
>>> import pyarrow as pa
>>> pa.fixed_shape_tensor(pa.int32(), [2, 2])
- FixedShapeTensorType(extension<arrow.fixed_shape_tensor>)
+ FixedShapeTensorType(extension<arrow.fixed_shape_tensor[value_type=int32,
shape=[2,2]]>)
Create an instance of fixed shape tensor extension type with
permutation:
@@ -4746,7 +4746,7 @@ def fixed_shape_tensor(DataType value_type, shape,
dim_names=None, permutation=N
>>> import pyarrow as pa
>>> tensor_type = pa.fixed_shape_tensor(pa.int32(), [2, 2])
>>> tensor_type
- FixedShapeTensorType(extension<arrow.fixed_shape_tensor>)
+ FixedShapeTensorType(extension<arrow.fixed_shape_tensor[value_type=int32,
shape=[2,2]]>)
Inspect the data type:
@@ -4762,7 +4762,7 @@ def fixed_shape_tensor(DataType value_type, shape,
dim_names=None, permutation=N
>>> tensor = pa.ExtensionArray.from_storage(tensor_type, storage)
>>> pa.table([tensor], names=["tensor_array"])
pyarrow.Table
- tensor_array: extension<arrow.fixed_shape_tensor>
+ tensor_array: extension<arrow.fixed_shape_tensor[value_type=int32,
shape=[2,2]]>
----
tensor_array: [[[1,2,3,4],[10,20,30,40],[100,200,300,400]]]