lidavidm commented on a change in pull request #11022:
URL: https://github.com/apache/arrow/pull/11022#discussion_r703002719



##########
File path: cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
##########
@@ -624,6 +624,250 @@ TYPED_TEST(TestCaseWhenNumeric, ListOfType) {
               ArrayFromJSON(type, R"([null, null, null, [6, null]])"));
 }
 
+template <typename Type>
+class TestCaseWhenDict : public ::testing::Test {};
+
+struct JsonDict {
+  std::shared_ptr<DataType> type;
+  std::string value;
+};
+
+TYPED_TEST_SUITE(TestCaseWhenDict, IntegralArrowTypes);
+
+TYPED_TEST(TestCaseWhenDict, Simple) {
+  auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]");
+  auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]");
+  for (const auto& dict :
+       {JsonDict{utf8(), R"(["a", null, "bc", "def"])"},
+        JsonDict{int64(), "[1, null, 2, 3]"},
+        JsonDict{decimal256(3, 2), R"(["1.23", null, "3.45", "6.78"])"}}) {
+    auto type = dictionary(default_type_instance<TypeParam>(), dict.type);
+    auto values_null = DictArrayFromJSON(type, "[null, null, null, null]", 
dict.value);
+    auto values1 = DictArrayFromJSON(type, "[0, null, 3, 1]", dict.value);
+    auto values2 = DictArrayFromJSON(type, "[2, 1, null, 0]", dict.value);
+
+    // Easy case: all arguments have the same dictionary
+    CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2},
+                DictArrayFromJSON(type, "[0, null, null, null]", dict.value));
+    CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, 
values1},
+                DictArrayFromJSON(type, "[0, null, null, 1]", dict.value));
+    CheckScalar("case_when", {MakeStruct({cond1, cond2}), values_null, 
values2, values1},
+                DictArrayFromJSON(type, "[null, null, null, 1]", dict.value));
+  }
+}
+
+TYPED_TEST(TestCaseWhenDict, Mixed) {
+  auto type = dictionary(default_type_instance<TypeParam>(), utf8());
+  auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]");
+  auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]");
+  auto dict = R"(["a", null, "bc", "def"])";
+  auto values_null = DictArrayFromJSON(type, "[null, null, null, null]", dict);
+  auto values1_dict = DictArrayFromJSON(type, "[0, null, 3, 1]", dict);
+  auto values1_decoded = ArrayFromJSON(utf8(), R"(["a", null, "def", null])");
+  auto values2_dict = DictArrayFromJSON(type, "[2, 1, null, 0]", dict);
+  auto values2_decoded = ArrayFromJSON(utf8(), R"(["bc", null, null, "a"])");
+
+  // If we have mixed dictionary/non-dictionary arguments, we decode 
dictionaries
+  CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1_dict, 
values2_decoded},
+              ArrayFromJSON(utf8(), R"(["a", null, null, null])"));
+  CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1_decoded, 
values2_dict},
+              ArrayFromJSON(utf8(), R"(["a", null, null, null])"));
+  CheckScalar("case_when",
+              {MakeStruct({cond1, cond2}), values1_dict, values2_dict, 
values1_decoded},
+              ArrayFromJSON(utf8(), R"(["a", null, null, null])"));
+  CheckScalar("case_when",
+              {MakeStruct({cond1, cond2}), values_null, values2_dict, 
values1_decoded},
+              ArrayFromJSON(utf8(), R"([null, null, null, null])"));
+}
+
+TYPED_TEST(TestCaseWhenDict, NestedSimple) {
+  auto make_list = [](const std::shared_ptr<Array>& indices,
+                      const std::shared_ptr<Array>& backing_array) {
+    EXPECT_OK_AND_ASSIGN(auto result, ListArray::FromArrays(*indices, 
*backing_array));
+    return result;
+  };
+  auto index_type = default_type_instance<TypeParam>();
+  auto inner_type = dictionary(index_type, utf8());
+  auto type = list(inner_type);
+  auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]");
+  auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]");
+  auto dict = R"(["a", "b", "bc", "def"])";
+  auto values_null = make_list(ArrayFromJSON(int32(), "[null, null, null, 
null, 0]"),
+                               DictArrayFromJSON(inner_type, "[]", dict));
+  auto values1_backing = DictArrayFromJSON(inner_type, "[0, null, 3, 1]", 
dict);
+  auto values2_backing = DictArrayFromJSON(inner_type, "[2, 1, null, 0]", 
dict);
+  auto values1 = make_list(ArrayFromJSON(int32(), "[0, 2, 2, 3, 4]"), 
values1_backing);
+  auto values2 = make_list(ArrayFromJSON(int32(), "[0, 1, 2, 2, 4]"), 
values2_backing);
+
+  CheckScalarNonRecursive(
+      "case_when", {MakeStruct({cond1, cond2}), values1, values2},
+      make_list(ArrayFromJSON(int32(), "[0, 2, 2, null, 2]"),
+                DictArrayFromJSON(inner_type, "[0, null]", R"(["a"])")));
+  CheckScalarNonRecursive(
+      "case_when",
+      {MakeStruct({cond1, cond2}), values1,
+       make_list(ArrayFromJSON(int32(), "[0, 1, null, 2, 4]"), 
values2_backing)},
+      make_list(ArrayFromJSON(int32(), "[0, 2, null, null, 2]"),
+                DictArrayFromJSON(inner_type, "[0, null]", R"(["a"])")));
+  CheckScalarNonRecursive(
+      "case_when",
+      {MakeStruct({cond1, cond2}), values1,
+       make_list(ArrayFromJSON(int32(), "[0, 1, null, 2, 4]"), 
values2_backing), values1},
+      make_list(ArrayFromJSON(int32(), "[0, 2, null, 2, 3]"),
+                DictArrayFromJSON(inner_type, "[0, null, 1]", R"(["a", 
"b"])")));
+
+  CheckScalarNonRecursive(
+      "case_when",
+      {
+          Datum(MakeStruct({cond1, cond2})),
+          Datum(std::make_shared<ListScalar>(
+              DictArrayFromJSON(inner_type, "[0, 1]", dict))),
+          Datum(std::make_shared<ListScalar>(
+              DictArrayFromJSON(inner_type, "[2, 3]", dict))),
+      },
+      make_list(ArrayFromJSON(int32(), "[0, 2, 4, null, 6]"),
+                DictArrayFromJSON(inner_type, "[0, 1, 0, 1, 2, 3]", dict)));
+
+  CheckScalarNonRecursive(
+      "case_when", {MakeStruct({Datum(true), Datum(false)}), values1, 
values2}, values1);
+  CheckScalarNonRecursive(
+      "case_when", {MakeStruct({Datum(false), Datum(true)}), values1, 
values2}, values2);
+  CheckScalarNonRecursive("case_when", {MakeStruct({Datum(false)}), values1, 
values2},
+                          values2);
+  CheckScalarNonRecursive("case_when",
+                          {MakeStruct({Datum(false), Datum(false)}), values1, 
values2},
+                          values_null);
+}
+
+TYPED_TEST(TestCaseWhenDict, DifferentDictionaries) {
+  auto type = dictionary(default_type_instance<TypeParam>(), utf8());
+  auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]");
+  auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]");
+  auto dict1 = R"(["a", null, "bc", "def"])";
+  auto dict2 = R"(["bc", "foo", null, "a"])";
+  auto dict3 = R"(["def", "a", "a", "bc"])";
+  auto values1_null = DictArrayFromJSON(type, "[null, null, null, null]", 
dict1);
+  auto values2_null = DictArrayFromJSON(type, "[null, null, null, null]", 
dict2);
+  auto values1 = DictArrayFromJSON(type, "[0, null, 3, 1]", dict1);
+  auto values2 = DictArrayFromJSON(type, "[2, 1, null, 0]", dict2);
+  auto values3 = DictArrayFromJSON(type, "[0, 1, 2, 3]", dict3);
+
+  // For scalar conditions, we borrow the dictionary of the chosen output (or 
the first
+  // input when outputting null)
+  CheckScalar("case_when", {MakeStruct({Datum(true), Datum(false)}), values1, 
values2},
+              values1);
+  CheckScalar("case_when", {MakeStruct({Datum(false), Datum(true)}), values1, 
values2},
+              values2);
+  CheckScalar("case_when", {MakeStruct({Datum(false), Datum(false)}), values1, 
values2},
+              values1_null);
+  CheckScalar("case_when", {MakeStruct({Datum(false), Datum(false)}), values2, 
values1},
+              values2_null);
+
+  // For array conditions, we always borrow the dictionary of the first input
+  CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2},
+              DictArrayFromJSON(type, "[0, null, null, null]", dict1));
+  CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, 
values1},
+              DictArrayFromJSON(type, "[0, null, null, 1]", dict1));
+
+  // When mixing dictionaries, we try to map other dictionaries onto the first 
one
+  // Don't check the scalar cases since we don't remap dictionaries in that 
case
+  CheckScalarNonRecursive(
+      "case_when",
+      {MakeStruct({ArrayFromJSON(boolean(), "[true, true, false, false]")}), 
values1,
+       values2},
+      DictArrayFromJSON(type, "[0, null, null, 2]", dict1));
+  CheckScalarNonRecursive(
+      "case_when",
+      {MakeStruct({ArrayFromJSON(boolean(), "[true, true, false, false]"),
+                   ArrayFromJSON(boolean(), "[true, false, true, false]")}),
+       values1, values2},
+      DictArrayFromJSON(type, "[0, null, null, null]", dict1));
+  CheckScalarNonRecursive(
+      "case_when",
+      {MakeStruct({ArrayFromJSON(boolean(), "[false, false, false, false]"),
+                   ArrayFromJSON(boolean(), "[true, true, true, true]")}),
+       values1, values3},
+      DictArrayFromJSON(type, "[3, 0, 0, 2]", dict1));
+  CheckScalarNonRecursive(
+      "case_when",
+      {MakeStruct({ArrayFromJSON(boolean(), "[null, null, null, true]"),
+                   ArrayFromJSON(boolean(), "[true, true, true, true]")}),
+       values1, values3},
+      DictArrayFromJSON(type, "[3, 0, 0, 1]", dict1));
+  CheckScalarNonRecursive(
+      "case_when",
+      {
+          MakeStruct({ArrayFromJSON(boolean(), "[true, true, false, false]")}),
+          DictScalarFromJSON(type, "0", dict1),
+          DictScalarFromJSON(type, "0", dict2),
+      },
+      DictArrayFromJSON(type, "[0, 0, 2, 2]", dict1));
+  CheckScalarNonRecursive(
+      "case_when",
+      {
+          MakeStruct({ArrayFromJSON(boolean(), "[true, true, false, false]"),
+                      ArrayFromJSON(boolean(), "[false, false, true, true]")}),
+          DictScalarFromJSON(type, "0", dict1),
+          DictScalarFromJSON(type, "0", dict2),
+      },
+      DictArrayFromJSON(type, "[0, 0, 2, 2]", dict1));
+
+  // If we can't map values from a dictionary, then raise an error
+  // Unmappable value is in the else clause

Review comment:
       I had mostly tried to emulate the R/dplyr behavior as closely as 
possible: https://github.com/apache/arrow/pull/10724#discussion_r682676388
   
   But unification is honestly probably easier to implement for us, so I can 
switch to that instead.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to