lidavidm commented on a change in pull request #12162:
URL: https://github.com/apache/arrow/pull/12162#discussion_r797691872
##########
File path: cpp/src/arrow/compute/kernels/scalar_nested.cc
##########
@@ -428,6 +429,264 @@ const FunctionDoc make_struct_doc{"Wrap Arrays into a
StructArray",
"specified through MakeStructOptions."),
{"*args"},
"MakeStructOptions"};
+template <typename KeyType>
+struct MapLookupFunctor {
+ static Result<int64_t> GetOneMatchingIndex(const Array& keys,
+ const Scalar& query_key_scalar,
+ const bool* from_back) {
+ int64_t match_index = -1;
+ RETURN_NOT_OK(
+ FindMatchingIndices(keys, query_key_scalar, [&](int64_t index) ->
Status {
+ match_index = index;
+ if (*from_back) {
+ return Status::OK();
+ } else {
+ return Status::Cancelled("Found key match for FIRST");
+ }
+ }));
+
+ return match_index;
+ }
+
+ template <typename FoundItem>
+ static Status FindMatchingIndices(const Array& keys, const Scalar&
query_key_scalar,
+ FoundItem callback) {
+ const auto query_key = UnboxScalar<KeyType>::Unbox(query_key_scalar);
+ int64_t index = 0;
+ Status status = VisitArrayValuesInline<KeyType>(
+ *keys.data(),
+ [&](decltype(query_key) key) -> Status {
+ if (key == query_key) {
+ return callback(index++);
+ }
+ ++index;
+ return Status::OK();
+ },
+ [&]() -> Status {
+ ++index;
+ return Status::OK();
+ });
+ if (!status.ok() && !status.IsCancelled()) {
+ return status;
+ }
+ return Status::OK();
+ }
+
+ static Status ExecMapArray(KernelContext* ctx, const ExecBatch& batch,
Datum* out) {
+ const auto& options = OptionsWrapper<MapLookupOptions>::Get(ctx);
+ const auto& query_key = options.query_key;
+ const auto& occurrence = options.occurrence;
+ const MapArray map_array(batch[0].array());
+
+ std::unique_ptr<ArrayBuilder> builder;
+ if (occurrence == MapLookupOptions::Occurrence::ALL) {
+ RETURN_NOT_OK(MakeBuilder(ctx->memory_pool(),
+ list(map_array.map_type()->item_type()),
&builder));
+ auto list_builder = checked_cast<ListBuilder*>(builder.get());
+ auto value_builder = list_builder->value_builder();
+
+ for (int64_t map_array_idx = 0; map_array_idx < map_array.length();
+ ++map_array_idx) {
+ if (!map_array.IsValid(map_array_idx)) {
+ RETURN_NOT_OK(list_builder->AppendNull());
+ continue;
+ }
+
+ auto map = map_array.value_slice(map_array_idx);
+ auto keys = checked_cast<const StructArray&>(*map).field(0);
+ auto items = checked_cast<const StructArray&>(*map).field(1);
+ bool found_at_least_one_key = false;
+ RETURN_NOT_OK(
+ FindMatchingIndices(*keys, *query_key, [&](int64_t index) ->
Status {
+ if (!found_at_least_one_key)
RETURN_NOT_OK(list_builder->Append(true));
+ found_at_least_one_key = true;
+ RETURN_NOT_OK(value_builder->AppendArraySlice(*items->data(),
index, 1));
+ return Status::OK();
+ }));
+ if (!found_at_least_one_key) {
+ RETURN_NOT_OK(list_builder->AppendNull());
+ }
+ }
+ ARROW_ASSIGN_OR_RAISE(auto result, list_builder->Finish());
+ out->value = result->data();
+ } else { /* occurrence == FIRST || LAST */
+ RETURN_NOT_OK(
+ MakeBuilder(ctx->memory_pool(), map_array.map_type()->item_type(),
&builder));
+ RETURN_NOT_OK(builder->Reserve(batch.length));
+ for (int64_t map_array_idx = 0; map_array_idx < map_array.length();
+ ++map_array_idx) {
+ if (!map_array.IsValid(map_array_idx)) {
+ RETURN_NOT_OK(builder->AppendNull());
+ continue;
+ }
+
+ auto map = map_array.value_slice(map_array_idx);
+ auto keys = checked_cast<const StructArray&>(*map).field(0);
+ auto items = checked_cast<const StructArray&>(*map).field(1);
+ bool from_back = (occurrence == MapLookupOptions::LAST);
+ ARROW_ASSIGN_OR_RAISE(int64_t key_match_idx,
+ GetOneMatchingIndex(*keys, *query_key,
&from_back));
+
+ if (key_match_idx != -1) {
+ RETURN_NOT_OK(builder->AppendArraySlice(*items->data(),
key_match_idx, 1));
+ } else {
+ RETURN_NOT_OK(builder->AppendNull());
+ }
+ }
+ ARROW_ASSIGN_OR_RAISE(auto result, builder->Finish());
+ out->value = result->data();
+ }
+
+ return Status::OK();
+ }
+
+ static Status ExecMapScalar(KernelContext* ctx, const ExecBatch& batch,
Datum* out) {
+ const auto& options = OptionsWrapper<MapLookupOptions>::Get(ctx);
+ const auto& query_key = options.query_key;
+ const auto& occurrence = options.occurrence;
+
+ std::shared_ptr<DataType> item_type =
+ checked_cast<const MapType&>(*batch[0].type()).item_type();
+ const auto& map_scalar = batch[0].scalar_as<MapScalar>();
+
+ if (ARROW_PREDICT_FALSE(!map_scalar.is_valid)) {
+ if (options.occurrence == MapLookupOptions::Occurrence::ALL) {
+ out->value = MakeNullScalar(list(item_type));
+ } else {
+ out->value = MakeNullScalar(item_type);
+ }
+ return Status::OK();
+ }
+
+ const auto& struct_array = checked_cast<const
StructArray&>(*map_scalar.value);
+ const std::shared_ptr<Array> keys = struct_array.field(0);
+ const std::shared_ptr<Array> items = struct_array.field(1);
+
+ if (occurrence == MapLookupOptions::Occurrence::ALL) {
+ bool found_at_least_one_key = false;
+ std::unique_ptr<ArrayBuilder> builder;
+ RETURN_NOT_OK(MakeBuilder(ctx->memory_pool(), items->type(), &builder));
+
+ RETURN_NOT_OK(FindMatchingIndices(*keys, *query_key, [&](int64_t index)
-> Status {
+ found_at_least_one_key = true;
+ RETURN_NOT_OK(builder->AppendArraySlice(*items->data(), index, 1));
+ return Status::OK();
+ }));
+ if (!found_at_least_one_key) {
+ out->value = MakeNullScalar(list(items->type()));
+ } else {
+ ARROW_ASSIGN_OR_RAISE(auto result, builder->Finish());
+ ARROW_ASSIGN_OR_RAISE(out->value, MakeScalar(list(items->type()),
result));
+ }
+ } else { /* occurrence == FIRST || LAST */
+ bool from_back = (occurrence == MapLookupOptions::LAST);
+
+ ARROW_ASSIGN_OR_RAISE(int64_t key_match_idx,
+ GetOneMatchingIndex(*keys, *query_key,
&from_back));
+ if (key_match_idx != -1) {
+ ARROW_ASSIGN_OR_RAISE(out->value, items->GetScalar(key_match_idx));
+ } else {
+ out->value = MakeNullScalar(items->type());
+ }
+ }
+ return Status::OK();
+ }
+};
+
+Result<ValueDescr> ResolveMapLookupType(KernelContext* ctx,
+ const std::vector<ValueDescr>& descrs)
{
+ const auto& options = OptionsWrapper<MapLookupOptions>::Get(ctx);
+ std::shared_ptr<DataType> type = descrs.front().type;
+ std::shared_ptr<DataType> item_type = checked_cast<const
MapType&>(*type).item_type();
+ std::shared_ptr<DataType> key_type = checked_cast<const
MapType&>(*type).key_type();
+
+ if (!options.query_key) {
+ return Status::Invalid("map_lookup: query_key can't be empty.");
+ } else if (!options.query_key->is_valid) {
+ return Status::Invalid("map_lookup: query_key can't be null.");
+ } else if (!options.query_key->type ||
!options.query_key->type->Equals(key_type)) {
+ return Status::TypeError(
+ "map_lookup: query_key type and Map key_type don't match. Expected "
+ "type: ",
+ *key_type, ", but got type: ", *options.query_key->type);
+ }
+
+ if (options.occurrence == MapLookupOptions::Occurrence::ALL) {
+ return ValueDescr(list(item_type), descrs.front().shape);
+ } else { /* occurrence == FIRST || LAST */
+ return ValueDescr(item_type, descrs.front().shape);
+ }
+}
+
+struct ResolveMapLookup {
+ KernelContext* ctx;
+ const ExecBatch& batch;
+ Datum* out;
+
+ template <typename KeyType>
+ Status Execute() {
+ if (batch[0].kind() == Datum::SCALAR) {
+ return MapLookupFunctor<KeyType>::ExecMapScalar(ctx, batch, out);
+ }
+ return MapLookupFunctor<KeyType>::ExecMapArray(ctx, batch, out);
+ }
+
+ template <typename KeyType>
+ enable_if_physical_integer<KeyType, Status> Visit(const KeyType& type) {
+ return Execute<KeyType>();
+ }
+
+ template <typename KeyType>
+ enable_if_decimal<KeyType, Status> Visit(const KeyType& type) {
+ return Execute<KeyType>();
+ }
+
+ template <typename KeyType>
+ enable_if_base_binary<KeyType, Status> Visit(const KeyType& type) {
+ return Execute<KeyType>();
+ }
+
+ template <typename KeyType>
+ enable_if_boolean<KeyType, Status> Visit(const KeyType& type) {
+ return Execute<KeyType>();
+ }
+
+ template <typename KeyType>
+ enable_if_same<KeyType, FixedSizeBinaryType, Status> Visit(const KeyType&
key) {
+ return Execute<KeyType>();
+ }
Review comment:
Does this work if it's just a normal overload? `Status Visit(const
FixedSizeBinaryType& key)`
##########
File path: cpp/src/arrow/compute/kernels/scalar_nested.cc
##########
@@ -428,6 +429,264 @@ const FunctionDoc make_struct_doc{"Wrap Arrays into a
StructArray",
"specified through MakeStructOptions."),
{"*args"},
"MakeStructOptions"};
+template <typename KeyType>
+struct MapLookupFunctor {
+ static Result<int64_t> GetOneMatchingIndex(const Array& keys,
+ const Scalar& query_key_scalar,
+ const bool* from_back) {
+ int64_t match_index = -1;
+ RETURN_NOT_OK(
+ FindMatchingIndices(keys, query_key_scalar, [&](int64_t index) ->
Status {
+ match_index = index;
+ if (*from_back) {
+ return Status::OK();
+ } else {
+ return Status::Cancelled("Found key match for FIRST");
+ }
+ }));
+
+ return match_index;
+ }
+
+ template <typename FoundItem>
+ static Status FindMatchingIndices(const Array& keys, const Scalar&
query_key_scalar,
+ FoundItem callback) {
+ const auto query_key = UnboxScalar<KeyType>::Unbox(query_key_scalar);
+ int64_t index = 0;
+ Status status = VisitArrayValuesInline<KeyType>(
+ *keys.data(),
+ [&](decltype(query_key) key) -> Status {
+ if (key == query_key) {
+ return callback(index++);
+ }
+ ++index;
+ return Status::OK();
+ },
+ [&]() -> Status {
+ ++index;
+ return Status::OK();
+ });
+ if (!status.ok() && !status.IsCancelled()) {
+ return status;
+ }
+ return Status::OK();
+ }
+
+ static Status ExecMapArray(KernelContext* ctx, const ExecBatch& batch,
Datum* out) {
+ const auto& options = OptionsWrapper<MapLookupOptions>::Get(ctx);
+ const auto& query_key = options.query_key;
+ const auto& occurrence = options.occurrence;
+ const MapArray map_array(batch[0].array());
+
+ std::unique_ptr<ArrayBuilder> builder;
+ if (occurrence == MapLookupOptions::Occurrence::ALL) {
+ RETURN_NOT_OK(MakeBuilder(ctx->memory_pool(),
+ list(map_array.map_type()->item_type()),
&builder));
+ auto list_builder = checked_cast<ListBuilder*>(builder.get());
+ auto value_builder = list_builder->value_builder();
+
+ for (int64_t map_array_idx = 0; map_array_idx < map_array.length();
+ ++map_array_idx) {
+ if (!map_array.IsValid(map_array_idx)) {
+ RETURN_NOT_OK(list_builder->AppendNull());
+ continue;
+ }
+
+ auto map = map_array.value_slice(map_array_idx);
+ auto keys = checked_cast<const StructArray&>(*map).field(0);
+ auto items = checked_cast<const StructArray&>(*map).field(1);
+ bool found_at_least_one_key = false;
+ RETURN_NOT_OK(
+ FindMatchingIndices(*keys, *query_key, [&](int64_t index) ->
Status {
+ if (!found_at_least_one_key)
RETURN_NOT_OK(list_builder->Append(true));
+ found_at_least_one_key = true;
+ RETURN_NOT_OK(value_builder->AppendArraySlice(*items->data(),
index, 1));
+ return Status::OK();
+ }));
+ if (!found_at_least_one_key) {
+ RETURN_NOT_OK(list_builder->AppendNull());
+ }
+ }
+ ARROW_ASSIGN_OR_RAISE(auto result, list_builder->Finish());
+ out->value = result->data();
+ } else { /* occurrence == FIRST || LAST */
+ RETURN_NOT_OK(
+ MakeBuilder(ctx->memory_pool(), map_array.map_type()->item_type(),
&builder));
+ RETURN_NOT_OK(builder->Reserve(batch.length));
+ for (int64_t map_array_idx = 0; map_array_idx < map_array.length();
+ ++map_array_idx) {
+ if (!map_array.IsValid(map_array_idx)) {
+ RETURN_NOT_OK(builder->AppendNull());
+ continue;
+ }
+
+ auto map = map_array.value_slice(map_array_idx);
+ auto keys = checked_cast<const StructArray&>(*map).field(0);
+ auto items = checked_cast<const StructArray&>(*map).field(1);
+ bool from_back = (occurrence == MapLookupOptions::LAST);
+ ARROW_ASSIGN_OR_RAISE(int64_t key_match_idx,
+ GetOneMatchingIndex(*keys, *query_key,
&from_back));
+
+ if (key_match_idx != -1) {
+ RETURN_NOT_OK(builder->AppendArraySlice(*items->data(),
key_match_idx, 1));
+ } else {
+ RETURN_NOT_OK(builder->AppendNull());
+ }
+ }
+ ARROW_ASSIGN_OR_RAISE(auto result, builder->Finish());
+ out->value = result->data();
+ }
+
+ return Status::OK();
+ }
+
+ static Status ExecMapScalar(KernelContext* ctx, const ExecBatch& batch,
Datum* out) {
+ const auto& options = OptionsWrapper<MapLookupOptions>::Get(ctx);
+ const auto& query_key = options.query_key;
+ const auto& occurrence = options.occurrence;
+
+ std::shared_ptr<DataType> item_type =
+ checked_cast<const MapType&>(*batch[0].type()).item_type();
+ const auto& map_scalar = batch[0].scalar_as<MapScalar>();
+
+ if (ARROW_PREDICT_FALSE(!map_scalar.is_valid)) {
+ if (options.occurrence == MapLookupOptions::Occurrence::ALL) {
+ out->value = MakeNullScalar(list(item_type));
+ } else {
+ out->value = MakeNullScalar(item_type);
+ }
+ return Status::OK();
+ }
+
+ const auto& struct_array = checked_cast<const
StructArray&>(*map_scalar.value);
+ const std::shared_ptr<Array> keys = struct_array.field(0);
+ const std::shared_ptr<Array> items = struct_array.field(1);
+
+ if (occurrence == MapLookupOptions::Occurrence::ALL) {
+ bool found_at_least_one_key = false;
+ std::unique_ptr<ArrayBuilder> builder;
+ RETURN_NOT_OK(MakeBuilder(ctx->memory_pool(), items->type(), &builder));
+
+ RETURN_NOT_OK(FindMatchingIndices(*keys, *query_key, [&](int64_t index)
-> Status {
+ found_at_least_one_key = true;
+ RETURN_NOT_OK(builder->AppendArraySlice(*items->data(), index, 1));
+ return Status::OK();
+ }));
+ if (!found_at_least_one_key) {
+ out->value = MakeNullScalar(list(items->type()));
+ } else {
+ ARROW_ASSIGN_OR_RAISE(auto result, builder->Finish());
+ ARROW_ASSIGN_OR_RAISE(out->value, MakeScalar(list(items->type()),
result));
+ }
+ } else { /* occurrence == FIRST || LAST */
+ bool from_back = (occurrence == MapLookupOptions::LAST);
+
+ ARROW_ASSIGN_OR_RAISE(int64_t key_match_idx,
+ GetOneMatchingIndex(*keys, *query_key,
&from_back));
+ if (key_match_idx != -1) {
+ ARROW_ASSIGN_OR_RAISE(out->value, items->GetScalar(key_match_idx));
+ } else {
+ out->value = MakeNullScalar(items->type());
+ }
+ }
+ return Status::OK();
+ }
+};
+
+Result<ValueDescr> ResolveMapLookupType(KernelContext* ctx,
+ const std::vector<ValueDescr>& descrs)
{
+ const auto& options = OptionsWrapper<MapLookupOptions>::Get(ctx);
+ std::shared_ptr<DataType> type = descrs.front().type;
+ std::shared_ptr<DataType> item_type = checked_cast<const
MapType&>(*type).item_type();
+ std::shared_ptr<DataType> key_type = checked_cast<const
MapType&>(*type).key_type();
+
+ if (!options.query_key) {
+ return Status::Invalid("map_lookup: query_key can't be empty.");
+ } else if (!options.query_key->is_valid) {
+ return Status::Invalid("map_lookup: query_key can't be null.");
+ } else if (!options.query_key->type ||
!options.query_key->type->Equals(key_type)) {
Review comment:
Hmm, `query_key->type` should never be nullptr, so the check is
redundant. (And if the check weren't redundant, we'd have undefined behavior
below on line 611 when we dereference the pointer.)
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]