This is an automated email from the ASF dual-hosted git repository.

lzljs3620320 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-paimon.git


The following commit(s) were added to refs/heads/master by this push:
     new ab664b840 [spark] Merge into support update nested column (#2365)
ab664b840 is described below

commit ab664b8406c0972ebd5ffd551f292df3bcd07408
Author: Zouxxyy <[email protected]>
AuthorDate: Thu Nov 23 09:21:06 2023 +0800

    [spark] Merge into support update nested column (#2365)
---
 .../spark/commands/UpdatePaimonTableCommand.scala  |  7 +--
 .../analysis/AssignmentAlignmentHelper.scala       | 33 +++++-----
 .../sql/catalyst/analysis/PaimonMergeInto.scala    | 72 +++-------------------
 .../analysis/expressions/ExpressionHelper.scala    | 12 +++-
 .../paimon/spark/sql/MergeIntoTableTest.scala      | 27 ++++----
 .../apache/paimon/spark/sql/UpdateTableTest.scala  |  2 +-
 6 files changed, 53 insertions(+), 100 deletions(-)

diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/UpdatePaimonTableCommand.scala
 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/UpdatePaimonTableCommand.scala
index a44dd1a36..bbf156cf2 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/UpdatePaimonTableCommand.scala
+++ 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/UpdatePaimonTableCommand.scala
@@ -25,12 +25,11 @@ import org.apache.paimon.types.RowKind
 
 import org.apache.spark.sql.{Row, SparkSession}
 import org.apache.spark.sql.Utils.createDataset
-import org.apache.spark.sql.catalyst.analysis.{AssignmentAlignmentHelper, 
EliminateSubqueryAliases}
+import org.apache.spark.sql.catalyst.analysis.{AssignmentAlignmentHelper, 
PaimonRelation}
 import org.apache.spark.sql.catalyst.expressions.Alias
 import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral
 import org.apache.spark.sql.catalyst.plans.logical.{Filter, Project, 
UpdateTable}
 import org.apache.spark.sql.execution.command.LeafRunnableCommand
-import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
 import org.apache.spark.sql.functions.lit
 
 case class UpdatePaimonTableCommand(u: UpdateTable)
@@ -39,10 +38,10 @@ case class UpdatePaimonTableCommand(u: UpdateTable)
 
   override def run(sparkSession: SparkSession): Seq[Row] = {
 
-    val relation = 
EliminateSubqueryAliases(u.table).asInstanceOf[DataSourceV2Relation]
+    val relation = PaimonRelation.getPaimonRelation(u.table)
 
     val updatedExprs: Seq[Alias] =
-      alignUpdateAssignments(relation.output, 
u.assignments).zip(relation.output).map {
+      generateAlignedExpressions(relation.output, 
u.assignments).zip(relation.output).map {
         case (expr, attr) => Alias(expr, attr.name)()
       }
 
diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentAlignmentHelper.scala
 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentAlignmentHelper.scala
index 33bec58da..334362302 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentAlignmentHelper.scala
+++ 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentAlignmentHelper.scala
@@ -18,11 +18,12 @@
 package org.apache.spark.sql.catalyst.analysis
 
 import org.apache.spark.sql.catalyst.SQLConfHelper
-import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, 
AttributeReference, CreateNamedStruct, Expression, GetStructField, Literal, 
NamedExpression}
+import org.apache.spark.sql.catalyst.analysis.expressions.ExpressionHelper
+import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, 
CreateNamedStruct, Expression, GetStructField, Literal, NamedExpression}
 import org.apache.spark.sql.catalyst.plans.logical.Assignment
 import org.apache.spark.sql.types.StructType
 
-trait AssignmentAlignmentHelper extends SQLConfHelper {
+trait AssignmentAlignmentHelper extends SQLConfHelper with ExpressionHelper {
 
   private lazy val resolver = conf.resolver
 
@@ -35,31 +36,29 @@ trait AssignmentAlignmentHelper extends SQLConfHelper {
   private case class AttrUpdate(ref: Seq[String], expr: Expression)
 
   /**
-   * Align update assignments to the given attrs, only supports PrimitiveType 
and StructType. For
-   * example, if attrs are [a int, b int, s struct(c1 int, c2 int)] and update 
assignments are [a =
-   * 1, s.c1 = 2], will return [1, b, struct(2, c2)].
+   * Generate aligned expressions, only supports PrimitiveType and StructType. 
For example, if attrs
+   * are [a int, b int, s struct(c1 int, c2 int)] and update assignments are 
[a = 1, s.c1 = 2], will
+   * return [1, b, struct(2, c2)].
    * @param attrs
    *   target attrs
    * @param assignments
    *   update assignments
    * @return
-   *   aligned update expressions
+   *   aligned expressions
    */
-  protected def alignUpdateAssignments(
+  protected def generateAlignedExpressions(
       attrs: Seq[Attribute],
       assignments: Seq[Assignment]): Seq[Expression] = {
     val attrUpdates = assignments.map(a => AttrUpdate(toRefSeq(a.key), 
a.value))
     recursiveAlignUpdates(attrs, attrUpdates)
   }
 
-  def toRefSeq(expr: Expression): Seq[String] = expr match {
-    case attr: Attribute =>
-      Seq(attr.name)
-    case GetStructField(child, _, Some(name)) =>
-      toRefSeq(child) :+ name
-    case other =>
-      throw new UnsupportedOperationException(
-        s"Unsupported update expression: $other, only support update with 
PrimitiveType and StructType.")
+  protected def alignAssignments(
+      attrs: Seq[Attribute],
+      assignments: Seq[Assignment]): Seq[Assignment] = {
+    generateAlignedExpressions(attrs, assignments).zip(attrs).map {
+      case (expression, field) => Assignment(field, expression)
+    }
   }
 
   private def recursiveAlignUpdates(
@@ -79,7 +78,7 @@ trait AssignmentAlignmentHelper extends SQLConfHelper {
           if (exactMatchedUpdate.isDefined) {
             if (headMatchedUpdates.size == 1) {
               // when an exact match (no nested fields) occurs, it must be the 
only match, then return it's expr
-              exactMatchedUpdate.get.expr
+              castIfNeeded(exactMatchedUpdate.get.expr, targetAttr.dataType)
             } else {
               // otherwise, there must be conflicting updates, for example:
               // - update the same attr multiple times
@@ -87,7 +86,7 @@ trait AssignmentAlignmentHelper extends SQLConfHelper {
               val conflictingAttrNames =
                 headMatchedUpdates.map(u => (namePrefix ++ 
u.ref).mkString(".")).distinct
               throw new UnsupportedOperationException(
-                s"Conflicting updates on attrs: 
${conflictingAttrNames.mkString(", ")}"
+                s"Conflicting update/insert on attrs: 
${conflictingAttrNames.mkString(", ")}"
               )
             }
           } else {
diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/analysis/PaimonMergeInto.scala
 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/analysis/PaimonMergeInto.scala
index 4908fa7ec..b1c945663 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/analysis/PaimonMergeInto.scala
+++ 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/analysis/PaimonMergeInto.scala
@@ -22,24 +22,20 @@ import org.apache.paimon.spark.SparkTable
 import org.apache.paimon.spark.commands.MergeIntoPaimonTable
 
 import org.apache.spark.sql.SparkSession
-import org.apache.spark.sql.catalyst.analysis.PaimonRelation
 import org.apache.spark.sql.catalyst.analysis.expressions.ExpressionHelper
 import org.apache.spark.sql.catalyst.expressions.{AttributeReference, 
AttributeSet, Expression, SubqueryExpression}
-import org.apache.spark.sql.catalyst.plans.logical.{Assignment, DeleteAction, 
InsertAction, InsertStarAction, LogicalPlan, MergeAction, MergeIntoTable, 
UpdateAction, UpdateStarAction}
+import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules.Rule
 
-import scala.collection.mutable
-
 /** A post-hoc resolution rule for MergeInto. */
 case class PaimonMergeInto(spark: SparkSession)
   extends Rule[LogicalPlan]
   with RowLevelHelper
-  with ExpressionHelper {
+  with ExpressionHelper
+  with AssignmentAlignmentHelper {
 
   override val operation: RowLevelOp = MergeInto
 
-  private val resolver: Resolver = spark.sessionState.conf.resolver
-
   def apply(plan: LogicalPlan): LogicalPlan = {
     plan.resolveOperators {
       case merge: MergeIntoTable
@@ -62,9 +58,9 @@ case class PaimonMergeInto(spark: SparkSession)
           primaryKeys)
 
         val alignedMatchedActions =
-          merge.matchedActions.map(checkAndAlignActionAssigment(_, 
targetOutput))
+          merge.matchedActions.map(checkAndAlignActionAssignment(_, 
targetOutput))
         val alignedNotMatchedActions =
-          merge.notMatchedActions.map(checkAndAlignActionAssigment(_, 
targetOutput))
+          merge.notMatchedActions.map(checkAndAlignActionAssignment(_, 
targetOutput))
 
         MergeIntoPaimonTable(
           v2Table,
@@ -76,41 +72,19 @@ case class PaimonMergeInto(spark: SparkSession)
     }
   }
 
-  private def checkAndAlignActionAssigment(
+  private def checkAndAlignActionAssignment(
       action: MergeAction,
       targetOutput: Seq[AttributeReference]): MergeAction = {
     action match {
       case d @ DeleteAction(_) => d
       case u @ UpdateAction(_, assignments) =>
-        val attrNameAndUpdateExpr = checkAndConvertAssignments(assignments, 
targetOutput)
-
-        val newAssignments = targetOutput.map {
-          field =>
-            val fieldAndExpr = attrNameAndUpdateExpr.find(a => 
resolver(field.name, a._1))
-            if (fieldAndExpr.isEmpty) {
-              Assignment(field, field)
-            } else {
-              Assignment(field, castIfNeeded(fieldAndExpr.get._2, 
field.dataType))
-            }
-        }
-        u.copy(assignments = newAssignments)
+        u.copy(assignments = alignAssignments(targetOutput, assignments))
 
       case i @ InsertAction(_, assignments) =>
-        val attrNameAndUpdateExpr = checkAndConvertAssignments(assignments, 
targetOutput)
         if (assignments.length != targetOutput.length) {
           throw new RuntimeException("Can't align the table's columns in 
insert clause.")
         }
-
-        val newAssignments = targetOutput.map {
-          field =>
-            val fieldAndExpr = attrNameAndUpdateExpr.find(a => 
resolver(field.name, a._1))
-            if (fieldAndExpr.isEmpty) {
-              throw new RuntimeException(s"Can't find the expression for 
${field.name}.")
-            } else {
-              Assignment(field, castIfNeeded(fieldAndExpr.get._2, 
field.dataType))
-            }
-        }
-        i.copy(assignments = newAssignments)
+        i.copy(assignments = alignAssignments(targetOutput, assignments))
 
       case _: UpdateStarAction =>
         throw new RuntimeException(s"UpdateStarAction should not be here.")
@@ -132,36 +106,6 @@ case class PaimonMergeInto(spark: SparkSession)
     }
   }
 
-  private def checkAndConvertAssignments(
-      assignments: Seq[Assignment],
-      targetOutput: Seq[AttributeReference]): Seq[(String, Expression)] = {
-    val columnToAssign = mutable.HashMap.empty[String, Int]
-    val pairs = assignments.map {
-      assignment =>
-        assignment.key match {
-          case a: AttributeReference =>
-            if (!targetOutput.exists(attr => resolver(attr.name, a.name))) {
-              throw new RuntimeException(
-                s"Ths key of assignment doesn't belong to the target table, 
$assignment")
-            }
-            columnToAssign.put(a.name, columnToAssign.getOrElse(a.name, 0) + 1)
-          case _ =>
-            throw new RuntimeException(
-              s"Only primitive type is supported in update/insert clause, 
$assignment")
-        }
-        (assignment.key.asInstanceOf[AttributeReference].name, 
assignment.value)
-    }
-
-    val duplicatedColumns = columnToAssign.filter(_._2 > 1).keys
-    if (duplicatedColumns.nonEmpty) {
-      val partOfMsg = duplicatedColumns.mkString(",")
-      throw new RuntimeException(
-        s"Can't update/insert the same column ($partOfMsg) multiple times.")
-    }
-
-    pairs
-  }
-
   /** This check will avoid to update the primary key columns */
   private def checkUpdateActionValidity(
       targetOutput: AttributeSet,
diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/analysis/expressions/ExpressionHelper.scala
 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/analysis/expressions/ExpressionHelper.scala
index a4e2b6755..ad02ffc03 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/analysis/expressions/ExpressionHelper.scala
+++ 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/catalyst/analysis/expressions/ExpressionHelper.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.sql.catalyst.analysis.expressions
 
 import org.apache.spark.sql.SparkSession
-import org.apache.spark.sql.catalyst.expressions.{And, Attribute, Cast, 
Expression, Literal, PredicateHelper}
+import org.apache.spark.sql.catalyst.expressions.{And, Attribute, Cast, 
Expression, GetStructField, Literal, PredicateHelper}
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types.{DataType, NullType}
@@ -82,6 +82,16 @@ trait ExpressionHelper extends PredicateHelper {
         }
     }
   }
+
+  protected def toRefSeq(expr: Expression): Seq[String] = expr match {
+    case attr: Attribute =>
+      Seq(attr.name)
+    case GetStructField(child, _, Some(name)) =>
+      toRefSeq(child) :+ name
+    case other =>
+      throw new UnsupportedOperationException(
+        s"Unsupported update expression: $other, only support update with 
PrimitiveType and StructType.")
+  }
 }
 
 object ExpressionHelper {
diff --git 
a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/MergeIntoTableTest.scala
 
b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/MergeIntoTableTest.scala
index c07131ec3..63d73a09a 100644
--- 
a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/MergeIntoTableTest.scala
+++ 
b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/MergeIntoTableTest.scala
@@ -511,7 +511,7 @@ class MergeIntoTableTest extends PaimonSparkTestBase {
                      |THEN INSERT *
                      |""".stripMargin)
       }.getMessage
-      assert(error1.contains("Can't update/insert the same column (b) multiple 
times."))
+      assert(error1.contains("Conflicting update/insert on attrs: b"))
 
       val error2 = intercept[RuntimeException] {
         spark.sql(s"""
@@ -524,11 +524,11 @@ class MergeIntoTableTest extends PaimonSparkTestBase {
                      |THEN INSERT (a, a, c) VALUES (a, b, c)
                      |""".stripMargin)
       }.getMessage
-      assert(error2.contains("Can't update/insert the same column (a) multiple 
times."))
+      assert(error2.contains("Conflicting update/insert on attrs: a"))
     }
   }
 
-  test("Paimon MergeInto: fail in case that update nested column") {
+  test("Paimon MergeInto: update nested column") {
     withTable("source", "target") {
 
       Seq((1, 100, "x1", "y1"), (3, 300, "x3", "y3"))
@@ -541,16 +541,17 @@ class MergeIntoTableTest extends PaimonSparkTestBase {
                    |""".stripMargin)
       spark.sql("INSERT INTO target values (1, 10, struct('x', 'y')), (2, 20, 
struct('x', 'y'))")
 
-      val error = intercept[RuntimeException] {
-        spark.sql(s"""
-                     |MERGE INTO target
-                     |USING source
-                     |ON target.a = source.a
-                     |WHEN MATCHED THEN
-                     |UPDATE SET c.c1 = source.c1
-                     |""".stripMargin)
-      }.getMessage
-      assert(error.contains("Only primitive type is supported"))
+      spark.sql(s"""
+                   |MERGE INTO target
+                   |USING source
+                   |ON target.a = source.a
+                   |WHEN MATCHED THEN
+                   |UPDATE SET c.c1 = source.c1
+                   |""".stripMargin)
+
+      checkAnswer(
+        spark.sql("SELECT * FROM target ORDER BY a"),
+        Row(1, 10, Row("x1", "y")) :: Row(2, 20, Row("x", "y")) :: Nil)
     }
   }
 
diff --git 
a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/UpdateTableTest.scala
 
b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/UpdateTableTest.scala
index 603262d58..4ced313ff 100644
--- 
a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/UpdateTableTest.scala
+++ 
b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/UpdateTableTest.scala
@@ -225,6 +225,6 @@ class UpdateTableTest extends PaimonSparkTestBase {
 
     assertThatThrownBy(
       () => spark.sql("UPDATE T SET s.c2 = 'a_new', s = struct(11, 'a_new') 
WHERE s.c1 = 1"))
-      .hasMessageContaining("Conflicting updates on attrs: s.c2, s")
+      .hasMessageContaining("Conflicting update/insert on attrs: s.c2, s")
   }
 }

Reply via email to