This is an automated email from the ASF dual-hosted git repository.
wesm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 548e194 ARROW-4673: [C++] Implement Scalar::Equals and Datum::Equals
548e194 is described below
commit 548e1949d527717d7821a4ab2f09ff7c39882152
Author: François Saint-Jacques <[email protected]>
AuthorDate: Thu Mar 14 20:03:44 2019 -0500
ARROW-4673: [C++] Implement Scalar::Equals and Datum::Equals
Handy for validating kernels.
Author: François Saint-Jacques <[email protected]>
Author: Wes McKinney <[email protected]>
Closes #3875 from fsaintjacques/ARROW-4673-datum-equal and squashes the
following commits:
3fff08785 <Wes McKinney> Add common base class for some primitive scalar, a
little DRY
093e1bd55 <François Saint-Jacques> Fix struct Scalar warning
66cae36d8 <François Saint-Jacques> Fix warnings.
7a7c0d6a1 <François Saint-Jacques> ARROW-4673: Implement Scalar::Equals
and Datum::Equals
---
cpp/src/arrow/compare.cc | 91 +++++++++++++++++++++++++
cpp/src/arrow/compare.h | 6 ++
cpp/src/arrow/compute/kernel.h | 46 +++++++++++++
cpp/src/arrow/compute/kernels/aggregate-test.cc | 4 +-
cpp/src/arrow/scalar-test.cc | 15 ++++
cpp/src/arrow/scalar.cc | 9 ++-
cpp/src/arrow/scalar.h | 32 +++++++--
cpp/src/arrow/testing/gtest_util.cc | 6 ++
cpp/src/arrow/testing/gtest_util.h | 8 +++
cpp/src/arrow/type_fwd.h | 2 +
cpp/src/arrow/util/memory.h | 12 ++++
cpp/src/arrow/visitor.cc | 39 +++++++++++
cpp/src/arrow/visitor.h | 32 +++++++++
cpp/src/arrow/visitor_inline.h | 18 +++++
14 files changed, 308 insertions(+), 12 deletions(-)
diff --git a/cpp/src/arrow/compare.cc b/cpp/src/arrow/compare.cc
index fcb16b5..aca6094 100644
--- a/cpp/src/arrow/compare.cc
+++ b/cpp/src/arrow/compare.cc
@@ -30,6 +30,7 @@
#include "arrow/array.h"
#include "arrow/buffer.h"
+#include "arrow/scalar.h"
#include "arrow/sparse_tensor.h"
#include "arrow/status.h"
#include "arrow/tensor.h"
@@ -38,6 +39,7 @@
#include "arrow/util/checked_cast.h"
#include "arrow/util/logging.h"
#include "arrow/util/macros.h"
+#include "arrow/util/memory.h"
#include "arrow/visitor_inline.h"
namespace arrow {
@@ -717,6 +719,78 @@ class TypeEqualsVisitor {
bool result_;
};
+class ScalarEqualsVisitor {
+ public:
+ explicit ScalarEqualsVisitor(const Scalar& right) : right_(right),
result_(false) {}
+
+ Status Visit(const NullScalar& left) {
+ result_ = true;
+ return Status::OK();
+ }
+
+ template <typename T>
+ typename std::enable_if<std::is_base_of<internal::PrimitiveScalar, T>::value,
+ Status>::type
+ Visit(const T& left_) {
+ const auto& right = checked_cast<const T&>(right_);
+ result_ = right.value == left_.value;
+ return Status::OK();
+ }
+
+ template <typename T>
+ typename std::enable_if<std::is_base_of<BinaryScalar, T>::value,
Status>::type Visit(
+ const T& left_) {
+ const auto& left = checked_cast<const BinaryScalar&>(left_);
+ const auto& right = checked_cast<const BinaryScalar&>(right_);
+ result_ = internal::SharedPtrEquals(left.value, right.value);
+ return Status::OK();
+ }
+
+ Status Visit(const Decimal128Scalar& left) {
+ const auto& right = checked_cast<const Decimal128Scalar&>(right_);
+ result_ = left.value == right.value;
+ return Status::OK();
+ }
+
+ Status Visit(const ListScalar& left) {
+ const auto& right = checked_cast<const ListScalar&>(right_);
+ result_ = internal::SharedPtrEquals(left.value, right.value);
+ return Status::OK();
+ }
+
+ Status Visit(const StructScalar& left) {
+ const auto& right = checked_cast<const StructScalar&>(right_);
+
+ if (right.value.size() != left.value.size()) {
+ result_ = false;
+ } else {
+ bool all_equals = true;
+ for (size_t i = 0; i < left.value.size() && all_equals; i++) {
+ all_equals &= internal::SharedPtrEquals(left.value[i], right.value[i]);
+ }
+ result_ = all_equals;
+ }
+
+ return Status::OK();
+ }
+
+ Status Visit(const UnionScalar& left) { return
Status::NotImplemented("union"); }
+
+ Status Visit(const DictionaryScalar& left) {
+ return Status::NotImplemented("dictionary");
+ }
+
+ Status Visit(const ExtensionScalar& left) {
+ return Status::NotImplemented("extension");
+ }
+
+ bool result() const { return result_; }
+
+ protected:
+ const Scalar& right_;
+ bool result_;
+};
+
} // namespace internal
bool ArrayEquals(const Array& left, const Array& right) {
@@ -915,4 +989,21 @@ bool TypeEquals(const DataType& left, const DataType&
right, bool check_metadata
return are_equal;
}
+bool ScalarEquals(const Scalar& left, const Scalar& right) {
+ bool are_equal = false;
+ if (&left == &right) {
+ are_equal = true;
+ } else if (!left.type->Equals(right.type)) {
+ are_equal = false;
+ } else if (left.is_valid != right.is_valid) {
+ are_equal = false;
+ } else {
+ internal::ScalarEqualsVisitor visitor(right);
+ auto error = VisitScalarInline(left, &visitor);
+ DCHECK_OK(error);
+ are_equal = visitor.result();
+ }
+ return are_equal;
+}
+
} // namespace arrow
diff --git a/cpp/src/arrow/compare.h b/cpp/src/arrow/compare.h
index 4bb2de4..a0c24c9 100644
--- a/cpp/src/arrow/compare.h
+++ b/cpp/src/arrow/compare.h
@@ -30,6 +30,7 @@ class Array;
class DataType;
class Tensor;
class SparseTensor;
+struct Scalar;
/// Returns true if the arrays are exactly equal
bool ARROW_EXPORT ArrayEquals(const Array& left, const Array& right);
@@ -56,6 +57,11 @@ bool ARROW_EXPORT ArrayRangeEquals(const Array& left, const
Array& right,
bool ARROW_EXPORT TypeEquals(const DataType& left, const DataType& right,
bool check_metadata = true);
+/// Returns true if scalars are equal
+/// \param[in] left a Scalar
+/// \param[in] right a Scalar
+bool ARROW_EXPORT ScalarEquals(const Scalar& left, const Scalar& right);
+
} // namespace arrow
#endif // ARROW_COMPARE_H
diff --git a/cpp/src/arrow/compute/kernel.h b/cpp/src/arrow/compute/kernel.h
index 387715d..56d7d68 100644
--- a/cpp/src/arrow/compute/kernel.h
+++ b/cpp/src/arrow/compute/kernel.h
@@ -27,6 +27,7 @@
#include "arrow/scalar.h"
#include "arrow/table.h"
#include "arrow/util/macros.h"
+#include "arrow/util/memory.h"
#include "arrow/util/variant.h" // IWYU pragma: export
#include "arrow/util/visibility.h"
@@ -59,6 +60,10 @@ class ARROW_EXPORT OpKernel {
virtual std::shared_ptr<DataType> out_type() const = 0;
};
+struct Datum;
+static inline bool CollectionEquals(const std::vector<Datum>& left,
+ const std::vector<Datum>& right);
+
/// \class Datum
/// \brief Variant type for various Arrow C++ data structures
struct ARROW_EXPORT Datum {
@@ -153,6 +158,14 @@ struct ARROW_EXPORT Datum {
return util::get<std::shared_ptr<ChunkedArray>>(this->value);
}
+ std::shared_ptr<RecordBatch> record_batch() const {
+ return util::get<std::shared_ptr<RecordBatch>>(this->value);
+ }
+
+ std::shared_ptr<Table> table() const {
+ return util::get<std::shared_ptr<Table>>(this->value);
+ }
+
const std::vector<Datum> collection() const {
return util::get<std::vector<Datum>>(this->value);
}
@@ -182,6 +195,29 @@ struct ARROW_EXPORT Datum {
}
return NULLPTR;
}
+
+ bool Equals(const Datum& other) const {
+ if (this->kind() != other.kind()) return false;
+
+ switch (this->kind()) {
+ case Datum::NONE:
+ return true;
+ case Datum::SCALAR:
+ return internal::SharedPtrEquals(this->scalar(), other.scalar());
+ case Datum::ARRAY:
+ return internal::SharedPtrEquals(this->make_array(),
other.make_array());
+ case Datum::CHUNKED_ARRAY:
+ return internal::SharedPtrEquals(this->chunked_array(),
other.chunked_array());
+ case Datum::RECORD_BATCH:
+ return internal::SharedPtrEquals(this->record_batch(),
other.record_batch());
+ case Datum::TABLE:
+ return internal::SharedPtrEquals(this->table(), other.table());
+ case Datum::COLLECTION:
+ return CollectionEquals(this->collection(), other.collection());
+ default:
+ return false;
+ }
+ }
};
/// \class UnaryKernel
@@ -214,6 +250,16 @@ class ARROW_EXPORT BinaryKernel : public OpKernel {
Datum* out) = 0;
};
+static inline bool CollectionEquals(const std::vector<Datum>& left,
+ const std::vector<Datum>& right) {
+ if (left.size() != right.size()) return false;
+
+ for (size_t i = 0; i < left.size(); i++)
+ if (!left[i].Equals(right[i])) return false;
+
+ return true;
+}
+
} // namespace compute
} // namespace arrow
diff --git a/cpp/src/arrow/compute/kernels/aggregate-test.cc
b/cpp/src/arrow/compute/kernels/aggregate-test.cc
index cbe91a2..fd3b6d9 100644
--- a/cpp/src/arrow/compute/kernels/aggregate-test.cc
+++ b/cpp/src/arrow/compute/kernels/aggregate-test.cc
@@ -267,10 +267,10 @@ void ValidateCount(FunctionContext* ctx, const Array&
input, CountPair expected)
Datum result;
ASSERT_OK(Count(ctx, all, input, &result));
- DatumEqual<Int64Type>::EnsureEqual(result, Datum(expected.first));
+ AssertDatumsEqual(result, Datum(expected.first));
ASSERT_OK(Count(ctx, nulls, input, &result));
- DatumEqual<Int64Type>::EnsureEqual(result, Datum(expected.second));
+ AssertDatumsEqual(result, Datum(expected.second));
}
template <typename ArrowType>
diff --git a/cpp/src/arrow/scalar-test.cc b/cpp/src/arrow/scalar-test.cc
index 580f480..67af4fb 100644
--- a/cpp/src/arrow/scalar-test.cc
+++ b/cpp/src/arrow/scalar-test.cc
@@ -76,8 +76,12 @@ TYPED_TEST(TestNumericScalar, Basics) {
ASSERT_TRUE(scalar_val->type->Equals(*expected_type));
T other_value = static_cast<T>(2);
+ auto scalar_other = std::make_shared<ScalarType>(other_value);
+ ASSERT_FALSE(scalar_other->Equals(scalar_val));
+
scalar_val->value = other_value;
ASSERT_EQ(other_value, scalar_val->value);
+ ASSERT_TRUE(scalar_other->Equals(scalar_val));
ScalarType stack_val = ScalarType(0, false);
ASSERT_FALSE(stack_val.is_valid);
@@ -106,6 +110,13 @@ TEST(TestBinaryScalar, Basics) {
ASSERT_TRUE(value2.is_valid);
ASSERT_TRUE(value2.type->Equals(*utf8()));
+ // Same buffer, different type.
+ ASSERT_FALSE(value2.Equals(value));
+
+ StringScalar value3(buf);
+ // Same buffer, same type.
+ ASSERT_TRUE(value2.Equals(value3));
+
StringScalar null_value2(nullptr, false);
ASSERT_FALSE(null_value2.is_valid);
}
@@ -182,6 +193,10 @@ TEST(TestTimestampScalars, Basics) {
ASSERT_TRUE(ts_val1.is_valid);
ASSERT_FALSE(ts_null.is_valid);
ASSERT_TRUE(ts_null.type->Equals(*type1));
+
+ ASSERT_FALSE(ts_val1.Equals(ts_val2));
+ ASSERT_FALSE(ts_val1.Equals(ts_null));
+ ASSERT_FALSE(ts_val2.Equals(ts_null));
}
// TODO test HalfFloatScalar
diff --git a/cpp/src/arrow/scalar.cc b/cpp/src/arrow/scalar.cc
index a9d7c5d..9888598 100644
--- a/cpp/src/arrow/scalar.cc
+++ b/cpp/src/arrow/scalar.cc
@@ -21,6 +21,7 @@
#include "arrow/array.h"
#include "arrow/buffer.h"
+#include "arrow/compare.h"
#include "arrow/type.h"
#include "arrow/util/checked_cast.h"
#include "arrow/util/decimal.h"
@@ -30,21 +31,23 @@ namespace arrow {
using internal::checked_cast;
+bool Scalar::Equals(const Scalar& other) const { return ScalarEquals(*this,
other); }
+
Time32Scalar::Time32Scalar(int32_t value, const std::shared_ptr<DataType>&
type,
bool is_valid)
- : Scalar{type, is_valid}, value(value) {
+ : internal::PrimitiveScalar{type, is_valid}, value(value) {
DCHECK_EQ(Type::TIME32, type->id());
}
Time64Scalar::Time64Scalar(int64_t value, const std::shared_ptr<DataType>&
type,
bool is_valid)
- : Scalar{type, is_valid}, value(value) {
+ : internal::PrimitiveScalar{type, is_valid}, value(value) {
DCHECK_EQ(Type::TIME64, type->id());
}
TimestampScalar::TimestampScalar(int64_t value, const
std::shared_ptr<DataType>& type,
bool is_valid)
- : Scalar{type, is_valid}, value(value) {
+ : internal::PrimitiveScalar{type, is_valid}, value(value) {
DCHECK_EQ(Type::TIMESTAMP, type->id());
}
diff --git a/cpp/src/arrow/scalar.h b/cpp/src/arrow/scalar.h
index b90b26e..ce714bf 100644
--- a/cpp/src/arrow/scalar.h
+++ b/cpp/src/arrow/scalar.h
@@ -48,6 +48,12 @@ struct ARROW_EXPORT Scalar {
/// \brief Whether the value is valid (not null) or not
bool is_valid;
+ bool Equals(const Scalar& other) const;
+ bool Equals(const std::shared_ptr<Scalar>& other) const {
+ if (other) return Equals(*other);
+ return false;
+ }
+
protected:
Scalar(const std::shared_ptr<DataType>& type, bool is_valid)
: type(type), is_valid(is_valid) {}
@@ -59,14 +65,22 @@ struct ARROW_EXPORT NullScalar : public Scalar {
NullScalar() : Scalar{null(), false} {}
};
-struct ARROW_EXPORT BooleanScalar : public Scalar {
+namespace internal {
+
+struct ARROW_EXPORT PrimitiveScalar : public Scalar {
+ using Scalar::Scalar;
+};
+
+} // namespace internal
+
+struct ARROW_EXPORT BooleanScalar : public internal::PrimitiveScalar {
bool value;
explicit BooleanScalar(bool value, bool is_valid = true)
- : Scalar{boolean(), is_valid}, value(value) {}
+ : internal::PrimitiveScalar{boolean(), is_valid}, value(value) {}
};
template <typename Type>
-struct NumericScalar : public Scalar {
+struct NumericScalar : public internal::PrimitiveScalar {
using T = typename Type::c_type;
T value;
@@ -75,7 +89,7 @@ struct NumericScalar : public Scalar {
protected:
explicit NumericScalar(T value, const std::shared_ptr<DataType>& type, bool
is_valid)
- : Scalar{type, is_valid}, value(value) {}
+ : internal::PrimitiveScalar{type, is_valid}, value(value) {}
};
struct ARROW_EXPORT BinaryScalar : public Scalar {
@@ -109,21 +123,21 @@ class ARROW_EXPORT Date64Scalar : public
NumericScalar<Date64Type> {
using NumericScalar<Date64Type>::NumericScalar;
};
-class ARROW_EXPORT Time32Scalar : public Scalar {
+class ARROW_EXPORT Time32Scalar : public internal::PrimitiveScalar {
public:
int32_t value;
Time32Scalar(int32_t value, const std::shared_ptr<DataType>& type,
bool is_valid = true);
};
-class ARROW_EXPORT Time64Scalar : public Scalar {
+class ARROW_EXPORT Time64Scalar : public internal::PrimitiveScalar {
public:
int64_t value;
Time64Scalar(int64_t value, const std::shared_ptr<DataType>& type,
bool is_valid = true);
};
-class ARROW_EXPORT TimestampScalar : public Scalar {
+class ARROW_EXPORT TimestampScalar : public internal::PrimitiveScalar {
public:
int64_t value;
TimestampScalar(int64_t value, const std::shared_ptr<DataType>& type,
@@ -149,4 +163,8 @@ struct ARROW_EXPORT StructScalar : public Scalar {
std::vector<std::shared_ptr<Scalar>> value;
};
+class ARROW_EXPORT UnionScalar : public Scalar {};
+class ARROW_EXPORT DictionaryScalar : public Scalar {};
+class ARROW_EXPORT ExtensionScalar : public Scalar {};
+
} // namespace arrow
diff --git a/cpp/src/arrow/testing/gtest_util.cc
b/cpp/src/arrow/testing/gtest_util.cc
index 1db631b..4811954 100644
--- a/cpp/src/arrow/testing/gtest_util.cc
+++ b/cpp/src/arrow/testing/gtest_util.cc
@@ -36,6 +36,7 @@
#include "arrow/array.h"
#include "arrow/buffer.h"
+#include "arrow/compute/kernel.h"
#include "arrow/ipc/json-simple.h"
#include "arrow/pretty_print.h"
#include "arrow/status.h"
@@ -104,6 +105,11 @@ void AssertSchemaEqual(const Schema& lhs, const Schema&
rhs) {
}
}
+void AssertDatumsEqual(const Datum& expected, const Datum& actual) {
+ // TODO: Implements better print.
+ ASSERT_TRUE(actual.Equals(expected));
+}
+
std::shared_ptr<Array> ArrayFromJSON(const std::shared_ptr<DataType>& type,
const std::string& json) {
std::shared_ptr<Array> out;
diff --git a/cpp/src/arrow/testing/gtest_util.h
b/cpp/src/arrow/testing/gtest_util.h
index 7f46dfd..88f3d12 100644
--- a/cpp/src/arrow/testing/gtest_util.h
+++ b/cpp/src/arrow/testing/gtest_util.h
@@ -105,6 +105,12 @@ class ChunkedArray;
class Column;
class Table;
+namespace compute {
+struct Datum;
+}
+
+using Datum = compute::Datum;
+
using ArrayVector = std::vector<std::shared_ptr<Array>>;
#define ASSERT_PP_EQUAL(LEFT, RIGHT)
\
@@ -137,6 +143,8 @@ ARROW_EXPORT void PrintColumn(const Column& col,
std::stringstream* ss);
ARROW_EXPORT void AssertTablesEqual(const Table& expected, const Table& actual,
bool same_chunk_layout = true);
+ARROW_EXPORT void AssertDatumsEqual(const Datum& expected, const Datum&
actual);
+
template <typename C_TYPE>
void AssertNumericDataEqual(const C_TYPE* raw_data,
const std::vector<C_TYPE>& expected_values) {
diff --git a/cpp/src/arrow/type_fwd.h b/cpp/src/arrow/type_fwd.h
index b995b88..9a8d3ef 100644
--- a/cpp/src/arrow/type_fwd.h
+++ b/cpp/src/arrow/type_fwd.h
@@ -89,6 +89,7 @@ struct Decimal128Scalar;
class UnionType;
class UnionArray;
+class UnionScalar;
template <typename TypeClass>
class NumericArray;
@@ -154,6 +155,7 @@ class IntervalScalar;
class ExtensionType;
class ExtensionArray;
+class ExtensionScalar;
// ----------------------------------------------------------------------
// (parameter-free) Factory functions
diff --git a/cpp/src/arrow/util/memory.h b/cpp/src/arrow/util/memory.h
index 62641e3..2d2a105 100644
--- a/cpp/src/arrow/util/memory.h
+++ b/cpp/src/arrow/util/memory.h
@@ -19,6 +19,9 @@
#define ARROW_UTIL_MEMORY_H
#include <cstdint>
+#include <memory>
+
+#include "arrow/util/macros.h"
namespace arrow {
namespace internal {
@@ -28,6 +31,15 @@ namespace internal {
void parallel_memcopy(uint8_t* dst, const uint8_t* src, int64_t nbytes,
uintptr_t block_size, int num_threads);
+// A helper function for checking if two wrapped objects implementing `Equals`
+// are equal.
+template <typename T>
+bool SharedPtrEquals(const std::shared_ptr<T>& left, const std::shared_ptr<T>&
right) {
+ if (left == right) return true;
+ if (left == NULLPTR || right == NULLPTR) return false;
+ return left->Equals(*right);
+}
+
} // namespace internal
} // namespace arrow
diff --git a/cpp/src/arrow/visitor.cc b/cpp/src/arrow/visitor.cc
index 76e5287..5a601c8 100644
--- a/cpp/src/arrow/visitor.cc
+++ b/cpp/src/arrow/visitor.cc
@@ -21,6 +21,7 @@
#include "arrow/array.h"
#include "arrow/extension_type.h"
+#include "arrow/scalar.h"
#include "arrow/status.h"
#include "arrow/type.h"
@@ -101,4 +102,42 @@ TYPE_VISITOR_DEFAULT(ExtensionType)
#undef TYPE_VISITOR_DEFAULT
+// ----------------------------------------------------------------------
+// Default implementations of ScalarVisitor methods
+
+#define SCALAR_VISITOR_DEFAULT(SCALAR_CLASS) \
+ Status ScalarVisitor::Visit(const SCALAR_CLASS& scalar) { \
+ return Status::NotImplemented( \
+ "ScalarVisitor not implemented for " ARROW_STRINGIFY(SCALAR_CLASS)); \
+ }
+
+SCALAR_VISITOR_DEFAULT(NullScalar)
+SCALAR_VISITOR_DEFAULT(BooleanScalar)
+SCALAR_VISITOR_DEFAULT(Int8Scalar)
+SCALAR_VISITOR_DEFAULT(Int16Scalar)
+SCALAR_VISITOR_DEFAULT(Int32Scalar)
+SCALAR_VISITOR_DEFAULT(Int64Scalar)
+SCALAR_VISITOR_DEFAULT(UInt8Scalar)
+SCALAR_VISITOR_DEFAULT(UInt16Scalar)
+SCALAR_VISITOR_DEFAULT(UInt32Scalar)
+SCALAR_VISITOR_DEFAULT(UInt64Scalar)
+SCALAR_VISITOR_DEFAULT(HalfFloatScalar)
+SCALAR_VISITOR_DEFAULT(FloatScalar)
+SCALAR_VISITOR_DEFAULT(DoubleScalar)
+SCALAR_VISITOR_DEFAULT(StringScalar)
+SCALAR_VISITOR_DEFAULT(BinaryScalar)
+SCALAR_VISITOR_DEFAULT(FixedSizeBinaryScalar)
+SCALAR_VISITOR_DEFAULT(Date64Scalar)
+SCALAR_VISITOR_DEFAULT(Date32Scalar)
+SCALAR_VISITOR_DEFAULT(Time32Scalar)
+SCALAR_VISITOR_DEFAULT(Time64Scalar)
+SCALAR_VISITOR_DEFAULT(TimestampScalar)
+SCALAR_VISITOR_DEFAULT(IntervalScalar)
+SCALAR_VISITOR_DEFAULT(Decimal128Scalar)
+SCALAR_VISITOR_DEFAULT(ListScalar)
+SCALAR_VISITOR_DEFAULT(StructScalar)
+SCALAR_VISITOR_DEFAULT(DictionaryScalar)
+
+#undef SCALAR_VISITOR_DEFAULT
+
} // namespace arrow
diff --git a/cpp/src/arrow/visitor.h b/cpp/src/arrow/visitor.h
index d1e3e3a..9806eff 100644
--- a/cpp/src/arrow/visitor.h
+++ b/cpp/src/arrow/visitor.h
@@ -92,6 +92,38 @@ class ARROW_EXPORT TypeVisitor {
virtual Status Visit(const ExtensionType& type);
};
+class ARROW_EXPORT ScalarVisitor {
+ public:
+ virtual ~ScalarVisitor() = default;
+
+ virtual Status Visit(const NullScalar& scalar);
+ virtual Status Visit(const BooleanScalar& scalar);
+ virtual Status Visit(const Int8Scalar& scalar);
+ virtual Status Visit(const Int16Scalar& scalar);
+ virtual Status Visit(const Int32Scalar& scalar);
+ virtual Status Visit(const Int64Scalar& scalar);
+ virtual Status Visit(const UInt8Scalar& scalar);
+ virtual Status Visit(const UInt16Scalar& scalar);
+ virtual Status Visit(const UInt32Scalar& scalar);
+ virtual Status Visit(const UInt64Scalar& scalar);
+ virtual Status Visit(const HalfFloatScalar& scalar);
+ virtual Status Visit(const FloatScalar& scalar);
+ virtual Status Visit(const DoubleScalar& scalar);
+ virtual Status Visit(const StringScalar& scalar);
+ virtual Status Visit(const BinaryScalar& scalar);
+ virtual Status Visit(const FixedSizeBinaryScalar& scalar);
+ virtual Status Visit(const Date64Scalar& scalar);
+ virtual Status Visit(const Date32Scalar& scalar);
+ virtual Status Visit(const Time32Scalar& scalar);
+ virtual Status Visit(const Time64Scalar& scalar);
+ virtual Status Visit(const TimestampScalar& scalar);
+ virtual Status Visit(const IntervalScalar& scalar);
+ virtual Status Visit(const Decimal128Scalar& scalar);
+ virtual Status Visit(const ListScalar& scalar);
+ virtual Status Visit(const StructScalar& scalar);
+ virtual Status Visit(const DictionaryScalar& scalar);
+};
+
} // namespace arrow
#endif // ARROW_VISITOR_H
diff --git a/cpp/src/arrow/visitor_inline.h b/cpp/src/arrow/visitor_inline.h
index e8b8c49..5e20e78 100644
--- a/cpp/src/arrow/visitor_inline.h
+++ b/cpp/src/arrow/visitor_inline.h
@@ -22,6 +22,7 @@
#include "arrow/array.h"
#include "arrow/extension_type.h"
+#include "arrow/scalar.h"
#include "arrow/status.h"
#include "arrow/tensor.h"
#include "arrow/type.h"
@@ -232,6 +233,23 @@ struct ArrayDataVisitor<T, enable_if_fixed_size_binary<T>>
{
}
};
+#define SCALAR_VISIT_INLINE(TYPE_CLASS) \
+ case TYPE_CLASS##Type::type_id: \
+ return visitor->Visit(internal::checked_cast<const
TYPE_CLASS##Scalar&>(scalar));
+
+template <typename VISITOR>
+inline Status VisitScalarInline(const Scalar& scalar, VISITOR* visitor) {
+ switch (scalar.type->id()) {
+ ARROW_GENERATE_FOR_ALL_TYPES(SCALAR_VISIT_INLINE);
+ default:
+ break;
+ }
+ return Status::NotImplemented("Scalar visitor for type not implemented ",
+ scalar.type->ToString());
+}
+
+#undef TYPE_VISIT_INLINE
+
} // namespace arrow
#endif // ARROW_VISITOR_INLINE_H