This is an automated email from the ASF dual-hosted git repository.

yuanzhou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 049a47750 [GLUTEN-4652][VL] Fix min_by/max_by result mismatch (#5544)
049a47750 is described below

commit 049a47750183ddb88f39036181cd0eb77918d5d9
Author: Yan Ma <[email protected]>
AuthorDate: Mon Apr 29 16:09:55 2024 +0800

    [GLUTEN-4652][VL] Fix min_by/max_by result mismatch (#5544)
    
    Fix min_by/max_by result mismatch. Take max_by for example, we need to keep 
intermediate result row like <null, 11> which will be compared with another 
result like <5, 8> and assure final result is <null, 11>.
---
 .../gluten/utils/VeloxIntermediateData.scala       |  8 ++++-
 .../execution/VeloxAggregateFunctionsSuite.scala   | 18 +++++++++++
 .../functions/RegistrationAllFunctions.cc          | 13 ++++++--
 .../functions/RowConstructorWithAllNull.h          | 37 ++++++++++++++++++++++
 .../operators/functions/RowFunctionWithNull.h      | 21 ++++++++++--
 5 files changed, 91 insertions(+), 6 deletions(-)

diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/utils/VeloxIntermediateData.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/utils/VeloxIntermediateData.scala
index e6a8bf2c8..a00bcae1c 100644
--- 
a/backends-velox/src/main/scala/org/apache/gluten/utils/VeloxIntermediateData.scala
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/utils/VeloxIntermediateData.scala
@@ -159,7 +159,13 @@ object VeloxIntermediateData {
    * row_constructor_with_null.
    */
   def getRowConstructFuncName(aggFunc: AggregateFunction): String = aggFunc 
match {
-    case _: Average | _: Sum if aggFunc.dataType.isInstanceOf[DecimalType] => 
"row_constructor"
+    case _: Average | _: Sum if aggFunc.dataType.isInstanceOf[DecimalType] =>
+      "row_constructor"
+    // For agg function min_by/max_by, it needs to keep rows with null value 
but non-null
+    // comparison, such as <null, 5>. So we set the struct to null when all of 
the arguments
+    // are null
+    case _: MaxMinBy =>
+      "row_constructor_with_all_null"
     case _ => "row_constructor_with_null"
   }
 
diff --git 
a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxAggregateFunctionsSuite.scala
 
b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxAggregateFunctionsSuite.scala
index 394c4e016..70fff52b8 100644
--- 
a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxAggregateFunctionsSuite.scala
+++ 
b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxAggregateFunctionsSuite.scala
@@ -27,6 +27,8 @@ abstract class VeloxAggregateFunctionsSuite extends 
VeloxWholeStageTransformerSu
   override protected val resourcePath: String = "/tpch-data-parquet-velox"
   override protected val fileFormat: String = "parquet"
 
+  import testImplicits._
+
   override def beforeAll(): Unit = {
     super.beforeAll()
     createTPCHNotNullTables()
@@ -188,6 +190,22 @@ abstract class VeloxAggregateFunctionsSuite extends 
VeloxWholeStageTransformerSu
     }
   }
 
+  test("min_by/max_by") {
+    withTempPath {
+      path =>
+        Seq((5: Integer, 6: Integer), (null: Integer, 11: Integer), (null: 
Integer, 5: Integer))
+          .toDF("a", "b")
+          .write
+          .parquet(path.getCanonicalPath)
+        spark.read
+          .parquet(path.getCanonicalPath)
+          .createOrReplaceTempView("test")
+        runQueryAndCompare("select min_by(a, b), max_by(a, b) from test") {
+          checkGlutenOperatorMatch[HashAggregateExecTransformer]
+        }
+    }
+  }
+
   test("groupby") {
     val df = runQueryAndCompare(
       "select l_orderkey, sum(l_partkey) as sum from lineitem " +
diff --git a/cpp/velox/operators/functions/RegistrationAllFunctions.cc 
b/cpp/velox/operators/functions/RegistrationAllFunctions.cc
index 2d2e820f1..c77fa47e5 100644
--- a/cpp/velox/operators/functions/RegistrationAllFunctions.cc
+++ b/cpp/velox/operators/functions/RegistrationAllFunctions.cc
@@ -16,6 +16,7 @@
  */
 #include "operators/functions/RegistrationAllFunctions.h"
 #include "operators/functions/Arithmetic.h"
+#include "operators/functions/RowConstructorWithAllNull.h"
 #include "operators/functions/RowConstructorWithNull.h"
 #include "operators/functions/RowFunctionWithNull.h"
 
@@ -47,11 +48,19 @@ void registerFunctionOverwrite() {
   velox::exec::registerVectorFunction(
       "row_constructor_with_null",
       std::vector<std::shared_ptr<velox::exec::FunctionSignature>>{},
-      std::make_unique<RowFunctionWithNull>(),
-      RowFunctionWithNull::metadata());
+      std::make_unique<RowFunctionWithNull</*allNull=*/false>>(),
+      RowFunctionWithNull</*allNull=*/false>::metadata());
   velox::exec::registerFunctionCallToSpecialForm(
       RowConstructorWithNullCallToSpecialForm::kRowConstructorWithNull,
       std::make_unique<RowConstructorWithNullCallToSpecialForm>());
+  velox::exec::registerVectorFunction(
+      "row_constructor_with_all_null",
+      std::vector<std::shared_ptr<velox::exec::FunctionSignature>>{},
+      std::make_unique<RowFunctionWithNull</*allNull=*/true>>(),
+      RowFunctionWithNull</*allNull=*/true>::metadata());
+  velox::exec::registerFunctionCallToSpecialForm(
+      RowConstructorWithAllNullCallToSpecialForm::kRowConstructorWithAllNull,
+      std::make_unique<RowConstructorWithAllNullCallToSpecialForm>());
   velox::functions::sparksql::registerBitwiseFunctions("spark_");
 }
 } // namespace
diff --git a/cpp/velox/operators/functions/RowConstructorWithAllNull.h 
b/cpp/velox/operators/functions/RowConstructorWithAllNull.h
new file mode 100644
index 000000000..dfc79e1a9
--- /dev/null
+++ b/cpp/velox/operators/functions/RowConstructorWithAllNull.h
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include "RowConstructorWithNull.h"
+
+namespace gluten {
+class RowConstructorWithAllNullCallToSpecialForm : public 
RowConstructorWithNullCallToSpecialForm {
+ public:
+  static constexpr const char* kRowConstructorWithAllNull = 
"row_constructor_with_all_null";
+
+ protected:
+  facebook::velox::exec::ExprPtr constructSpecialForm(
+      const std::string& name,
+      const facebook::velox::TypePtr& type,
+      std::vector<facebook::velox::exec::ExprPtr>&& compiledChildren,
+      bool trackCpuUsage,
+      const facebook::velox::core::QueryConfig& config) {
+    return constructSpecialForm(kRowConstructorWithAllNull, type, 
std::move(compiledChildren), trackCpuUsage, config);
+  }
+};
+} // namespace gluten
diff --git a/cpp/velox/operators/functions/RowFunctionWithNull.h 
b/cpp/velox/operators/functions/RowFunctionWithNull.h
index 9ed6bc277..4131fb472 100644
--- a/cpp/velox/operators/functions/RowFunctionWithNull.h
+++ b/cpp/velox/operators/functions/RowFunctionWithNull.h
@@ -23,8 +23,10 @@
 namespace gluten {
 
 /**
- * A customized RowFunction to set struct as null when one of its argument is 
null.
+ * @tparam allNull If true, set struct as null when all of arguments are all, 
else will
+ * set it null when one of its arguments is null.
  */
+template <bool allNull>
 class RowFunctionWithNull final : public facebook::velox::exec::VectorFunction 
{
  public:
   void apply(
@@ -42,13 +44,26 @@ class RowFunctionWithNull final : public 
facebook::velox::exec::VectorFunction {
     rows.applyToSelected([&](facebook::velox::vector_size_t i) {
       facebook::velox::bits::clearNull(nullsPtr, i);
       if (!facebook::velox::bits::isBitNull(nullsPtr, i)) {
+        int argsNullCnt = 0;
         for (size_t c = 0; c < argsCopy.size(); c++) {
           auto arg = argsCopy[c].get();
           if (arg->mayHaveNulls() && arg->isNullAt(i)) {
-            // If any argument of the struct is null, set the struct as null.
+            // For row_constructor_with_null, if any argument of the struct is 
null,
+            // set the struct as null.
+            if constexpr (!allNull) {
+              facebook::velox::bits::setNull(nullsPtr, i, true);
+              cntNull++;
+              break;
+            } else {
+              argsNullCnt++;
+            }
+          }
+        }
+        // For row_constructor_with_all_null, set the struct to be null when 
all arguments are all
+        if constexpr (allNull) {
+          if (argsNullCnt == argsCopy.size()) {
             facebook::velox::bits::setNull(nullsPtr, i, true);
             cntNull++;
-            break;
           }
         }
       }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to