This is an automated email from the ASF dual-hosted git repository.

wesm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 548e194  ARROW-4673: [C++] Implement Scalar::Equals and Datum::Equals
548e194 is described below

commit 548e1949d527717d7821a4ab2f09ff7c39882152
Author: François Saint-Jacques <[email protected]>
AuthorDate: Thu Mar 14 20:03:44 2019 -0500

    ARROW-4673: [C++] Implement Scalar::Equals and Datum::Equals
    
    Handy for validating kernels.
    
    Author: François Saint-Jacques <[email protected]>
    Author: Wes McKinney <[email protected]>
    
    Closes #3875 from fsaintjacques/ARROW-4673-datum-equal and squashes the 
following commits:
    
    3fff08785 <Wes McKinney> Add common base class for some primitive scalar, a 
little DRY
    093e1bd55 <François Saint-Jacques> Fix struct Scalar warning
    66cae36d8 <François Saint-Jacques> Fix warnings.
    7a7c0d6a1 <François Saint-Jacques> ARROW-4673:  Implement Scalar::Equals 
and Datum::Equals
---
 cpp/src/arrow/compare.cc                        | 91 +++++++++++++++++++++++++
 cpp/src/arrow/compare.h                         |  6 ++
 cpp/src/arrow/compute/kernel.h                  | 46 +++++++++++++
 cpp/src/arrow/compute/kernels/aggregate-test.cc |  4 +-
 cpp/src/arrow/scalar-test.cc                    | 15 ++++
 cpp/src/arrow/scalar.cc                         |  9 ++-
 cpp/src/arrow/scalar.h                          | 32 +++++++--
 cpp/src/arrow/testing/gtest_util.cc             |  6 ++
 cpp/src/arrow/testing/gtest_util.h              |  8 +++
 cpp/src/arrow/type_fwd.h                        |  2 +
 cpp/src/arrow/util/memory.h                     | 12 ++++
 cpp/src/arrow/visitor.cc                        | 39 +++++++++++
 cpp/src/arrow/visitor.h                         | 32 +++++++++
 cpp/src/arrow/visitor_inline.h                  | 18 +++++
 14 files changed, 308 insertions(+), 12 deletions(-)

diff --git a/cpp/src/arrow/compare.cc b/cpp/src/arrow/compare.cc
index fcb16b5..aca6094 100644
--- a/cpp/src/arrow/compare.cc
+++ b/cpp/src/arrow/compare.cc
@@ -30,6 +30,7 @@
 
 #include "arrow/array.h"
 #include "arrow/buffer.h"
+#include "arrow/scalar.h"
 #include "arrow/sparse_tensor.h"
 #include "arrow/status.h"
 #include "arrow/tensor.h"
@@ -38,6 +39,7 @@
 #include "arrow/util/checked_cast.h"
 #include "arrow/util/logging.h"
 #include "arrow/util/macros.h"
+#include "arrow/util/memory.h"
 #include "arrow/visitor_inline.h"
 
 namespace arrow {
@@ -717,6 +719,78 @@ class TypeEqualsVisitor {
   bool result_;
 };
 
+class ScalarEqualsVisitor {
+ public:
+  explicit ScalarEqualsVisitor(const Scalar& right) : right_(right), 
result_(false) {}
+
+  Status Visit(const NullScalar& left) {
+    result_ = true;
+    return Status::OK();
+  }
+
+  template <typename T>
+  typename std::enable_if<std::is_base_of<internal::PrimitiveScalar, T>::value,
+                          Status>::type
+  Visit(const T& left_) {
+    const auto& right = checked_cast<const T&>(right_);
+    result_ = right.value == left_.value;
+    return Status::OK();
+  }
+
+  template <typename T>
+  typename std::enable_if<std::is_base_of<BinaryScalar, T>::value, 
Status>::type Visit(
+      const T& left_) {
+    const auto& left = checked_cast<const BinaryScalar&>(left_);
+    const auto& right = checked_cast<const BinaryScalar&>(right_);
+    result_ = internal::SharedPtrEquals(left.value, right.value);
+    return Status::OK();
+  }
+
+  Status Visit(const Decimal128Scalar& left) {
+    const auto& right = checked_cast<const Decimal128Scalar&>(right_);
+    result_ = left.value == right.value;
+    return Status::OK();
+  }
+
+  Status Visit(const ListScalar& left) {
+    const auto& right = checked_cast<const ListScalar&>(right_);
+    result_ = internal::SharedPtrEquals(left.value, right.value);
+    return Status::OK();
+  }
+
+  Status Visit(const StructScalar& left) {
+    const auto& right = checked_cast<const StructScalar&>(right_);
+
+    if (right.value.size() != left.value.size()) {
+      result_ = false;
+    } else {
+      bool all_equals = true;
+      for (size_t i = 0; i < left.value.size() && all_equals; i++) {
+        all_equals &= internal::SharedPtrEquals(left.value[i], right.value[i]);
+      }
+      result_ = all_equals;
+    }
+
+    return Status::OK();
+  }
+
+  Status Visit(const UnionScalar& left) { return 
Status::NotImplemented("union"); }
+
+  Status Visit(const DictionaryScalar& left) {
+    return Status::NotImplemented("dictionary");
+  }
+
+  Status Visit(const ExtensionScalar& left) {
+    return Status::NotImplemented("extension");
+  }
+
+  bool result() const { return result_; }
+
+ protected:
+  const Scalar& right_;
+  bool result_;
+};
+
 }  // namespace internal
 
 bool ArrayEquals(const Array& left, const Array& right) {
@@ -915,4 +989,21 @@ bool TypeEquals(const DataType& left, const DataType& 
right, bool check_metadata
   return are_equal;
 }
 
+bool ScalarEquals(const Scalar& left, const Scalar& right) {
+  bool are_equal = false;
+  if (&left == &right) {
+    are_equal = true;
+  } else if (!left.type->Equals(right.type)) {
+    are_equal = false;
+  } else if (left.is_valid != right.is_valid) {
+    are_equal = false;
+  } else {
+    internal::ScalarEqualsVisitor visitor(right);
+    auto error = VisitScalarInline(left, &visitor);
+    DCHECK_OK(error);
+    are_equal = visitor.result();
+  }
+  return are_equal;
+}
+
 }  // namespace arrow
diff --git a/cpp/src/arrow/compare.h b/cpp/src/arrow/compare.h
index 4bb2de4..a0c24c9 100644
--- a/cpp/src/arrow/compare.h
+++ b/cpp/src/arrow/compare.h
@@ -30,6 +30,7 @@ class Array;
 class DataType;
 class Tensor;
 class SparseTensor;
+struct Scalar;
 
 /// Returns true if the arrays are exactly equal
 bool ARROW_EXPORT ArrayEquals(const Array& left, const Array& right);
@@ -56,6 +57,11 @@ bool ARROW_EXPORT ArrayRangeEquals(const Array& left, const 
Array& right,
 bool ARROW_EXPORT TypeEquals(const DataType& left, const DataType& right,
                              bool check_metadata = true);
 
+/// Returns true if scalars are equal
+/// \param[in] left a Scalar
+/// \param[in] right a Scalar
+bool ARROW_EXPORT ScalarEquals(const Scalar& left, const Scalar& right);
+
 }  // namespace arrow
 
 #endif  // ARROW_COMPARE_H
diff --git a/cpp/src/arrow/compute/kernel.h b/cpp/src/arrow/compute/kernel.h
index 387715d..56d7d68 100644
--- a/cpp/src/arrow/compute/kernel.h
+++ b/cpp/src/arrow/compute/kernel.h
@@ -27,6 +27,7 @@
 #include "arrow/scalar.h"
 #include "arrow/table.h"
 #include "arrow/util/macros.h"
+#include "arrow/util/memory.h"
 #include "arrow/util/variant.h"  // IWYU pragma: export
 #include "arrow/util/visibility.h"
 
@@ -59,6 +60,10 @@ class ARROW_EXPORT OpKernel {
   virtual std::shared_ptr<DataType> out_type() const = 0;
 };
 
+struct Datum;
+static inline bool CollectionEquals(const std::vector<Datum>& left,
+                                    const std::vector<Datum>& right);
+
 /// \class Datum
 /// \brief Variant type for various Arrow C++ data structures
 struct ARROW_EXPORT Datum {
@@ -153,6 +158,14 @@ struct ARROW_EXPORT Datum {
     return util::get<std::shared_ptr<ChunkedArray>>(this->value);
   }
 
+  std::shared_ptr<RecordBatch> record_batch() const {
+    return util::get<std::shared_ptr<RecordBatch>>(this->value);
+  }
+
+  std::shared_ptr<Table> table() const {
+    return util::get<std::shared_ptr<Table>>(this->value);
+  }
+
   const std::vector<Datum> collection() const {
     return util::get<std::vector<Datum>>(this->value);
   }
@@ -182,6 +195,29 @@ struct ARROW_EXPORT Datum {
     }
     return NULLPTR;
   }
+
+  bool Equals(const Datum& other) const {
+    if (this->kind() != other.kind()) return false;
+
+    switch (this->kind()) {
+      case Datum::NONE:
+        return true;
+      case Datum::SCALAR:
+        return internal::SharedPtrEquals(this->scalar(), other.scalar());
+      case Datum::ARRAY:
+        return internal::SharedPtrEquals(this->make_array(), 
other.make_array());
+      case Datum::CHUNKED_ARRAY:
+        return internal::SharedPtrEquals(this->chunked_array(), 
other.chunked_array());
+      case Datum::RECORD_BATCH:
+        return internal::SharedPtrEquals(this->record_batch(), 
other.record_batch());
+      case Datum::TABLE:
+        return internal::SharedPtrEquals(this->table(), other.table());
+      case Datum::COLLECTION:
+        return CollectionEquals(this->collection(), other.collection());
+      default:
+        return false;
+    }
+  }
 };
 
 /// \class UnaryKernel
@@ -214,6 +250,16 @@ class ARROW_EXPORT BinaryKernel : public OpKernel {
                       Datum* out) = 0;
 };
 
+static inline bool CollectionEquals(const std::vector<Datum>& left,
+                                    const std::vector<Datum>& right) {
+  if (left.size() != right.size()) return false;
+
+  for (size_t i = 0; i < left.size(); i++)
+    if (!left[i].Equals(right[i])) return false;
+
+  return true;
+}
+
 }  // namespace compute
 }  // namespace arrow
 
diff --git a/cpp/src/arrow/compute/kernels/aggregate-test.cc 
b/cpp/src/arrow/compute/kernels/aggregate-test.cc
index cbe91a2..fd3b6d9 100644
--- a/cpp/src/arrow/compute/kernels/aggregate-test.cc
+++ b/cpp/src/arrow/compute/kernels/aggregate-test.cc
@@ -267,10 +267,10 @@ void ValidateCount(FunctionContext* ctx, const Array& 
input, CountPair expected)
   Datum result;
 
   ASSERT_OK(Count(ctx, all, input, &result));
-  DatumEqual<Int64Type>::EnsureEqual(result, Datum(expected.first));
+  AssertDatumsEqual(result, Datum(expected.first));
 
   ASSERT_OK(Count(ctx, nulls, input, &result));
-  DatumEqual<Int64Type>::EnsureEqual(result, Datum(expected.second));
+  AssertDatumsEqual(result, Datum(expected.second));
 }
 
 template <typename ArrowType>
diff --git a/cpp/src/arrow/scalar-test.cc b/cpp/src/arrow/scalar-test.cc
index 580f480..67af4fb 100644
--- a/cpp/src/arrow/scalar-test.cc
+++ b/cpp/src/arrow/scalar-test.cc
@@ -76,8 +76,12 @@ TYPED_TEST(TestNumericScalar, Basics) {
   ASSERT_TRUE(scalar_val->type->Equals(*expected_type));
 
   T other_value = static_cast<T>(2);
+  auto scalar_other = std::make_shared<ScalarType>(other_value);
+  ASSERT_FALSE(scalar_other->Equals(scalar_val));
+
   scalar_val->value = other_value;
   ASSERT_EQ(other_value, scalar_val->value);
+  ASSERT_TRUE(scalar_other->Equals(scalar_val));
 
   ScalarType stack_val = ScalarType(0, false);
   ASSERT_FALSE(stack_val.is_valid);
@@ -106,6 +110,13 @@ TEST(TestBinaryScalar, Basics) {
   ASSERT_TRUE(value2.is_valid);
   ASSERT_TRUE(value2.type->Equals(*utf8()));
 
+  // Same buffer, different type.
+  ASSERT_FALSE(value2.Equals(value));
+
+  StringScalar value3(buf);
+  // Same buffer, same type.
+  ASSERT_TRUE(value2.Equals(value3));
+
   StringScalar null_value2(nullptr, false);
   ASSERT_FALSE(null_value2.is_valid);
 }
@@ -182,6 +193,10 @@ TEST(TestTimestampScalars, Basics) {
   ASSERT_TRUE(ts_val1.is_valid);
   ASSERT_FALSE(ts_null.is_valid);
   ASSERT_TRUE(ts_null.type->Equals(*type1));
+
+  ASSERT_FALSE(ts_val1.Equals(ts_val2));
+  ASSERT_FALSE(ts_val1.Equals(ts_null));
+  ASSERT_FALSE(ts_val2.Equals(ts_null));
 }
 
 // TODO test HalfFloatScalar
diff --git a/cpp/src/arrow/scalar.cc b/cpp/src/arrow/scalar.cc
index a9d7c5d..9888598 100644
--- a/cpp/src/arrow/scalar.cc
+++ b/cpp/src/arrow/scalar.cc
@@ -21,6 +21,7 @@
 
 #include "arrow/array.h"
 #include "arrow/buffer.h"
+#include "arrow/compare.h"
 #include "arrow/type.h"
 #include "arrow/util/checked_cast.h"
 #include "arrow/util/decimal.h"
@@ -30,21 +31,23 @@ namespace arrow {
 
 using internal::checked_cast;
 
+bool Scalar::Equals(const Scalar& other) const { return ScalarEquals(*this, 
other); }
+
 Time32Scalar::Time32Scalar(int32_t value, const std::shared_ptr<DataType>& 
type,
                            bool is_valid)
-    : Scalar{type, is_valid}, value(value) {
+    : internal::PrimitiveScalar{type, is_valid}, value(value) {
   DCHECK_EQ(Type::TIME32, type->id());
 }
 
 Time64Scalar::Time64Scalar(int64_t value, const std::shared_ptr<DataType>& 
type,
                            bool is_valid)
-    : Scalar{type, is_valid}, value(value) {
+    : internal::PrimitiveScalar{type, is_valid}, value(value) {
   DCHECK_EQ(Type::TIME64, type->id());
 }
 
 TimestampScalar::TimestampScalar(int64_t value, const 
std::shared_ptr<DataType>& type,
                                  bool is_valid)
-    : Scalar{type, is_valid}, value(value) {
+    : internal::PrimitiveScalar{type, is_valid}, value(value) {
   DCHECK_EQ(Type::TIMESTAMP, type->id());
 }
 
diff --git a/cpp/src/arrow/scalar.h b/cpp/src/arrow/scalar.h
index b90b26e..ce714bf 100644
--- a/cpp/src/arrow/scalar.h
+++ b/cpp/src/arrow/scalar.h
@@ -48,6 +48,12 @@ struct ARROW_EXPORT Scalar {
   /// \brief Whether the value is valid (not null) or not
   bool is_valid;
 
+  bool Equals(const Scalar& other) const;
+  bool Equals(const std::shared_ptr<Scalar>& other) const {
+    if (other) return Equals(*other);
+    return false;
+  }
+
  protected:
   Scalar(const std::shared_ptr<DataType>& type, bool is_valid)
       : type(type), is_valid(is_valid) {}
@@ -59,14 +65,22 @@ struct ARROW_EXPORT NullScalar : public Scalar {
   NullScalar() : Scalar{null(), false} {}
 };
 
-struct ARROW_EXPORT BooleanScalar : public Scalar {
+namespace internal {
+
+struct ARROW_EXPORT PrimitiveScalar : public Scalar {
+  using Scalar::Scalar;
+};
+
+}  // namespace internal
+
+struct ARROW_EXPORT BooleanScalar : public internal::PrimitiveScalar {
   bool value;
   explicit BooleanScalar(bool value, bool is_valid = true)
-      : Scalar{boolean(), is_valid}, value(value) {}
+      : internal::PrimitiveScalar{boolean(), is_valid}, value(value) {}
 };
 
 template <typename Type>
-struct NumericScalar : public Scalar {
+struct NumericScalar : public internal::PrimitiveScalar {
   using T = typename Type::c_type;
   T value;
 
@@ -75,7 +89,7 @@ struct NumericScalar : public Scalar {
 
  protected:
   explicit NumericScalar(T value, const std::shared_ptr<DataType>& type, bool 
is_valid)
-      : Scalar{type, is_valid}, value(value) {}
+      : internal::PrimitiveScalar{type, is_valid}, value(value) {}
 };
 
 struct ARROW_EXPORT BinaryScalar : public Scalar {
@@ -109,21 +123,21 @@ class ARROW_EXPORT Date64Scalar : public 
NumericScalar<Date64Type> {
   using NumericScalar<Date64Type>::NumericScalar;
 };
 
-class ARROW_EXPORT Time32Scalar : public Scalar {
+class ARROW_EXPORT Time32Scalar : public internal::PrimitiveScalar {
  public:
   int32_t value;
   Time32Scalar(int32_t value, const std::shared_ptr<DataType>& type,
                bool is_valid = true);
 };
 
-class ARROW_EXPORT Time64Scalar : public Scalar {
+class ARROW_EXPORT Time64Scalar : public internal::PrimitiveScalar {
  public:
   int64_t value;
   Time64Scalar(int64_t value, const std::shared_ptr<DataType>& type,
                bool is_valid = true);
 };
 
-class ARROW_EXPORT TimestampScalar : public Scalar {
+class ARROW_EXPORT TimestampScalar : public internal::PrimitiveScalar {
  public:
   int64_t value;
   TimestampScalar(int64_t value, const std::shared_ptr<DataType>& type,
@@ -149,4 +163,8 @@ struct ARROW_EXPORT StructScalar : public Scalar {
   std::vector<std::shared_ptr<Scalar>> value;
 };
 
+class ARROW_EXPORT UnionScalar : public Scalar {};
+class ARROW_EXPORT DictionaryScalar : public Scalar {};
+class ARROW_EXPORT ExtensionScalar : public Scalar {};
+
 }  // namespace arrow
diff --git a/cpp/src/arrow/testing/gtest_util.cc 
b/cpp/src/arrow/testing/gtest_util.cc
index 1db631b..4811954 100644
--- a/cpp/src/arrow/testing/gtest_util.cc
+++ b/cpp/src/arrow/testing/gtest_util.cc
@@ -36,6 +36,7 @@
 
 #include "arrow/array.h"
 #include "arrow/buffer.h"
+#include "arrow/compute/kernel.h"
 #include "arrow/ipc/json-simple.h"
 #include "arrow/pretty_print.h"
 #include "arrow/status.h"
@@ -104,6 +105,11 @@ void AssertSchemaEqual(const Schema& lhs, const Schema& 
rhs) {
   }
 }
 
+void AssertDatumsEqual(const Datum& expected, const Datum& actual) {
+  // TODO: Implements better print.
+  ASSERT_TRUE(actual.Equals(expected));
+}
+
 std::shared_ptr<Array> ArrayFromJSON(const std::shared_ptr<DataType>& type,
                                      const std::string& json) {
   std::shared_ptr<Array> out;
diff --git a/cpp/src/arrow/testing/gtest_util.h 
b/cpp/src/arrow/testing/gtest_util.h
index 7f46dfd..88f3d12 100644
--- a/cpp/src/arrow/testing/gtest_util.h
+++ b/cpp/src/arrow/testing/gtest_util.h
@@ -105,6 +105,12 @@ class ChunkedArray;
 class Column;
 class Table;
 
+namespace compute {
+struct Datum;
+}
+
+using Datum = compute::Datum;
+
 using ArrayVector = std::vector<std::shared_ptr<Array>>;
 
 #define ASSERT_PP_EQUAL(LEFT, RIGHT)                                           
        \
@@ -137,6 +143,8 @@ ARROW_EXPORT void PrintColumn(const Column& col, 
std::stringstream* ss);
 ARROW_EXPORT void AssertTablesEqual(const Table& expected, const Table& actual,
                                     bool same_chunk_layout = true);
 
+ARROW_EXPORT void AssertDatumsEqual(const Datum& expected, const Datum& 
actual);
+
 template <typename C_TYPE>
 void AssertNumericDataEqual(const C_TYPE* raw_data,
                             const std::vector<C_TYPE>& expected_values) {
diff --git a/cpp/src/arrow/type_fwd.h b/cpp/src/arrow/type_fwd.h
index b995b88..9a8d3ef 100644
--- a/cpp/src/arrow/type_fwd.h
+++ b/cpp/src/arrow/type_fwd.h
@@ -89,6 +89,7 @@ struct Decimal128Scalar;
 
 class UnionType;
 class UnionArray;
+class UnionScalar;
 
 template <typename TypeClass>
 class NumericArray;
@@ -154,6 +155,7 @@ class IntervalScalar;
 
 class ExtensionType;
 class ExtensionArray;
+class ExtensionScalar;
 
 // ----------------------------------------------------------------------
 // (parameter-free) Factory functions
diff --git a/cpp/src/arrow/util/memory.h b/cpp/src/arrow/util/memory.h
index 62641e3..2d2a105 100644
--- a/cpp/src/arrow/util/memory.h
+++ b/cpp/src/arrow/util/memory.h
@@ -19,6 +19,9 @@
 #define ARROW_UTIL_MEMORY_H
 
 #include <cstdint>
+#include <memory>
+
+#include "arrow/util/macros.h"
 
 namespace arrow {
 namespace internal {
@@ -28,6 +31,15 @@ namespace internal {
 void parallel_memcopy(uint8_t* dst, const uint8_t* src, int64_t nbytes,
                       uintptr_t block_size, int num_threads);
 
+// A helper function for checking if two wrapped objects implementing `Equals`
+// are equal.
+template <typename T>
+bool SharedPtrEquals(const std::shared_ptr<T>& left, const std::shared_ptr<T>& 
right) {
+  if (left == right) return true;
+  if (left == NULLPTR || right == NULLPTR) return false;
+  return left->Equals(*right);
+}
+
 }  // namespace internal
 }  // namespace arrow
 
diff --git a/cpp/src/arrow/visitor.cc b/cpp/src/arrow/visitor.cc
index 76e5287..5a601c8 100644
--- a/cpp/src/arrow/visitor.cc
+++ b/cpp/src/arrow/visitor.cc
@@ -21,6 +21,7 @@
 
 #include "arrow/array.h"
 #include "arrow/extension_type.h"
+#include "arrow/scalar.h"
 #include "arrow/status.h"
 #include "arrow/type.h"
 
@@ -101,4 +102,42 @@ TYPE_VISITOR_DEFAULT(ExtensionType)
 
 #undef TYPE_VISITOR_DEFAULT
 
+// ----------------------------------------------------------------------
+// Default implementations of ScalarVisitor methods
+
+#define SCALAR_VISITOR_DEFAULT(SCALAR_CLASS)                                 \
+  Status ScalarVisitor::Visit(const SCALAR_CLASS& scalar) {                  \
+    return Status::NotImplemented(                                           \
+        "ScalarVisitor not implemented for " ARROW_STRINGIFY(SCALAR_CLASS)); \
+  }
+
+SCALAR_VISITOR_DEFAULT(NullScalar)
+SCALAR_VISITOR_DEFAULT(BooleanScalar)
+SCALAR_VISITOR_DEFAULT(Int8Scalar)
+SCALAR_VISITOR_DEFAULT(Int16Scalar)
+SCALAR_VISITOR_DEFAULT(Int32Scalar)
+SCALAR_VISITOR_DEFAULT(Int64Scalar)
+SCALAR_VISITOR_DEFAULT(UInt8Scalar)
+SCALAR_VISITOR_DEFAULT(UInt16Scalar)
+SCALAR_VISITOR_DEFAULT(UInt32Scalar)
+SCALAR_VISITOR_DEFAULT(UInt64Scalar)
+SCALAR_VISITOR_DEFAULT(HalfFloatScalar)
+SCALAR_VISITOR_DEFAULT(FloatScalar)
+SCALAR_VISITOR_DEFAULT(DoubleScalar)
+SCALAR_VISITOR_DEFAULT(StringScalar)
+SCALAR_VISITOR_DEFAULT(BinaryScalar)
+SCALAR_VISITOR_DEFAULT(FixedSizeBinaryScalar)
+SCALAR_VISITOR_DEFAULT(Date64Scalar)
+SCALAR_VISITOR_DEFAULT(Date32Scalar)
+SCALAR_VISITOR_DEFAULT(Time32Scalar)
+SCALAR_VISITOR_DEFAULT(Time64Scalar)
+SCALAR_VISITOR_DEFAULT(TimestampScalar)
+SCALAR_VISITOR_DEFAULT(IntervalScalar)
+SCALAR_VISITOR_DEFAULT(Decimal128Scalar)
+SCALAR_VISITOR_DEFAULT(ListScalar)
+SCALAR_VISITOR_DEFAULT(StructScalar)
+SCALAR_VISITOR_DEFAULT(DictionaryScalar)
+
+#undef SCALAR_VISITOR_DEFAULT
+
 }  // namespace arrow
diff --git a/cpp/src/arrow/visitor.h b/cpp/src/arrow/visitor.h
index d1e3e3a..9806eff 100644
--- a/cpp/src/arrow/visitor.h
+++ b/cpp/src/arrow/visitor.h
@@ -92,6 +92,38 @@ class ARROW_EXPORT TypeVisitor {
   virtual Status Visit(const ExtensionType& type);
 };
 
+class ARROW_EXPORT ScalarVisitor {
+ public:
+  virtual ~ScalarVisitor() = default;
+
+  virtual Status Visit(const NullScalar& scalar);
+  virtual Status Visit(const BooleanScalar& scalar);
+  virtual Status Visit(const Int8Scalar& scalar);
+  virtual Status Visit(const Int16Scalar& scalar);
+  virtual Status Visit(const Int32Scalar& scalar);
+  virtual Status Visit(const Int64Scalar& scalar);
+  virtual Status Visit(const UInt8Scalar& scalar);
+  virtual Status Visit(const UInt16Scalar& scalar);
+  virtual Status Visit(const UInt32Scalar& scalar);
+  virtual Status Visit(const UInt64Scalar& scalar);
+  virtual Status Visit(const HalfFloatScalar& scalar);
+  virtual Status Visit(const FloatScalar& scalar);
+  virtual Status Visit(const DoubleScalar& scalar);
+  virtual Status Visit(const StringScalar& scalar);
+  virtual Status Visit(const BinaryScalar& scalar);
+  virtual Status Visit(const FixedSizeBinaryScalar& scalar);
+  virtual Status Visit(const Date64Scalar& scalar);
+  virtual Status Visit(const Date32Scalar& scalar);
+  virtual Status Visit(const Time32Scalar& scalar);
+  virtual Status Visit(const Time64Scalar& scalar);
+  virtual Status Visit(const TimestampScalar& scalar);
+  virtual Status Visit(const IntervalScalar& scalar);
+  virtual Status Visit(const Decimal128Scalar& scalar);
+  virtual Status Visit(const ListScalar& scalar);
+  virtual Status Visit(const StructScalar& scalar);
+  virtual Status Visit(const DictionaryScalar& scalar);
+};
+
 }  // namespace arrow
 
 #endif  // ARROW_VISITOR_H
diff --git a/cpp/src/arrow/visitor_inline.h b/cpp/src/arrow/visitor_inline.h
index e8b8c49..5e20e78 100644
--- a/cpp/src/arrow/visitor_inline.h
+++ b/cpp/src/arrow/visitor_inline.h
@@ -22,6 +22,7 @@
 
 #include "arrow/array.h"
 #include "arrow/extension_type.h"
+#include "arrow/scalar.h"
 #include "arrow/status.h"
 #include "arrow/tensor.h"
 #include "arrow/type.h"
@@ -232,6 +233,23 @@ struct ArrayDataVisitor<T, enable_if_fixed_size_binary<T>> 
{
   }
 };
 
+#define SCALAR_VISIT_INLINE(TYPE_CLASS) \
+  case TYPE_CLASS##Type::type_id:       \
+    return visitor->Visit(internal::checked_cast<const 
TYPE_CLASS##Scalar&>(scalar));
+
+template <typename VISITOR>
+inline Status VisitScalarInline(const Scalar& scalar, VISITOR* visitor) {
+  switch (scalar.type->id()) {
+    ARROW_GENERATE_FOR_ALL_TYPES(SCALAR_VISIT_INLINE);
+    default:
+      break;
+  }
+  return Status::NotImplemented("Scalar visitor for type not implemented ",
+                                scalar.type->ToString());
+}
+
+#undef TYPE_VISIT_INLINE
+
 }  // namespace arrow
 
 #endif  // ARROW_VISITOR_INLINE_H

Reply via email to