This is an automated email from the ASF dual-hosted git repository.

zouxxyy pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git


The following commit(s) were added to refs/heads/master by this push:
     new 7a9bcdaa3 [spark] dataframe.write and insert sql syntax in byName mode 
(#3871)
7a9bcdaa3 is described below

commit 7a9bcdaa3b556adb2695740d44fd109dc80eb87a
Author: Yann Byron <[email protected]>
AuthorDate: Thu Aug 1 23:08:41 2024 +0800

    [spark] dataframe.write and insert sql syntax in byName mode (#3871)
---
 .../spark/catalyst/analysis/PaimonAnalysis.scala   | 87 +++++++++++++++++-----
 .../paimon/spark/sql/DataFrameWriteTest.scala      | 29 ++++++++
 .../spark/sql/InsertOverwriteTableTest.scala       | 51 +++++++++++++
 3 files changed, 149 insertions(+), 18 deletions(-)

diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonAnalysis.scala
 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonAnalysis.scala
index 3dc0e40c9..d115fe3fd 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonAnalysis.scala
+++ 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonAnalysis.scala
@@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.analysis.ResolvedTable
 import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, 
Expression, NamedExpression}
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
 import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
 import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType}
 
@@ -39,11 +40,17 @@ class PaimonAnalysis(session: SparkSession) extends 
Rule[LogicalPlan] {
   override def apply(plan: LogicalPlan): LogicalPlan = 
plan.resolveOperatorsDown {
 
     case a @ PaimonV2WriteCommand(table, paimonTable)
-        if !schemaCompatible(
-          a.query.output.toStructType,
-          table.output.toStructType,
-          paimonTable.partitionKeys().asScala) =>
-      val newQuery = resolveQueryColumns(a.query, table.output)
+        if a.isByName && needsSchemaAdjustmentByName(a.query, table.output, 
paimonTable) =>
+      val newQuery = resolveQueryColumnsByName(a.query, table.output)
+      if (newQuery != a.query) {
+        Compatibility.withNewQuery(a, newQuery)
+      } else {
+        a
+      }
+
+    case a @ PaimonV2WriteCommand(table, paimonTable)
+        if !a.isByName && needsSchemaAdjustmentByPosition(a.query, 
table.output, paimonTable) =>
+      val newQuery = resolveQueryColumnsByPosition(a.query, table.output)
       if (newQuery != a.query) {
         Compatibility.withNewQuery(a, newQuery)
       } else {
@@ -57,6 +64,62 @@ class PaimonAnalysis(session: SparkSession) extends 
Rule[LogicalPlan] {
       PaimonMergeIntoResolver(merge, session)
   }
 
+  private def needsSchemaAdjustmentByName(
+      query: LogicalPlan,
+      targetAttrs: Seq[Attribute],
+      paimonTable: FileStoreTable): Boolean = {
+    val userSpecifiedNames = if 
(session.sessionState.conf.caseSensitiveAnalysis) {
+      query.output.map(a => (a.name, a)).toMap
+    } else {
+      CaseInsensitiveMap(query.output.map(a => (a.name, a)).toMap)
+    }
+    val specifiedTargetAttrs = targetAttrs.filter(col => 
userSpecifiedNames.contains(col.name))
+    !schemaCompatible(
+      specifiedTargetAttrs.toStructType,
+      query.output.toStructType,
+      paimonTable.partitionKeys().asScala)
+  }
+
+  private def resolveQueryColumnsByName(
+      query: LogicalPlan,
+      targetAttrs: Seq[Attribute]): LogicalPlan = {
+    val output = query.output
+    val project = targetAttrs.map {
+      attr =>
+        val outputAttr = output
+          .find(t => session.sessionState.conf.resolver(t.name, attr.name))
+          .getOrElse {
+            throw new RuntimeException(
+              s"Cannot find ${attr.name} in data columns: 
${output.map(_.name).mkString(", ")}")
+          }
+        addCastToColumn(outputAttr, attr)
+    }
+    Project(project, query)
+  }
+
+  private def needsSchemaAdjustmentByPosition(
+      query: LogicalPlan,
+      targetAttrs: Seq[Attribute],
+      paimonTable: FileStoreTable): Boolean = {
+    val output = query.output
+    targetAttrs.map(_.name) != output.map(_.name) ||
+    !schemaCompatible(
+      targetAttrs.toStructType,
+      output.toStructType,
+      paimonTable.partitionKeys().asScala)
+  }
+
+  private def resolveQueryColumnsByPosition(
+      query: LogicalPlan,
+      tableAttributes: Seq[Attribute]): LogicalPlan = {
+    val project = query.output.zipWithIndex.map {
+      case (attr, i) =>
+        val targetAttr = tableAttributes(i)
+        addCastToColumn(attr, targetAttr)
+    }
+    Project(project, query)
+  }
+
   private def schemaCompatible(
       dataSchema: StructType,
       tableSchema: StructType,
@@ -83,22 +146,10 @@ class PaimonAnalysis(session: SparkSession) extends 
Rule[LogicalPlan] {
     }
 
     dataSchema.zip(tableSchema).forall {
-      case (f1, f2) =>
-        f1.name == f2.name && dataTypeCompatible(f1.name, f1.dataType, 
f2.dataType)
+      case (f1, f2) => dataTypeCompatible(f1.name, f1.dataType, f2.dataType)
     }
   }
 
-  private def resolveQueryColumns(
-      query: LogicalPlan,
-      tableAttributes: Seq[Attribute]): LogicalPlan = {
-    val project = query.output.zipWithIndex.map {
-      case (attr, i) =>
-        val targetAttr = tableAttributes(i)
-        addCastToColumn(attr, targetAttr)
-    }
-    Project(project, query)
-  }
-
   private def addCastToColumn(attr: Attribute, targetAttr: Attribute): 
NamedExpression = {
     val expr = (attr.dataType, targetAttr.dataType) match {
       case (s, t) if s == t =>
diff --git 
a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/DataFrameWriteTest.scala
 
b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/DataFrameWriteTest.scala
index a4b618318..a2509871f 100644
--- 
a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/DataFrameWriteTest.scala
+++ 
b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/DataFrameWriteTest.scala
@@ -56,6 +56,35 @@ class DataFrameWriteTest extends PaimonSparkTestBase {
     
Assertions.assertFalse(paimonTable.options().containsKey("write.merge-schema.explicit-cast"))
   }
 
+  fileFormats.foreach {
+    fileFormat =>
+      test(s"Paimon: DataFrameWrite.saveAsTable in ByName mode, file.format: 
$fileFormat") {
+        withTable("t1", "t2") {
+          spark.sql(s"""
+                       |CREATE TABLE t1 (col1 STRING, col2 INT, col3 DOUBLE)
+                       |TBLPROPERTIES ('file.format' = '$fileFormat')
+                       |""".stripMargin)
+
+          spark.sql(s"""
+                       |CREATE TABLE t2 (col2 INT, col3 DOUBLE, col1 STRING)
+                       |TBLPROPERTIES ('file.format' = '$fileFormat')
+                       |""".stripMargin)
+
+          sql(s"""
+                 |INSERT INTO TABLE t1 VALUES
+                 |("Hello", 1, 1.1),
+                 |("World", 2, 2.2),
+                 |("Paimon", 3, 3.3);
+                 |""".stripMargin)
+
+          
spark.table("t1").write.format("paimon").mode("append").saveAsTable("t2")
+          checkAnswer(
+            sql("SELECT * FROM t2 ORDER BY col2"),
+            Row(1, 1.1d, "Hello") :: Row(2, 2.2d, "World") :: Row(3, 3.3d, 
"Paimon") :: Nil)
+        }
+      }
+  }
+
   withPk.foreach {
     hasPk =>
       bucketModes.foreach {
diff --git 
a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/InsertOverwriteTableTest.scala
 
b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/InsertOverwriteTableTest.scala
index 9ad1f4523..528df32e6 100644
--- 
a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/InsertOverwriteTableTest.scala
+++ 
b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/InsertOverwriteTableTest.scala
@@ -27,6 +27,57 @@ import java.sql.{Date, Timestamp}
 
 abstract class InsertOverwriteTableTestBase extends PaimonSparkTestBase {
 
+  fileFormats.foreach {
+    fileFormat =>
+      Seq(true, false).foreach {
+        isPartitioned =>
+          test(
+            s"Paimon: insert into/overwrite in ByName mode, file.format: 
$fileFormat, isPartitioned: $isPartitioned") {
+            withTable("t1", "t2") {
+              val partitionedSQL = if (isPartitioned) {
+                "PARTITIONED BY (col4)"
+              } else {
+                ""
+              }
+              spark.sql(s"""
+                           |CREATE TABLE t1 (col1 STRING, col2 INT, col3 
DOUBLE, col4 STRING)
+                           |$partitionedSQL
+                           |TBLPROPERTIES ('file.format' = '$fileFormat')
+                           |""".stripMargin)
+
+              spark.sql(s"""
+                           |CREATE TABLE t2 (col2 INT, col3 DOUBLE, col1 
STRING, col4 STRING)
+                           |$partitionedSQL
+                           |TBLPROPERTIES ('file.format' = '$fileFormat')
+                           |""".stripMargin)
+
+              sql(s"""
+                     |INSERT INTO TABLE t1 VALUES
+                     |("Hello", 1, 1.1, "pt1"),
+                     |("Paimon", 3, 3.3, "pt2");
+                     |""".stripMargin)
+
+              sql("INSERT INTO t2 (col1, col2, col3, col4) SELECT * FROM t1")
+              checkAnswer(
+                sql("SELECT * FROM t2 ORDER BY col2"),
+                Row(1, 1.1d, "Hello", "pt1") :: Row(3, 3.3d, "Paimon", "pt2") 
:: Nil)
+
+              sql(s"""
+                     |INSERT INTO TABLE t1 VALUES ("World", 2, 2.2, "pt1");
+                     |""".stripMargin)
+              sql("INSERT OVERWRITE t2 (col1, col2, col3, col4) SELECT * FROM 
t1")
+              checkAnswer(
+                sql("SELECT * FROM t2 ORDER BY col2"),
+                Row(1, 1.1d, "Hello", "pt1") :: Row(2, 2.2d, "World", "pt1") 
:: Row(
+                  3,
+                  3.3d,
+                  "Paimon",
+                  "pt2") :: Nil)
+            }
+          }
+      }
+  }
+
   withPk.foreach {
     hasPk =>
       bucketModes.foreach {

Reply via email to