aucahuasi commented on a change in pull request #11159:
URL: https://github.com/apache/arrow/pull/11159#discussion_r711609783
##########
File path: cpp/src/arrow/compute/kernels/scalar_nested_test.cc
##########
@@ -43,6 +43,74 @@ TEST(TestScalarNested, ListValueLength) {
"[3, null, 3, 3]");
}
+TEST(TestScalarNested, ListElementNonFixedListWithNulls) {
+ auto sample = "[[7, 5, 81], [6, null, 4, 7, 8], [3, 12, 2, 0], [1, 9],
null]";
+ auto types = NumericTypes();
+ auto index_types = IntTypes();
+ for (auto ty : types) {
+ for (auto list_type : {list(ty), large_list(ty)}) {
+ auto input = ArrayFromJSON(list_type, sample);
+ auto null_input = ArrayFromJSON(list_type, "[null]");
+ for (auto index_type : index_types) {
+ auto index = ScalarFromJSON(index_type, "1");
+ auto expected = ArrayFromJSON(ty, "[5, null, 12, 9, null]");
+ auto expected_null = ArrayFromJSON(ty, "[null]");
+ CheckScalar("list_element", {input, index}, expected);
+ CheckScalar("list_element", {null_input, index}, expected_null);
+ }
+ }
+ }
+}
+
+TEST(TestScalarNested, ListElementFixedList) {
+ auto sample = "[[7, 5, 81], [6, 4, 8], [3, 12, 2], [1, 43, 87]]";
+ auto types = NumericTypes();
+ auto index_types = IntTypes();
+ for (auto ty : types) {
+ auto input = ArrayFromJSON(fixed_size_list(ty, 3), sample);
+ for (auto index_type : index_types) {
+ auto index = ScalarFromJSON(index_type, "0");
+ auto expected = ArrayFromJSON(ty, "[7, 6, 3, 1]");
+ CheckScalar("list_element", {input, index}, expected);
+ }
+ }
+}
+
+TEST(TestScalarNested, ListElementInvalid) {
+ auto input_array = ArrayFromJSON(list(float32()), "[[0.1, 1.1], [0.2,
1.2]]");
+ auto input_scalar = ScalarFromJSON(list(float32()), "[0.1, 0.2]");
+
+ // invalid index: null
+ auto index = ScalarFromJSON(int32(), "null");
+ EXPECT_THAT(CallFunction("list_element", {input_array, index}),
+ Raises(StatusCode::Invalid));
+ EXPECT_THAT(CallFunction("list_element", {input_scalar, index}),
+ Raises(StatusCode::Invalid));
+
+ // invalid index: < 0
+ index = ScalarFromJSON(int32(), "-1");
+ EXPECT_THAT(CallFunction("list_element", {input_array, index}),
+ Raises(StatusCode::Invalid));
+ EXPECT_THAT(CallFunction("list_element", {input_scalar, index}),
+ Raises(StatusCode::Invalid));
+
+ // invalid index: >= list.length
+ index = ScalarFromJSON(int32(), "2");
+ EXPECT_THAT(CallFunction("list_element", {input_array, index}),
+ Raises(StatusCode::Invalid));
+ EXPECT_THAT(CallFunction("list_element", {input_scalar, index}),
+ Raises(StatusCode::Invalid));
+
+ // invalid input
+ input_array = ArrayFromJSON(list(float32()), "[[41, 6, 93], [], [2]]");
+ input_scalar = ScalarFromJSON(list(float32()), "[]");
+ index = ScalarFromJSON(int32(), "0");
+ EXPECT_THAT(CallFunction("list_element", {input_array, index}),
+ Raises(StatusCode::Invalid));
+ EXPECT_THAT(CallFunction("list_element", {input_scalar, index}),
+ Raises(StatusCode::Invalid));
+}
Review comment:
Given that the function is new, the error messages can change in the
near future. I would suggest not to be so specific for now and try to make
these improvements later. Also, some of the tests for invalid cases were using
the same approach, so I guess we will want a full review of this one and the
others for the future as well. Thanks anyway!
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]