This is an automated email from the ASF dual-hosted git repository.
apitrou pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 2519230121 ARROW-16989: [C++] Substrait ProjectRel is interpreted
incorrectly (#13528)
2519230121 is described below
commit 2519230121b9be3ecac01ac3ed2b610382dbca48
Author: Jeroen van Straten <[email protected]>
AuthorDate: Tue Jul 12 09:37:42 2022 +0200
ARROW-16989: [C++] Substrait ProjectRel is interpreted incorrectly (#13528)
A Substrait ProjectRel *appends* columns to the dataset, while Acero's
project node replaces them (emit clauses are instead used to remove or swizzle
columns). This PR prefixes the current columns in the project node to make the
two compatible.
I don't think a declaration includes information about the number of
columns it generates, so I had to refactor a little bit to have relation
ToProto return a struct of the declaration and the number of columns, in order
to know how many columns to replicate. I expect this struct to grow in
importance and features when ARROW-16986 is addressed.
Authored-by: Jeroen van Straten <[email protected]>
Signed-off-by: Antoine Pitrou <[email protected]>
---
.../arrow/engine/substrait/relation_internal.cc | 45 ++++++++++++++--------
cpp/src/arrow/engine/substrait/relation_internal.h | 11 +++++-
cpp/src/arrow/engine/substrait/serde.cc | 10 +++--
3 files changed, 46 insertions(+), 20 deletions(-)
diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc
b/cpp/src/arrow/engine/substrait/relation_internal.cc
index dce66eccf8..09ecb2f069 100644
--- a/cpp/src/arrow/engine/substrait/relation_internal.cc
+++ b/cpp/src/arrow/engine/substrait/relation_internal.cc
@@ -52,8 +52,8 @@ Status CheckRelCommon(const RelMessage& rel) {
return Status::OK();
}
-Result<compute::Declaration> FromProto(const substrait::Rel& rel,
- const ExtensionSet& ext_set) {
+Result<DeclarationInfo> FromProto(const substrait::Rel& rel,
+ const ExtensionSet& ext_set) {
static bool dataset_init = false;
if (!dataset_init) {
dataset_init = true;
@@ -180,10 +180,13 @@ Result<compute::Declaration> FromProto(const
substrait::Rel& rel,
std::move(filesystem),
std::move(files),
std::move(format), {}));
+ auto num_columns = static_cast<int>(base_schema->fields().size());
ARROW_ASSIGN_OR_RAISE(auto ds,
ds_factory->Finish(std::move(base_schema)));
- return compute::Declaration{
- "scan", dataset::ScanNodeOptions{std::move(ds),
std::move(scan_options)}};
+ return DeclarationInfo{
+ compute::Declaration{
+ "scan", dataset::ScanNodeOptions{std::move(ds),
std::move(scan_options)}},
+ num_columns};
}
case substrait::Rel::RelTypeCase::kFilter: {
@@ -200,10 +203,12 @@ Result<compute::Declaration> FromProto(const
substrait::Rel& rel,
}
ARROW_ASSIGN_OR_RAISE(auto condition, FromProto(filter.condition(),
ext_set));
- return compute::Declaration::Sequence({
- std::move(input),
- {"filter", compute::FilterNodeOptions{std::move(condition)}},
- });
+ return DeclarationInfo{
+ compute::Declaration::Sequence({
+ std::move(input.declaration),
+ {"filter", compute::FilterNodeOptions{std::move(condition)}},
+ }),
+ input.num_columns};
}
case substrait::Rel::RelTypeCase::kProject: {
@@ -215,16 +220,25 @@ Result<compute::Declaration> FromProto(const
substrait::Rel& rel,
}
ARROW_ASSIGN_OR_RAISE(auto input, FromProto(project.input(), ext_set));
+ // NOTE: Substrait ProjectRels *append* columns, while Acero's project
node replaces
+ // them. Therefore, we need to prefix all the current columns for
compatibility.
std::vector<compute::Expression> expressions;
+ expressions.reserve(input.num_columns + project.expressions().size());
+ for (int i = 0; i < input.num_columns; i++) {
+ expressions.emplace_back(compute::field_ref(FieldRef(i)));
+ }
for (const auto& expr : project.expressions()) {
expressions.emplace_back();
ARROW_ASSIGN_OR_RAISE(expressions.back(), FromProto(expr, ext_set));
}
- return compute::Declaration::Sequence({
- std::move(input),
- {"project", compute::ProjectNodeOptions{std::move(expressions)}},
- });
+ auto num_columns = static_cast<int>(expressions.size());
+ return DeclarationInfo{
+ compute::Declaration::Sequence({
+ std::move(input.declaration),
+ {"project", compute::ProjectNodeOptions{std::move(expressions)}},
+ }),
+ num_columns};
}
case substrait::Rel::RelTypeCase::kJoin: {
@@ -304,9 +318,10 @@ Result<compute::Declaration> FromProto(const
substrait::Rel& rel,
join_options.join_type = join_type;
join_options.key_cmp = {join_key_cmp};
compute::Declaration join_dec{"hashjoin", std::move(join_options)};
- join_dec.inputs.emplace_back(std::move(left));
- join_dec.inputs.emplace_back(std::move(right));
- return std::move(join_dec);
+ auto num_columns = left.num_columns + right.num_columns;
+ join_dec.inputs.emplace_back(std::move(left.declaration));
+ join_dec.inputs.emplace_back(std::move(right.declaration));
+ return DeclarationInfo{std::move(join_dec), num_columns};
}
default:
diff --git a/cpp/src/arrow/engine/substrait/relation_internal.h
b/cpp/src/arrow/engine/substrait/relation_internal.h
index ec56a2d359..4a8b6c209c 100644
--- a/cpp/src/arrow/engine/substrait/relation_internal.h
+++ b/cpp/src/arrow/engine/substrait/relation_internal.h
@@ -30,8 +30,17 @@
namespace arrow {
namespace engine {
+/// Information resulting from converting a Substrait relation.
+struct DeclarationInfo {
+ /// The compute declaration produced thus far.
+ compute::Declaration declaration;
+
+ /// The number of columns returned by the declaration.
+ int num_columns;
+};
+
ARROW_ENGINE_EXPORT
-Result<compute::Declaration> FromProto(const substrait::Rel&, const
ExtensionSet&);
+Result<DeclarationInfo> FromProto(const substrait::Rel&, const ExtensionSet&);
} // namespace engine
} // namespace arrow
diff --git a/cpp/src/arrow/engine/substrait/serde.cc
b/cpp/src/arrow/engine/substrait/serde.cc
index dda41c282a..af189da1bb 100644
--- a/cpp/src/arrow/engine/substrait/serde.cc
+++ b/cpp/src/arrow/engine/substrait/serde.cc
@@ -55,7 +55,8 @@ Result<Message> ParseFromBuffer(const Buffer& buf) {
Result<compute::Declaration> DeserializeRelation(const Buffer& buf,
const ExtensionSet& ext_set) {
ARROW_ASSIGN_OR_RAISE(auto rel, ParseFromBuffer<substrait::Rel>(buf));
- return FromProto(rel, ext_set);
+ ARROW_ASSIGN_OR_RAISE(auto decl_info, FromProto(rel, ext_set));
+ return std::move(decl_info.declaration);
}
using DeclarationFactory = std::function<Result<compute::Declaration>(
@@ -121,7 +122,7 @@ Result<std::vector<compute::Declaration>> DeserializePlans(
std::vector<compute::Declaration> sink_decls;
for (const substrait::PlanRel& plan_rel : plan.relations()) {
ARROW_ASSIGN_OR_RAISE(
- auto decl,
+ auto decl_info,
FromProto(plan_rel.has_root() ? plan_rel.root().input() :
plan_rel.rel(),
ext_set));
std::vector<std::string> names;
@@ -130,8 +131,9 @@ Result<std::vector<compute::Declaration>> DeserializePlans(
}
// pipe each relation
- ARROW_ASSIGN_OR_RAISE(auto sink_decl,
- declaration_factory(std::move(decl),
std::move(names)));
+ ARROW_ASSIGN_OR_RAISE(
+ auto sink_decl,
+ declaration_factory(std::move(decl_info.declaration),
std::move(names)));
sink_decls.push_back(std::move(sink_decl));
}