This is an automated email from the ASF dual-hosted git repository.
yuanzhou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 049a47750 [GLUTEN-4652][VL] Fix min_by/max_by result mismatch (#5544)
049a47750 is described below
commit 049a47750183ddb88f39036181cd0eb77918d5d9
Author: Yan Ma <[email protected]>
AuthorDate: Mon Apr 29 16:09:55 2024 +0800
[GLUTEN-4652][VL] Fix min_by/max_by result mismatch (#5544)
Fix min_by/max_by result mismatch. Take max_by for example, we need to keep
intermediate result row like <null, 11> which will be compared with another
result like <5, 8> and assure final result is <null, 11>.
---
.../gluten/utils/VeloxIntermediateData.scala | 8 ++++-
.../execution/VeloxAggregateFunctionsSuite.scala | 18 +++++++++++
.../functions/RegistrationAllFunctions.cc | 13 ++++++--
.../functions/RowConstructorWithAllNull.h | 37 ++++++++++++++++++++++
.../operators/functions/RowFunctionWithNull.h | 21 ++++++++++--
5 files changed, 91 insertions(+), 6 deletions(-)
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/utils/VeloxIntermediateData.scala
b/backends-velox/src/main/scala/org/apache/gluten/utils/VeloxIntermediateData.scala
index e6a8bf2c8..a00bcae1c 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/utils/VeloxIntermediateData.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/utils/VeloxIntermediateData.scala
@@ -159,7 +159,13 @@ object VeloxIntermediateData {
* row_constructor_with_null.
*/
def getRowConstructFuncName(aggFunc: AggregateFunction): String = aggFunc
match {
- case _: Average | _: Sum if aggFunc.dataType.isInstanceOf[DecimalType] =>
"row_constructor"
+ case _: Average | _: Sum if aggFunc.dataType.isInstanceOf[DecimalType] =>
+ "row_constructor"
+ // For agg function min_by/max_by, it needs to keep rows with null value
but non-null
+ // comparison, such as <null, 5>. So we set the struct to null when all of
the arguments
+ // are null
+ case _: MaxMinBy =>
+ "row_constructor_with_all_null"
case _ => "row_constructor_with_null"
}
diff --git
a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxAggregateFunctionsSuite.scala
b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxAggregateFunctionsSuite.scala
index 394c4e016..70fff52b8 100644
---
a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxAggregateFunctionsSuite.scala
+++
b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxAggregateFunctionsSuite.scala
@@ -27,6 +27,8 @@ abstract class VeloxAggregateFunctionsSuite extends
VeloxWholeStageTransformerSu
override protected val resourcePath: String = "/tpch-data-parquet-velox"
override protected val fileFormat: String = "parquet"
+ import testImplicits._
+
override def beforeAll(): Unit = {
super.beforeAll()
createTPCHNotNullTables()
@@ -188,6 +190,22 @@ abstract class VeloxAggregateFunctionsSuite extends
VeloxWholeStageTransformerSu
}
}
+ test("min_by/max_by") {
+ withTempPath {
+ path =>
+ Seq((5: Integer, 6: Integer), (null: Integer, 11: Integer), (null:
Integer, 5: Integer))
+ .toDF("a", "b")
+ .write
+ .parquet(path.getCanonicalPath)
+ spark.read
+ .parquet(path.getCanonicalPath)
+ .createOrReplaceTempView("test")
+ runQueryAndCompare("select min_by(a, b), max_by(a, b) from test") {
+ checkGlutenOperatorMatch[HashAggregateExecTransformer]
+ }
+ }
+ }
+
test("groupby") {
val df = runQueryAndCompare(
"select l_orderkey, sum(l_partkey) as sum from lineitem " +
diff --git a/cpp/velox/operators/functions/RegistrationAllFunctions.cc
b/cpp/velox/operators/functions/RegistrationAllFunctions.cc
index 2d2e820f1..c77fa47e5 100644
--- a/cpp/velox/operators/functions/RegistrationAllFunctions.cc
+++ b/cpp/velox/operators/functions/RegistrationAllFunctions.cc
@@ -16,6 +16,7 @@
*/
#include "operators/functions/RegistrationAllFunctions.h"
#include "operators/functions/Arithmetic.h"
+#include "operators/functions/RowConstructorWithAllNull.h"
#include "operators/functions/RowConstructorWithNull.h"
#include "operators/functions/RowFunctionWithNull.h"
@@ -47,11 +48,19 @@ void registerFunctionOverwrite() {
velox::exec::registerVectorFunction(
"row_constructor_with_null",
std::vector<std::shared_ptr<velox::exec::FunctionSignature>>{},
- std::make_unique<RowFunctionWithNull>(),
- RowFunctionWithNull::metadata());
+ std::make_unique<RowFunctionWithNull</*allNull=*/false>>(),
+ RowFunctionWithNull</*allNull=*/false>::metadata());
velox::exec::registerFunctionCallToSpecialForm(
RowConstructorWithNullCallToSpecialForm::kRowConstructorWithNull,
std::make_unique<RowConstructorWithNullCallToSpecialForm>());
+ velox::exec::registerVectorFunction(
+ "row_constructor_with_all_null",
+ std::vector<std::shared_ptr<velox::exec::FunctionSignature>>{},
+ std::make_unique<RowFunctionWithNull</*allNull=*/true>>(),
+ RowFunctionWithNull</*allNull=*/true>::metadata());
+ velox::exec::registerFunctionCallToSpecialForm(
+ RowConstructorWithAllNullCallToSpecialForm::kRowConstructorWithAllNull,
+ std::make_unique<RowConstructorWithAllNullCallToSpecialForm>());
velox::functions::sparksql::registerBitwiseFunctions("spark_");
}
} // namespace
diff --git a/cpp/velox/operators/functions/RowConstructorWithAllNull.h
b/cpp/velox/operators/functions/RowConstructorWithAllNull.h
new file mode 100644
index 000000000..dfc79e1a9
--- /dev/null
+++ b/cpp/velox/operators/functions/RowConstructorWithAllNull.h
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include "RowConstructorWithNull.h"
+
+namespace gluten {
+class RowConstructorWithAllNullCallToSpecialForm : public
RowConstructorWithNullCallToSpecialForm {
+ public:
+ static constexpr const char* kRowConstructorWithAllNull =
"row_constructor_with_all_null";
+
+ protected:
+ facebook::velox::exec::ExprPtr constructSpecialForm(
+ const std::string& name,
+ const facebook::velox::TypePtr& type,
+ std::vector<facebook::velox::exec::ExprPtr>&& compiledChildren,
+ bool trackCpuUsage,
+ const facebook::velox::core::QueryConfig& config) {
+ return constructSpecialForm(kRowConstructorWithAllNull, type,
std::move(compiledChildren), trackCpuUsage, config);
+ }
+};
+} // namespace gluten
diff --git a/cpp/velox/operators/functions/RowFunctionWithNull.h
b/cpp/velox/operators/functions/RowFunctionWithNull.h
index 9ed6bc277..4131fb472 100644
--- a/cpp/velox/operators/functions/RowFunctionWithNull.h
+++ b/cpp/velox/operators/functions/RowFunctionWithNull.h
@@ -23,8 +23,10 @@
namespace gluten {
/**
- * A customized RowFunction to set struct as null when one of its argument is
null.
+ * @tparam allNull If true, set struct as null when all of arguments are all,
else will
+ * set it null when one of its arguments is null.
*/
+template <bool allNull>
class RowFunctionWithNull final : public facebook::velox::exec::VectorFunction
{
public:
void apply(
@@ -42,13 +44,26 @@ class RowFunctionWithNull final : public
facebook::velox::exec::VectorFunction {
rows.applyToSelected([&](facebook::velox::vector_size_t i) {
facebook::velox::bits::clearNull(nullsPtr, i);
if (!facebook::velox::bits::isBitNull(nullsPtr, i)) {
+ int argsNullCnt = 0;
for (size_t c = 0; c < argsCopy.size(); c++) {
auto arg = argsCopy[c].get();
if (arg->mayHaveNulls() && arg->isNullAt(i)) {
- // If any argument of the struct is null, set the struct as null.
+ // For row_constructor_with_null, if any argument of the struct is
null,
+ // set the struct as null.
+ if constexpr (!allNull) {
+ facebook::velox::bits::setNull(nullsPtr, i, true);
+ cntNull++;
+ break;
+ } else {
+ argsNullCnt++;
+ }
+ }
+ }
+ // For row_constructor_with_all_null, set the struct to be null when
all arguments are all
+ if constexpr (allNull) {
+ if (argsNullCnt == argsCopy.size()) {
facebook::velox::bits::setNull(nullsPtr, i, true);
cntNull++;
- break;
}
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]