This is an automated email from the ASF dual-hosted git repository.

apitrou pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 2f9c074  ARROW-12554: [C++] Allow duplicates in 
`SetLookupOptions::value_set`
2f9c074 is described below

commit 2f9c074d97e6105779fe27f52c3172338970d02e
Author: Antoine Pitrou <[email protected]>
AuthorDate: Wed Apr 28 11:29:52 2021 +0200

    ARROW-12554: [C++] Allow duplicates in `SetLookupOptions::value_set`
    
    For the `index_in` function, we need to map the memo table indices to 
indices in the value_set
    (they are different in there are duplicates).
    
    This fixes the current benchmark failures for `is_in` and `index_in`.
    
    Closes #10174 from pitrou/ARROW-12554-index-in-duplicates
    
    Authored-by: Antoine Pitrou <[email protected]>
    Signed-off-by: Antoine Pitrou <[email protected]>
---
 cpp/src/arrow/compute/kernels/scalar_set_lookup.cc |  45 +++--
 .../compute/kernels/scalar_set_lookup_test.cc      | 185 ++++++++++++++++++++-
 2 files changed, 217 insertions(+), 13 deletions(-)

diff --git a/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc 
b/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc
index 2868b0c..3e2e95e 100644
--- a/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc
@@ -41,32 +41,52 @@ struct SetLookupState : public KernelState {
 
   Status Init(const SetLookupOptions& options) {
     if (options.value_set.kind() == Datum::ARRAY) {
-      RETURN_NOT_OK(AddArrayValueSet(*options.value_set.array()));
+      const ArrayData& value_set = *options.value_set.array();
+      memo_index_to_value_index.reserve(value_set.length);
+      RETURN_NOT_OK(AddArrayValueSet(options, *options.value_set.array()));
     } else if (options.value_set.kind() == Datum::CHUNKED_ARRAY) {
       const ChunkedArray& value_set = *options.value_set.chunked_array();
+      memo_index_to_value_index.reserve(value_set.length());
+      int64_t offset = 0;
       for (const std::shared_ptr<Array>& chunk : value_set.chunks()) {
-        RETURN_NOT_OK(AddArrayValueSet(*chunk->data()));
+        RETURN_NOT_OK(AddArrayValueSet(options, *chunk->data(), offset));
+        offset += chunk->length();
       }
     } else {
       return Status::Invalid("value_set should be an array or chunked array");
     }
-    if (lookup_table.size() != options.value_set.length()) {
-      return Status::NotImplemented("duplicate values in value_set");
-    }
-    if (!options.skip_nulls) {
-      null_index = lookup_table.GetNull();
+    if (!options.skip_nulls && lookup_table.GetNull() >= 0) {
+      null_index = memo_index_to_value_index[lookup_table.GetNull()];
     }
     return Status::OK();
   }
 
-  Status AddArrayValueSet(const ArrayData& data) {
+  Status AddArrayValueSet(const SetLookupOptions& options, const ArrayData& 
data,
+                          int64_t start_index = 0) {
     using T = typename GetViewType<Type>::T;
+    int32_t index = static_cast<int32_t>(start_index);
     auto visit_valid = [&](T v) {
+      const auto memo_size = 
static_cast<int32_t>(memo_index_to_value_index.size());
       int32_t unused_memo_index;
-      return lookup_table.GetOrInsert(v, &unused_memo_index);
+      auto on_found = [&](int32_t memo_index) { DCHECK_LT(memo_index, 
memo_size); };
+      auto on_not_found = [&](int32_t memo_index) {
+        DCHECK_EQ(memo_index, memo_size);
+        memo_index_to_value_index.push_back(index);
+      };
+      RETURN_NOT_OK(lookup_table.GetOrInsert(
+          v, std::move(on_found), std::move(on_not_found), 
&unused_memo_index));
+      ++index;
+      return Status::OK();
     };
     auto visit_null = [&]() {
-      lookup_table.GetOrInsertNull();
+      const auto memo_size = 
static_cast<int32_t>(memo_index_to_value_index.size());
+      auto on_found = [&](int32_t memo_index) { DCHECK_LT(memo_index, 
memo_size); };
+      auto on_not_found = [&](int32_t memo_index) {
+        DCHECK_EQ(memo_index, memo_size);
+        memo_index_to_value_index.push_back(index);
+      };
+      lookup_table.GetOrInsertNull(std::move(on_found), 
std::move(on_not_found));
+      ++index;
       return Status::OK();
     };
 
@@ -75,6 +95,9 @@ struct SetLookupState : public KernelState {
 
   using MemoTable = typename HashTraits<Type>::MemoTableType;
   MemoTable lookup_table;
+  // When there are duplicates in value_set, the MemoTable indices must
+  // be mapped back to indices in the value_set.
+  std::vector<int32_t> memo_index_to_value_index;
   int32_t null_index = -1;
 };
 
@@ -215,7 +238,7 @@ struct IndexInVisitor {
           int32_t index = state.lookup_table.Get(v);
           if (index != -1) {
             // matching needle; output index from value_set
-            this->builder.UnsafeAppend(index);
+            this->builder.UnsafeAppend(state.memo_index_to_value_index[index]);
           } else {
             // no matching needle; output null
             this->builder.UnsafeAppendNull();
diff --git a/cpp/src/arrow/compute/kernels/scalar_set_lookup_test.cc 
b/cpp/src/arrow/compute/kernels/scalar_set_lookup_test.cc
index 272502c..5c8bf98 100644
--- a/cpp/src/arrow/compute/kernels/scalar_set_lookup_test.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_set_lookup_test.cc
@@ -157,6 +157,13 @@ TYPED_TEST(TestIsInKernelPrimitive, IsIn) {
   CheckIsIn(type, "[null, 1, 2, 3, 2]", "[2, null, 1]",
             "[false, true, true, false, true]", /*skip_nulls=*/true);
 
+  // Duplicates in right array
+  CheckIsIn(type, "[null, 1, 2, 3, 2]", "[null, 2, 2, null, 1, 1]",
+            "[true, true, true, false, true]",
+            /*skip_nulls=*/false);
+  CheckIsIn(type, "[null, 1, 2, 3, 2]", "[null, 2, 2, null, 1, 1]",
+            "[false, true, true, false, true]", /*skip_nulls=*/true);
+
   // Empty Arrays
   CheckIsIn(type, "[]", "[]", "[]");
 }
@@ -170,6 +177,10 @@ TEST_F(TestIsInKernel, NullType) {
 
   CheckIsIn(type, "[null, null]", "[null]", "[false, false]", 
/*skip_nulls=*/true);
   CheckIsIn(type, "[null, null]", "[]", "[false, false]", /*skip_nulls=*/true);
+
+  // Duplicates in right array
+  CheckIsIn(type, "[null, null, null]", "[null, null]", "[true, true, true]");
+  CheckIsIn(type, "[null, null]", "[null, null]", "[false, false]", 
/*skip_nulls=*/true);
 }
 
 TEST_F(TestIsInKernel, TimeTimestamp) {
@@ -179,6 +190,12 @@ TEST_F(TestIsInKernel, TimeTimestamp) {
               "[true, true, false, true, true]", /*skip_nulls=*/false);
     CheckIsIn(type, "[1, null, 5, 1, 2]", "[2, 1, null]",
               "[true, false, false, true, true]", /*skip_nulls=*/true);
+
+    // Duplicates in right array
+    CheckIsIn(type, "[1, null, 5, 1, 2]", "[2, 1, 1, null, 2]",
+              "[true, true, false, true, true]", /*skip_nulls=*/false);
+    CheckIsIn(type, "[1, null, 5, 1, 2]", "[2, 1, 1, null, 2]",
+              "[true, false, false, true, true]", /*skip_nulls=*/true);
   }
 }
 
@@ -194,6 +211,12 @@ TEST_F(TestIsInKernel, Boolean) {
             "[false, true, true, false, true]", /*skip_nulls=*/false);
   CheckIsIn(type, "[true, false, null, true, false]", "[false, null]",
             "[false, true, false, false, true]", /*skip_nulls=*/true);
+
+  // Duplicates in right array
+  CheckIsIn(type, "[true, false, null, true, false]", "[null, false, false, 
null]",
+            "[false, true, true, false, true]", /*skip_nulls=*/false);
+  CheckIsIn(type, "[true, false, null, true, false]", "[null, false, false, 
null]",
+            "[false, true, false, false, true]", /*skip_nulls=*/true);
 }
 
 TYPED_TEST_SUITE(TestIsInKernelBinary, BinaryTypes);
@@ -214,6 +237,14 @@ TYPED_TEST(TestIsInKernelBinary, Binary) {
   CheckIsIn(type, R"(["aaa", "", "cc", null, ""])", R"(["aaa", "", null])",
             "[true, true, false, false, true]",
             /*skip_nulls=*/true);
+
+  // Duplicates in right array
+  CheckIsIn(type, R"(["aaa", "", "cc", null, ""])",
+            R"([null, "aaa", "aaa", "", "", null])", "[true, true, false, 
true, true]",
+            /*skip_nulls=*/false);
+  CheckIsIn(type, R"(["aaa", "", "cc", null, ""])",
+            R"([null, "aaa", "aaa", "", "", null])", "[true, true, false, 
false, true]",
+            /*skip_nulls=*/true);
 }
 
 TEST_F(TestIsInKernel, FixedSizeBinary) {
@@ -232,6 +263,16 @@ TEST_F(TestIsInKernel, FixedSizeBinary) {
   CheckIsIn(type, R"(["aaa", "bbb", "ccc", null, "bbb"])", R"(["aaa", "bbb", 
null])",
             "[true, true, false, false, true]",
             /*skip_nulls=*/true);
+
+  // Duplicates in right array
+  CheckIsIn(type, R"(["aaa", "bbb", "ccc", null, "bbb"])",
+            R"(["aaa", null, "aaa", "bbb", "bbb", null])",
+            "[true, true, false, true, true]",
+            /*skip_nulls=*/false);
+  CheckIsIn(type, R"(["aaa", "bbb", "ccc", null, "bbb"])",
+            R"(["aaa", null, "aaa", "bbb", "bbb", null])",
+            "[true, true, false, false, true]",
+            /*skip_nulls=*/true);
 }
 
 TEST_F(TestIsInKernel, Decimal) {
@@ -250,6 +291,16 @@ TEST_F(TestIsInKernel, Decimal) {
   CheckIsIn(type, R"(["12.3", "45.6", "78.9", null, "12.3"])",
             R"(["12.3", "78.9", null])", "[true, false, true, false, true]",
             /*skip_nulls=*/true);
+
+  // Duplicates in right array
+  CheckIsIn(type, R"(["12.3", "45.6", "78.9", null, "12.3"])",
+            R"([null, "12.3", "12.3", "78.9", "78.9", null])",
+            "[true, false, true, true, true]",
+            /*skip_nulls=*/false);
+  CheckIsIn(type, R"(["12.3", "45.6", "78.9", null, "12.3"])",
+            R"([null, "12.3", "12.3", "78.9", "78.9", null])",
+            "[true, false, true, false, true]",
+            /*skip_nulls=*/true);
 }
 
 TEST_F(TestIsInKernel, DictionaryArray) {
@@ -314,6 +365,29 @@ TEST_F(TestIsInKernel, DictionaryArray) {
                         /*value_set_json=*/R"(["C", "B", "A"])",
                         /*expected_json=*/"[false, false, false, true, false]",
                         /*skip_nulls=*/true);
+
+    // With duplicates in value_set
+    CheckIsInDictionary(/*type=*/utf8(),
+                        /*index_type=*/index_ty,
+                        /*input_dictionary_json=*/R"(["A", "B", "C", "D"])",
+                        /*input_index_json=*/"[1, 2, null, 0]",
+                        /*value_set_json=*/R"(["A", "A", "B", "A", "B", "C"])",
+                        /*expected_json=*/"[true, true, false, true]",
+                        /*skip_nulls=*/false);
+    CheckIsInDictionary(/*type=*/utf8(),
+                        /*index_type=*/index_ty,
+                        /*input_dictionary_json=*/R"(["A", "B", "C", "D"])",
+                        /*input_index_json=*/"[1, 3, null, 0, 1]",
+                        /*value_set_json=*/R"(["C", "C", "B", "A", null, null, 
"B"])",
+                        /*expected_json=*/"[true, false, true, true, true]",
+                        /*skip_nulls=*/false);
+    CheckIsInDictionary(/*type=*/utf8(),
+                        /*index_type=*/index_ty,
+                        /*input_dictionary_json=*/R"(["A", "B", "C", "D"])",
+                        /*input_index_json=*/"[1, 3, null, 0, 1]",
+                        /*value_set_json=*/R"(["C", "C", "B", "A", null, null, 
"B"])",
+                        /*expected_json=*/"[true, false, false, true, true]",
+                        /*skip_nulls=*/true);
   }
 }
 
@@ -335,6 +409,16 @@ TEST_F(TestIsInKernel, ChunkedArrayInvoke) {
   expected = ChunkedArrayFromJSON(
       boolean(), {"[false, true, true, false, false]", "[true, false, false, 
false]"});
   CheckIsInChunked(input, value_set, expected, /*skip_nulls=*/true);
+
+  // Duplicates in value_set
+  value_set =
+      ChunkedArrayFromJSON(utf8(), {R"(["", null, "", "def"])", R"(["def", 
null])"});
+  expected = ChunkedArrayFromJSON(
+      boolean(), {"[false, true, true, false, false]", "[true, true, false, 
false]"});
+  CheckIsInChunked(input, value_set, expected, /*skip_nulls=*/false);
+  expected = ChunkedArrayFromJSON(
+      boolean(), {"[false, true, true, false, false]", "[true, false, false, 
false]"});
+  CheckIsInChunked(input, value_set, expected, /*skip_nulls=*/true);
 }
 
 // ----------------------------------------------------------------------
@@ -439,6 +523,18 @@ TYPED_TEST(TestIndexInKernelPrimitive, IndexIn) {
                      /* value_set= */ "[null]",
                      /* expected= */ "[0, 0, 0, 0]");
 
+  // Duplicates in value_set
+  this->CheckIndexIn(type,
+                     /* input= */ "[2, 1, 2, 1, 2, 3]",
+                     /* value_set= */ "[2, 2, 1, 1, 1, 3, 3]",
+                     /* expected= */ "[0, 2, 0, 2, 0, 5]");
+
+  // Duplicates and nulls in value_set
+  this->CheckIndexIn(type,
+                     /* input= */ "[2, 1, 2, 1, 2, 3]",
+                     /* value_set= */ "[2, 2, null, null, 1, 1, 1, 3, 3]",
+                     /* expected= */ "[0, 4, 0, 4, 0, 7]");
+
   // No Match
   this->CheckIndexIn(type,
                      /* input= */ "[2, null, 7, 3, 8]",
@@ -463,6 +559,17 @@ TYPED_TEST(TestIndexInKernelPrimitive, SkipNulls) {
                      /*value_set=*/"[1, 3]",
                      /*expected=*/"[null, 0, null, 1, null]",
                      /*skip_nulls=*/true);
+  // Same with duplicates in value_set
+  this->CheckIndexIn(type,
+                     /*input=*/"[0, 1, 2, 3, null]",
+                     /*value_set=*/"[1, 1, 3, 3]",
+                     /*expected=*/"[null, 0, null, 2, null]",
+                     /*skip_nulls=*/false);
+  this->CheckIndexIn(type,
+                     /*input=*/"[0, 1, 2, 3, null]",
+                     /*value_set=*/"[1, 1, 3, 3]",
+                     /*expected=*/"[null, 0, null, 2, null]",
+                     /*skip_nulls=*/true);
 
   // Nulls in value_set
   this->CheckIndexIn(type,
@@ -472,9 +579,15 @@ TYPED_TEST(TestIndexInKernelPrimitive, SkipNulls) {
                      /*skip_nulls=*/false);
   this->CheckIndexIn(type,
                      /*input=*/"[0, 1, 2, 3, null]",
-                     /*value_set=*/"[1, null, 3]",
-                     /*expected=*/"[null, 0, null, 2, null]",
+                     /*value_set=*/"[1, 1, null, null, 3, 3]",
+                     /*expected=*/"[null, 0, null, 4, null]",
                      /*skip_nulls=*/true);
+  // Same with duplicates in value_set
+  this->CheckIndexIn(type,
+                     /*input=*/"[0, 1, 2, 3, null]",
+                     /*value_set=*/"[1, 1, null, null, 3, 3]",
+                     /*expected=*/"[null, 0, null, 4, 2]",
+                     /*skip_nulls=*/false);
 }
 
 TEST_F(TestIndexInKernel, NullType) {
@@ -493,6 +606,12 @@ TEST_F(TestIndexInKernel, TimeTimestamp) {
                /* value_set= */ "[2, 1, null]",
                /* expected= */ "[1, 2, null, 1, 0]");
 
+  // Duplicates in value_set
+  CheckIndexIn(time32(TimeUnit::SECOND),
+               /* input= */ "[1, null, 5, 1, 2]",
+               /* value_set= */ "[2, 2, 1, 1, null, null]",
+               /* expected= */ "[2, 4, null, 2, 0]");
+
   // Needles array has no nulls
   CheckIndexIn(time32(TimeUnit::SECOND),
                /* input= */ "[2, null, 5, 1]",
@@ -531,6 +650,10 @@ TEST_F(TestIndexInKernel, Boolean) {
   CheckIndexIn(boolean(), "[false, null, false, true]", "[false, true, null]",
                "[0, 2, 0, 1]");
 
+  // Duplicates in value_set
+  CheckIndexIn(boolean(), "[false, null, false, true]",
+               "[false, false, true, true, null, null]", "[0, 4, 0, 2]");
+
   // No Nulls
   CheckIndexIn(boolean(), "[true, true, false, true]", "[false, true]", "[1, 
1, 0, 1]");
 
@@ -562,6 +685,10 @@ TYPED_TEST(TestIndexInKernelBinary, Binary) {
   this->CheckIndexIn(type, R"(["foo", null, "bar", "foo"])", R"(["foo", null, 
"bar"])",
                      R"([0, 1, 2, 0])");
 
+  // Duplicates in value_set
+  this->CheckIndexIn(type, R"(["foo", null, "bar", "foo"])",
+                     R"(["foo", "foo", null, null, "bar", "bar"])", R"([0, 2, 
4, 0])");
+
   // No match
   this->CheckIndexIn(type,
                      /* input= */ R"(["foo", null, "bar", "foo"])",
@@ -653,6 +780,17 @@ TEST_F(TestIndexInKernel, FixedSizeBinary) {
                /*expected=*/R"([1, null, null, 0, 2, 0])",
                /*skip_nulls=*/true);
 
+  // Duplicates in value_set
+  CheckIndexIn(fixed_size_binary(3),
+               /*input=*/R"(["bbb", null, "ddd", "aaa", "ccc", "aaa"])",
+               /*value_set=*/R"(["aaa", "aaa", null, null, "bbb", "bbb", 
"ccc"])",
+               /*expected=*/R"([4, 2, null, 0, 6, 0])");
+  CheckIndexIn(fixed_size_binary(3),
+               /*input=*/R"(["bbb", null, "ddd", "aaa", "ccc", "aaa"])",
+               /*value_set=*/R"(["aaa", "aaa", null, null, "bbb", "bbb", 
"ccc"])",
+               /*expected=*/R"([4, null, null, 0, 6, 0])",
+               /*skip_nulls=*/true);
+
   // Empty input array
   CheckIndexIn(fixed_size_binary(5), R"([])", R"(["bbbbb", null, "aaaaa", 
"ccccc"])",
                R"([])");
@@ -689,6 +827,18 @@ TEST_F(TestIndexInKernel, Decimal) {
                /*value_set=*/R"(["11", "12"])",
                /*expected=*/R"([1, null, 0, 1, null])",
                /*skip_nulls=*/true);
+
+  // Duplicates in value_set
+  CheckIndexIn(type,
+               /*input=*/R"(["12", null, "11", "12", "13"])",
+               /*value_set=*/R"([null, null, "11", "11", "12", "12"])",
+               /*expected=*/R"([4, 0, 2, 4, null])",
+               /*skip_nulls=*/false);
+  CheckIndexIn(type,
+               /*input=*/R"(["12", null, "11", "12", "13"])",
+               /*value_set=*/R"([null, null, "11", "11", "12", "12"])",
+               /*expected=*/R"([4, null, 2, 4, null])",
+               /*skip_nulls=*/true);
 }
 
 TEST_F(TestIndexInKernel, DictionaryArray) {
@@ -753,6 +903,29 @@ TEST_F(TestIndexInKernel, DictionaryArray) {
                            /*value_set_json=*/R"(["C", "B", "A"])",
                            /*expected_json=*/"[null, null, null, 2, null]",
                            /*skip_nulls=*/true);
+
+    // With duplicates in value_set
+    CheckIndexInDictionary(/*type=*/utf8(),
+                           /*index_type=*/index_ty,
+                           /*input_dictionary_json=*/R"(["A", "B", "C", "D"])",
+                           /*input_index_json=*/"[1, 2, null, 0]",
+                           /*value_set_json=*/R"(["A", "A", "B", "B", "C", 
"C"])",
+                           /*expected_json=*/"[2, 4, null, 0]",
+                           /*skip_nulls=*/false);
+    CheckIndexInDictionary(/*type=*/utf8(),
+                           /*index_type=*/index_ty,
+                           /*input_dictionary_json=*/R"(["A", null, "C", 
"D"])",
+                           /*input_index_json=*/"[1, 3, null, 0, 1]",
+                           /*value_set_json=*/R"(["C", "C", "B", "B", "A", 
"A", null])",
+                           /*expected_json=*/"[6, null, 6, 4, 6]",
+                           /*skip_nulls=*/false);
+    CheckIndexInDictionary(/*type=*/utf8(),
+                           /*index_type=*/index_ty,
+                           /*input_dictionary_json=*/R"(["A", null, "C", 
"D"])",
+                           /*input_index_json=*/"[1, 3, null, 0, 1]",
+                           /*value_set_json=*/R"(["C", "C", "B", "B", "A", 
"A", null])",
+                           /*expected_json=*/"[null, null, null, 4, null]",
+                           /*skip_nulls=*/true);
   }
 }
 
@@ -773,6 +946,14 @@ TEST_F(TestIndexInKernel, ChunkedArrayInvoke) {
   CheckIndexInChunked(input, value_set, expected, /*skip_nulls=*/false);
   expected = ChunkedArrayFromJSON(int32(), {"[3, 1, 0, 3, null]", "[1, null, 
3, null]"});
   CheckIndexInChunked(input, value_set, expected, /*skip_nulls=*/true);
+
+  // Duplicates in value_set
+  value_set = ChunkedArrayFromJSON(
+      utf8(), {R"(["ghi", "ghi", "def"])", R"(["def", null, null, "abc"])"});
+  expected = ChunkedArrayFromJSON(int32(), {"[6, 2, 0, 6, null]", "[2, 4, 6, 
null]"});
+  CheckIndexInChunked(input, value_set, expected, /*skip_nulls=*/false);
+  expected = ChunkedArrayFromJSON(int32(), {"[6, 2, 0, 6, null]", "[2, null, 
6, null]"});
+  CheckIndexInChunked(input, value_set, expected, /*skip_nulls=*/true);
 }
 
 TEST(TestSetLookup, DispatchBest) {

Reply via email to