This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new ff66adda3ca [SPARK-40107][SQL][FOLLOW-UP] Update `empty2null` check
ff66adda3ca is described below

commit ff66adda3ca3762b8c71b14acbd6da00c5508a2e
Author: allisonwang-db <allison.w...@databricks.com>
AuthorDate: Tue Sep 13 12:08:17 2022 +0800

    [SPARK-40107][SQL][FOLLOW-UP] Update `empty2null` check
    
    ### What changes were proposed in this pull request?
    
    This PR is a follow-up for SPARK-40107. It updates the way we check the 
`empty2null` expression in a V1 write query plan. Previously, we only search 
for this expression in Project. But optimizer can change the position of this 
expression, for example collapsing projects with aggregates. As a result, we 
need to search the entire plan to see if `empty2null` has been added by 
`V1Writes`.
    
    ### Why are the changes needed?
    
    To prevent unnecessary `empty2null` projections from being added in 
FileFormatWriter.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    New unit tests.
    
    Closes #37856 from allisonwang-db/spark-40107-followup.
    
    Authored-by: allisonwang-db <allison.w...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../execution/datasources/FileFormatWriter.scala   | 14 ++------
 .../spark/sql/execution/datasources/V1Writes.scala | 35 +++++++++++---------
 .../datasources/V1WriteCommandSuite.scala          | 38 +++++++++++++++-------
 .../command/V1WriteHiveCommandSuite.scala          |  4 ++-
 4 files changed, 50 insertions(+), 41 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
index 794d90b242c..12562014c39 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
@@ -103,10 +103,7 @@ object FileFormatWriter extends Logging {
       .map(FileSourceMetadataAttribute.cleanupFileSourceMetadataInformation))
     val dataColumns = 
finalOutputSpec.outputColumns.filterNot(partitionSet.contains)
 
-    val hasEmpty2Null = plan.find {
-      case p: ProjectExec => V1WritesUtils.hasEmptyToNull(p.projectList)
-      case _ => false
-    }.isDefined
+    val hasEmpty2Null = plan.exists(p => 
V1WritesUtils.hasEmptyToNull(p.expressions))
     val empty2NullPlan = if (hasEmpty2Null) {
       plan
     } else {
@@ -150,14 +147,7 @@ object FileFormatWriter extends Logging {
     // the sort order doesn't matter
     // Use the output ordering from the original plan before adding the 
empty2null projection.
     val actualOrdering = plan.outputOrdering.map(_.child)
-    val orderingMatched = if (requiredOrdering.length > actualOrdering.length) 
{
-      false
-    } else {
-      requiredOrdering.zip(actualOrdering).forall {
-        case (requiredOrder, childOutputOrder) =>
-          requiredOrder.semanticEquals(childOutputOrder)
-      }
-    }
+    val orderingMatched = V1WritesUtils.isOrderingMatched(requiredOrdering, 
actualOrdering)
 
     SQLExecution.checkSQLExecutionId(sparkSession)
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/V1Writes.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/V1Writes.scala
index d3cac32ae66..d082b95739c 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/V1Writes.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/V1Writes.scala
@@ -47,6 +47,9 @@ trait V1WriteCommand extends DataWritingCommand {
  * A rule that adds logical sorts to V1 data writing commands.
  */
 object V1Writes extends Rule[LogicalPlan] with SQLConfHelper {
+
+  import V1WritesUtils._
+
   override def apply(plan: LogicalPlan): LogicalPlan = {
     if (conf.plannedWriteEnabled) {
       plan.transformDown {
@@ -65,10 +68,11 @@ object V1Writes extends Rule[LogicalPlan] with 
SQLConfHelper {
   }
 
   private def prepareQuery(write: V1WriteCommand, query: LogicalPlan): 
LogicalPlan = {
-    val empty2NullPlan = if (hasEmptyToNull(query)) {
+    val hasEmpty2Null = query.exists(p => hasEmptyToNull(p.expressions))
+    val empty2NullPlan = if (hasEmpty2Null) {
       query
     } else {
-      val projectList = V1WritesUtils.convertEmptyToNull(query.output, 
write.partitionColumns)
+      val projectList = convertEmptyToNull(query.output, 
write.partitionColumns)
       if (projectList.isEmpty) query else Project(projectList, query)
     }
     assert(empty2NullPlan.output.length == query.output.length)
@@ -80,26 +84,13 @@ object V1Writes extends Rule[LogicalPlan] with 
SQLConfHelper {
     }.asInstanceOf[SortOrder])
     val outputOrdering = query.outputOrdering
     // Check if the ordering is already matched to ensure the idempotency of 
the rule.
-    val orderingMatched = if (requiredOrdering.length > outputOrdering.length) 
{
-      false
-    } else {
-      requiredOrdering.zip(outputOrdering).forall {
-        case (requiredOrder, outputOrder) => 
requiredOrder.semanticEquals(outputOrder)
-      }
-    }
+    val orderingMatched = isOrderingMatched(requiredOrdering, outputOrdering)
     if (orderingMatched) {
       empty2NullPlan
     } else {
       Sort(requiredOrdering, global = false, empty2NullPlan)
     }
   }
-
-  private def hasEmptyToNull(plan: LogicalPlan): Boolean = {
-    plan.find {
-      case p: Project => V1WritesUtils.hasEmptyToNull(p.projectList)
-      case _ => false
-    }.isDefined
-  }
 }
 
 object V1WritesUtils {
@@ -209,4 +200,16 @@ object V1WritesUtils {
   def hasEmptyToNull(expressions: Seq[Expression]): Boolean = {
     expressions.exists(_.exists(_.isInstanceOf[Empty2Null]))
   }
+
+  def isOrderingMatched(
+      requiredOrdering: Seq[Expression],
+      outputOrdering: Seq[Expression]): Boolean = {
+    if (requiredOrdering.length > outputOrdering.length) {
+      false
+    } else {
+      requiredOrdering.zip(outputOrdering).forall {
+        case (requiredOrder, outputOrder) => 
requiredOrder.semanticEquals(outputOrder)
+      }
+    }
+  }
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala
index c18396b554d..d66f2bd0cc4 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala
@@ -18,19 +18,19 @@
 package org.apache.spark.sql.execution.datasources
 
 import org.apache.spark.sql.{QueryTest, Row}
-import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Sort}
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Sort}
 import org.apache.spark.sql.execution.QueryExecution
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils}
 import org.apache.spark.sql.util.QueryExecutionListener
 
-abstract class V1WriteCommandSuiteBase extends QueryTest with SQLTestUtils {
+trait V1WriteCommandSuiteBase extends SQLTestUtils {
 
   import testImplicits._
 
   setupTestData()
 
-  protected override def beforeAll(): Unit = {
+  override def beforeAll(): Unit = {
     super.beforeAll()
     (0 to 20).map(i => (i, i % 5, (i % 10).toString))
       .toDF("i", "j", "k")
@@ -38,12 +38,12 @@ abstract class V1WriteCommandSuiteBase extends QueryTest 
with SQLTestUtils {
       .saveAsTable("t0")
   }
 
-  protected override def afterAll(): Unit = {
+  override def afterAll(): Unit = {
     sql("drop table if exists t0")
     super.afterAll()
   }
 
-  protected def withPlannedWrite(testFunc: Boolean => Any): Unit = {
+  def withPlannedWrite(testFunc: Boolean => Any): Unit = {
     Seq(true, false).foreach { enabled =>
       withSQLConf(SQLConf.PLANNED_WRITE_ENABLED.key -> enabled.toString) {
         testFunc(enabled)
@@ -87,19 +87,16 @@ abstract class V1WriteCommandSuiteBase extends QueryTest 
with SQLTestUtils {
         s"Expect hasLogicalSort: $hasLogicalSort, Actual: 
${optimizedPlan.isInstanceOf[Sort]}")
 
       // Check empty2null conversion.
-      val projection = optimizedPlan.collectFirst {
-        case p: Project
-          if 
p.projectList.exists(_.exists(_.isInstanceOf[V1WritesUtils.Empty2Null])) => p
-      }
-      assert(projection.isDefined == hasEmpty2Null,
-        s"Expect hasEmpty2Null: $hasEmpty2Null, Actual: 
${projection.isDefined}")
+      val empty2nullExpr = optimizedPlan.exists(p => 
V1WritesUtils.hasEmptyToNull(p.expressions))
+      assert(empty2nullExpr == hasEmpty2Null,
+        s"Expect hasEmpty2Null: $hasEmpty2Null, Actual: $empty2nullExpr. 
Plan:\n$optimizedPlan")
     }
 
     spark.listenerManager.unregister(listener)
   }
 }
 
-class V1WriteCommandSuite extends V1WriteCommandSuiteBase with 
SharedSparkSession {
+class V1WriteCommandSuite extends QueryTest with SharedSparkSession with 
V1WriteCommandSuiteBase {
 
   import testImplicits._
 
@@ -277,4 +274,21 @@ class V1WriteCommandSuite extends V1WriteCommandSuiteBase 
with SharedSparkSessio
       }
     }
   }
+
+  test("v1 write with empty2null in aggregate") {
+    withPlannedWrite { enabled =>
+      withTable("t") {
+        executeAndCheckOrdering(
+          hasLogicalSort = enabled, orderingMatched = enabled, hasEmpty2Null = 
enabled) {
+          sql(
+            """
+              |CREATE TABLE t USING PARQUET
+              |PARTITIONED BY (k) AS
+              |SELECT SUM(i) AS i, SUM(j) AS j, k
+              |FROM t0 WHERE i > 0 GROUP BY k
+              |""".stripMargin)
+        }
+      }
+    }
+  }
 }
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/V1WriteHiveCommandSuite.scala
 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/V1WriteHiveCommandSuite.scala
index 364b7971730..0f219032fc0 100644
--- 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/V1WriteHiveCommandSuite.scala
+++ 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/V1WriteHiveCommandSuite.scala
@@ -17,10 +17,12 @@
 
 package org.apache.spark.sql.hive.execution.command
 
+import org.apache.spark.sql.QueryTest
 import org.apache.spark.sql.execution.datasources.V1WriteCommandSuiteBase
 import org.apache.spark.sql.hive.test.TestHiveSingleton
 
-class V1WriteHiveCommandSuite extends V1WriteCommandSuiteBase with 
TestHiveSingleton {
+class V1WriteHiveCommandSuite
+    extends QueryTest with TestHiveSingleton with V1WriteCommandSuiteBase  {
 
   test("create hive table as select - no partition column") {
     withPlannedWrite { enabled =>


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to