lidavidm commented on a change in pull request #11022:
URL: https://github.com/apache/arrow/pull/11022#discussion_r703002476
##########
File path: cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
##########
@@ -624,6 +624,250 @@ TYPED_TEST(TestCaseWhenNumeric, ListOfType) {
ArrayFromJSON(type, R"([null, null, null, [6, null]])"));
}
+template <typename Type>
+class TestCaseWhenDict : public ::testing::Test {};
+
+struct JsonDict {
+ std::shared_ptr<DataType> type;
+ std::string value;
+};
+
+TYPED_TEST_SUITE(TestCaseWhenDict, IntegralArrowTypes);
+
+TYPED_TEST(TestCaseWhenDict, Simple) {
+ auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]");
+ auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]");
+ for (const auto& dict :
+ {JsonDict{utf8(), R"(["a", null, "bc", "def"])"},
+ JsonDict{int64(), "[1, null, 2, 3]"},
+ JsonDict{decimal256(3, 2), R"(["1.23", null, "3.45", "6.78"])"}}) {
+ auto type = dictionary(default_type_instance<TypeParam>(), dict.type);
+ auto values_null = DictArrayFromJSON(type, "[null, null, null, null]",
dict.value);
+ auto values1 = DictArrayFromJSON(type, "[0, null, 3, 1]", dict.value);
+ auto values2 = DictArrayFromJSON(type, "[2, 1, null, 0]", dict.value);
+
+ // Easy case: all arguments have the same dictionary
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2},
+ DictArrayFromJSON(type, "[0, null, null, null]", dict.value));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2,
values1},
+ DictArrayFromJSON(type, "[0, null, null, 1]", dict.value));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values_null,
values2, values1},
+ DictArrayFromJSON(type, "[null, null, null, 1]", dict.value));
+ }
+}
+
+TYPED_TEST(TestCaseWhenDict, Mixed) {
+ auto type = dictionary(default_type_instance<TypeParam>(), utf8());
+ auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]");
+ auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]");
+ auto dict = R"(["a", null, "bc", "def"])";
+ auto values_null = DictArrayFromJSON(type, "[null, null, null, null]", dict);
+ auto values1_dict = DictArrayFromJSON(type, "[0, null, 3, 1]", dict);
+ auto values1_decoded = ArrayFromJSON(utf8(), R"(["a", null, "def", null])");
+ auto values2_dict = DictArrayFromJSON(type, "[2, 1, null, 0]", dict);
+ auto values2_decoded = ArrayFromJSON(utf8(), R"(["bc", null, null, "a"])");
+
+ // If we have mixed dictionary/non-dictionary arguments, we decode
dictionaries
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1_dict,
values2_decoded},
+ ArrayFromJSON(utf8(), R"(["a", null, null, null])"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1_decoded,
values2_dict},
+ ArrayFromJSON(utf8(), R"(["a", null, null, null])"));
+ CheckScalar("case_when",
+ {MakeStruct({cond1, cond2}), values1_dict, values2_dict,
values1_decoded},
+ ArrayFromJSON(utf8(), R"(["a", null, null, null])"));
+ CheckScalar("case_when",
+ {MakeStruct({cond1, cond2}), values_null, values2_dict,
values1_decoded},
+ ArrayFromJSON(utf8(), R"([null, null, null, null])"));
+}
+
+TYPED_TEST(TestCaseWhenDict, NestedSimple) {
+ auto make_list = [](const std::shared_ptr<Array>& indices,
+ const std::shared_ptr<Array>& backing_array) {
+ EXPECT_OK_AND_ASSIGN(auto result, ListArray::FromArrays(*indices,
*backing_array));
+ return result;
+ };
+ auto index_type = default_type_instance<TypeParam>();
+ auto inner_type = dictionary(index_type, utf8());
+ auto type = list(inner_type);
+ auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]");
+ auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]");
+ auto dict = R"(["a", "b", "bc", "def"])";
+ auto values_null = make_list(ArrayFromJSON(int32(), "[null, null, null,
null, 0]"),
+ DictArrayFromJSON(inner_type, "[]", dict));
+ auto values1_backing = DictArrayFromJSON(inner_type, "[0, null, 3, 1]",
dict);
+ auto values2_backing = DictArrayFromJSON(inner_type, "[2, 1, null, 0]",
dict);
+ auto values1 = make_list(ArrayFromJSON(int32(), "[0, 2, 2, 3, 4]"),
values1_backing);
+ auto values2 = make_list(ArrayFromJSON(int32(), "[0, 1, 2, 2, 4]"),
values2_backing);
+
+ CheckScalarNonRecursive(
Review comment:
The scalar variant of the kernel will not produce the same dictionary
indices so the values do not compare equal. I'll add a comment to that effect.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]