This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new 0a94038dde79 [SPARK-49352][SQL][3.4] Avoid redundant array transform 
for identical expression
0a94038dde79 is described below

commit 0a94038dde79bb574a9376965cac8f8a4f229ccf
Author: Liang-Chi Hsieh <vii...@gmail.com>
AuthorDate: Fri Aug 23 22:28:03 2024 -0700

    [SPARK-49352][SQL][3.4] Avoid redundant array transform for identical 
expression
    
    ### What changes were proposed in this pull request?
    
    This patch avoids `ArrayTransform` in `resolveArrayType` function if the 
resolution expression is the same as input param.
    
    ### Why are the changes needed?
    
    Our customer encounters significant performance regression when migrating 
from Spark 3.2 to Spark 3.4 on a `Insert Into` query which is analyzed as a 
`AppendData` on an Iceberg table.
    We found that the root cause is in Spark 3.4, `TableOutputResolver` 
resolves the query with additional `ArrayTransform` on an `ArrayType` field. 
The `ArrayTransform`'s lambda function is actually an identical function, i.e., 
the transformation is redundant.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Unit test and manual e2e test
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #47862 from viirya/fix_redundant_array_transform_3.4.
    
    Authored-by: Liang-Chi Hsieh <vii...@gmail.com>
    Signed-off-by: Dongjoon Hyun <dongj...@apache.org>
---
 .../catalyst/analysis/TableOutputResolver.scala    | 12 ++++++--
 .../spark/sql/catalyst/util/CharVarcharUtils.scala |  2 +-
 .../catalyst/analysis/V2WriteAnalysisSuite.scala   | 32 +++++++++++++++++++++-
 3 files changed, 42 insertions(+), 4 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala
index e1ee0defa239..908711db8503 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala
@@ -182,8 +182,16 @@ object TableOutputResolver {
       val fakeAttr = AttributeReference("x", expectedType.elementType, 
expectedType.containsNull)()
       val res = reorderColumnsByName(Seq(param), Seq(fakeAttr), conf, 
addError, colPath)
       if (res.length == 1) {
-        val func = LambdaFunction(res.head, Seq(param))
-        Some(Alias(ArrayTransform(input, func), expectedName)())
+        if (res.head == param) {
+          // If the element type is the same, we can reuse the input array 
directly.
+          Some(
+            Alias(input, expectedName)(
+              nonInheritableMetadataKeys =
+                Seq(CharVarcharUtils.CHAR_VARCHAR_TYPE_STRING_METADATA_KEY)))
+        } else {
+          val func = LambdaFunction(res.head, Seq(param))
+          Some(Alias(ArrayTransform(input, func), expectedName)())
+        }
       } else {
         None
       }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala
index 448106343584..5324752ba3f8 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala
@@ -29,7 +29,7 @@ import org.apache.spark.sql.types._
 
 object CharVarcharUtils extends Logging {
 
-  private val CHAR_VARCHAR_TYPE_STRING_METADATA_KEY = 
"__CHAR_VARCHAR_TYPE_STRING"
+  private[sql] val CHAR_VARCHAR_TYPE_STRING_METADATA_KEY = 
"__CHAR_VARCHAR_TYPE_STRING"
 
   /**
    * Replaces CharType/VarcharType with StringType recursively in the given 
struct type. If a
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala
index 69cd838cfb24..7d55dd4e97a6 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala
@@ -21,7 +21,7 @@ import java.util.Locale
 
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, 
Cast, LessThanOrEqual, Literal}
+import org.apache.spark.sql.catalyst.expressions.{Alias, ArrayTransform, 
AttributeReference, Cast, LessThanOrEqual, Literal}
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.internal.SQLConf
@@ -237,6 +237,36 @@ abstract class V2WriteAnalysisSuiteBase extends 
AnalysisTest {
 
   def byPosition(table: NamedRelation, query: LogicalPlan): LogicalPlan
 
+  test("SPARK-49352: Avoid redundant array transform for identical 
expression") {
+    def assertArrayField(fromType: ArrayType, toType: ArrayType, hasTransform: 
Boolean): Unit = {
+      val table = TestRelation(Seq($"a".int, $"arr".array(toType)))
+      val query = TestRelation(Seq($"arr".array(fromType), $"a".int))
+
+      val writePlan = byName(table, query).analyze
+
+      assertResolved(writePlan)
+      checkAnalysis(writePlan, writePlan)
+
+      val transform = writePlan.children.head.expressions.exists { e =>
+        e.find {
+          case _: ArrayTransform => true
+          case _ => false
+        }.isDefined
+      }
+      if (hasTransform) {
+        assert(transform)
+      } else {
+        assert(!transform)
+      }
+    }
+
+    assertArrayField(ArrayType(LongType), ArrayType(LongType), hasTransform = 
false)
+    assertArrayField(
+      ArrayType(new StructType().add("x", "int").add("y", "int")),
+      ArrayType(new StructType().add("y", "int").add("x", "byte")),
+      hasTransform = true)
+  }
+
   test("SPARK-33136: output resolved on complex types for V2 write commands") {
     def assertTypeCompatibility(name: String, fromType: DataType, toType: 
DataType): Unit = {
       val table = TestRelation(StructType(Seq(StructField("a", 
toType))).toAttributes)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to