dhruv9vats commented on a change in pull request #12162: URL: https://github.com/apache/arrow/pull/12162#discussion_r792729511
########## File path: cpp/src/arrow/compute/kernels/scalar_nested.cc ########## @@ -428,6 +428,290 @@ const FunctionDoc make_struct_doc{"Wrap Arrays into a StructArray", "specified through MakeStructOptions."), {"*args"}, "MakeStructOptions"}; +template <typename KeyType> +struct MapArrayLookupFunctor { + static Result<int64_t> FindOneMapValueIndex(const Array& keys, + const Scalar& query_key_scalar, + const int64_t start, const int64_t end, + const bool from_back = false) { + const auto query_key = UnboxScalar<KeyType>::Unbox(query_key_scalar); + int64_t index = 0; + int64_t match_idx = -1; + ARROW_UNUSED(VisitArrayValuesInline<KeyType>( + *keys.data(), + [&](decltype(query_key) key) -> Status { + if (index < start) { + ++index; + return Status::OK(); + } else if (index < end) { + if (key == query_key) { Review comment: This indexing check is to make sure the index is between [start, end) as `map_array.offsets()` approach is used, instead we could use something like you previously mentioned (inside `ExecMapArray`): ```cpp auto map = map_array.value_slice(map_array_idx); auto keys = checked_cast<const StructArray&>(*map).field(0); auto items = checked_cast<const StructArray&>(*map).field(1); ``` Does the latter incur some performance penalty? Which approach should be used? (Ditto below) ########## File path: cpp/src/arrow/compute/kernels/scalar_nested.cc ########## @@ -428,6 +428,290 @@ const FunctionDoc make_struct_doc{"Wrap Arrays into a StructArray", "specified through MakeStructOptions."), {"*args"}, "MakeStructOptions"}; +template <typename KeyType> +struct MapArrayLookupFunctor { + static Result<int64_t> FindOneMapValueIndex(const Array& keys, + const Scalar& query_key_scalar, + const int64_t start, const int64_t end, + const bool from_back = false) { + const auto query_key = UnboxScalar<KeyType>::Unbox(query_key_scalar); + int64_t index = 0; + int64_t match_idx = -1; + ARROW_UNUSED(VisitArrayValuesInline<KeyType>( + *keys.data(), + [&](decltype(query_key) key) -> Status { + if (index < start) { + ++index; + return Status::OK(); + } else if (index < end) { + if (key == query_key) { + if (!from_back) { + match_idx = index; + return Status::Cancelled("Found first matching key"); + } else { + match_idx = index; + } + } + ++index; + return Status::OK(); + } else { + return Status::Cancelled("End reached"); + } + }, + [&]() -> Status { + if (index < end) { + ++index; + return Status::OK(); + } else { + return Status::Cancelled("End reached"); + } + })); + + return match_idx; + } + + static Result<std::unique_ptr<ArrayBuilder>> GetBuiltArray( + const Array& keys, const Array& items, const Scalar& query_key_scalar, + bool& found_at_least_one_key, const int64_t& start, const int64_t& end, + KernelContext* ctx) { + std::unique_ptr<ArrayBuilder> builder; + RETURN_NOT_OK(MakeBuilder(ctx->memory_pool(), items.type(), &builder)); + const auto query_key = UnboxScalar<KeyType>::Unbox(query_key_scalar); + int64_t index = 0; + ARROW_UNUSED(VisitArrayValuesInline<KeyType>( + *keys.data(), + [&](decltype(query_key) key) -> Status { + if (index < start) { + ++index; + return Status::OK(); + } else if (index < end) { + if (key == query_key) { + found_at_least_one_key = true; + RETURN_NOT_OK(builder->AppendArraySlice(*items.data(), index, 1)); + } + ++index; + return Status::OK(); + } else { + return Status::Cancelled("End reached"); + } + }, + [&]() -> Status { + if (index < end) { + ++index; + return Status::OK(); + } else { + return Status::Cancelled("End reached"); + } + })); + + return std::move(builder); + } + + static Status ExecMapArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + const auto& options = OptionsWrapper<MapArrayLookupOptions>::Get(ctx); + const auto& query_key = options.query_key; + const auto& occurrence = options.occurrence; + const MapArray map_array(batch[0].array()); + + std::shared_ptr<arrow::Array> keys = map_array.keys(); + std::shared_ptr<arrow::Array> items = map_array.items(); + auto offsets = std::dynamic_pointer_cast<Int32Array>(map_array.offsets()); + + std::unique_ptr<ArrayBuilder> builder; + if (occurrence == MapArrayLookupOptions::Occurrence::ALL) { + RETURN_NOT_OK(MakeBuilder(ctx->memory_pool(), + list(map_array.map_type()->item_type()), &builder)); + + for (int64_t map_array_idx = 0; map_array_idx < map_array.length(); + ++map_array_idx) { + if (!map_array.IsValid(map_array_idx)) { + RETURN_NOT_OK(builder->AppendNull()); + continue; + } + + int64_t start = offsets->Value(map_array_idx); + int64_t end = offsets->Value(map_array_idx + 1); + std::unique_ptr<ArrayBuilder> list_builder; + bool found_at_least_one_key = false; + ARROW_ASSIGN_OR_RAISE( + list_builder, GetBuiltArray(*keys, *items, *query_key, found_at_least_one_key, + start, end, ctx)); + if (!found_at_least_one_key) { + RETURN_NOT_OK(builder->AppendNull()); + } else { + ARROW_ASSIGN_OR_RAISE(auto list_result, list_builder->Finish()); + RETURN_NOT_OK(builder->AppendScalar(ListScalar(list_result))); + } + list_builder->Reset(); + } + ARROW_ASSIGN_OR_RAISE(auto result, builder->Finish()); + out->value = result->data(); + } else { + RETURN_NOT_OK( + MakeBuilder(ctx->memory_pool(), map_array.map_type()->item_type(), &builder)); + + for (int64_t map_array_idx = 0; map_array_idx < map_array.length(); + ++map_array_idx) { + if (!map_array.IsValid(map_array_idx)) { + RETURN_NOT_OK(builder->AppendNull()); + continue; + } + int64_t start = offsets->Value(map_array_idx); + int64_t end = offsets->Value(map_array_idx + 1); + bool from_back = (occurrence == MapArrayLookupOptions::LAST); + + ARROW_ASSIGN_OR_RAISE( + int64_t key_match_idx, + FindOneMapValueIndex(*keys, *query_key, start, end, from_back)); + if (key_match_idx != -1) { + RETURN_NOT_OK(builder->AppendArraySlice(*items->data(), key_match_idx, 1)); + } else { + RETURN_NOT_OK(builder->AppendNull()); + } + } + ARROW_ASSIGN_OR_RAISE(auto result, builder->Finish()); + out->value = result->data(); + } + + return Status::OK(); + } + + static Status ExecMapScalar(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + const auto& options = OptionsWrapper<MapArrayLookupOptions>::Get(ctx); + const auto& query_key = options.query_key; + const auto& occurrence = options.occurrence; + + std::shared_ptr<DataType> item_type = + checked_cast<const MapType&>(*batch[0].type()).item_type(); + const auto& map_scalar = batch[0].scalar_as<MapScalar>(); + + if (ARROW_PREDICT_FALSE(!map_scalar.is_valid)) { + if (options.occurrence == MapArrayLookupOptions::Occurrence::ALL) { + out->value = MakeNullScalar(list(item_type)); + } else { + out->value = MakeNullScalar(item_type); + } + return Status::OK(); + } + + const auto& struct_array = checked_cast<const StructArray&>(*map_scalar.value); + const std::shared_ptr<Array> keys = struct_array.field(0); + const std::shared_ptr<Array> items = struct_array.field(1); + + if (occurrence == MapArrayLookupOptions::Occurrence::ALL) { + bool found_at_least_one_key = false; + std::unique_ptr<ArrayBuilder> builder; + ARROW_ASSIGN_OR_RAISE( + builder, GetBuiltArray(*keys, *items, *query_key, found_at_least_one_key, 0, + struct_array.length(), ctx)); + + if (!found_at_least_one_key) { + out->value = MakeNullScalar(list(items->type())); + } else { + ARROW_ASSIGN_OR_RAISE(auto result, builder->Finish()); + ARROW_ASSIGN_OR_RAISE(out->value, MakeScalar(list(items->type()), result)); + } + } else { /* occurrence == FIRST || LAST */ + bool from_back = (occurrence == MapArrayLookupOptions::LAST); + + ARROW_ASSIGN_OR_RAISE( + int64_t key_match_idx, + FindOneMapValueIndex(*keys, *query_key, 0, struct_array.length(), from_back)); + if (key_match_idx != -1) { + ARROW_ASSIGN_OR_RAISE(out->value, items->GetScalar(key_match_idx)); + } else { + out->value = MakeNullScalar(items->type()); + } + } + return Status::OK(); + } +}; + +Result<ValueDescr> ResolveMapArrayLookupType(KernelContext* ctx, + const std::vector<ValueDescr>& descrs) { + const auto& options = OptionsWrapper<MapArrayLookupOptions>::Get(ctx); + std::shared_ptr<DataType> type = descrs.front().type; + std::shared_ptr<DataType> item_type = checked_cast<const MapType&>(*type).item_type(); + std::shared_ptr<DataType> key_type = checked_cast<const MapType&>(*type).key_type(); + + if (!options.query_key || !options.query_key->type || + !options.query_key->type->Equals(key_type)) { + return Status::TypeError( + "map_array_lookup: query_key type and MapArray key_type don't match. Expected " + "type: ", + *item_type, ", but got type: ", *options.query_key->type); + } + + if (options.occurrence == MapArrayLookupOptions::Occurrence::ALL) { + return ValueDescr(list(item_type), descrs.front().shape); + } else { /* occurrence == FIRST || LAST */ + return ValueDescr(item_type, descrs.front().shape); + } +} + +struct ResolveMapArrayLookup { + KernelContext* ctx; + const ExecBatch& batch; + Datum* out; + + template <typename KeyType> + Status Execute() { + if (batch[0].kind() == Datum::SCALAR) { + return MapArrayLookupFunctor<KeyType>::ExecMapScalar(ctx, batch, out); + } + return MapArrayLookupFunctor<KeyType>::ExecMapArray(ctx, batch, out); + } + + template <typename KeyType> + enable_if_physical_integer<KeyType, Status> Visit(const KeyType& type) { + return Execute<KeyType>(); + } + + template <typename KeyType> + enable_if_decimal<KeyType, Status> Visit(const KeyType& type) { + return Execute<KeyType>(); + } + + template <typename KeyType> + enable_if_base_binary<KeyType, Status> Visit(const KeyType& type) { + return Execute<KeyType>(); + } + + template <typename KeyType> + enable_if_boolean<KeyType, Status> Visit(const KeyType& type) { + return Execute<KeyType>(); + } + Review comment: Current problems: - `enable_if_fixed_size_binary` conflicts with `enable_if_decimal` - MonthDayNanoInterval ########## File path: cpp/src/arrow/compute/kernels/scalar_nested_test.cc ########## @@ -225,6 +225,294 @@ TEST(TestScalarNested, StructField) { } } +void CheckMapArrayLookupWithDifferentOptions( + const std::shared_ptr<Array>& map, const std::shared_ptr<Scalar>& query_key, + const std::shared_ptr<Array>& expected_all, + const std::shared_ptr<Array>& expected_first, + const std::shared_ptr<Array>& expected_last) { + MapArrayLookupOptions all_matches(query_key, MapArrayLookupOptions::ALL); + MapArrayLookupOptions first_matches(query_key, MapArrayLookupOptions::FIRST); + MapArrayLookupOptions last_matches(query_key, MapArrayLookupOptions::LAST); + + CheckScalar("map_array_lookup", {map}, expected_all, &all_matches); + CheckScalar("map_array_lookup", {map}, expected_first, &first_matches); + CheckScalar("map_array_lookup", {map}, expected_last, &last_matches); +} + +class TestMapArrayLookupKernel : public ::testing::Test {}; + +TEST_F(TestMapArrayLookupKernel, Basic) { + auto type = map(utf8(), int32()); + const char* input = R"( + [ + [["foo", 99], ["bar", 1], ["hello", 2], ["foo", 3], ["lets go", 5], ["what now?", 8]], + null, + [["nothing", null], ["hat", null], ["foo", 101], ["sorry", 1], ["dip", null], + ["foo", 22]], + [] + ])"; + auto map_array = ArrayFromJSON(type, input); + CheckMapArrayLookupWithDifferentOptions( + map_array, MakeScalar("foo"), + ArrayFromJSON(list(int32()), R"([[99, 3], null, [101, 22], null])"), + ArrayFromJSON(int32(), R"([99, null, 101, null])"), + ArrayFromJSON(int32(), R"([3, null, 22, null])")); +} + +TEST_F(TestMapArrayLookupKernel, NestedItems) { Review comment: Will remove these two at the end? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org