This is an automated email from the ASF dual-hosted git repository.
zouxxyy pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git
The following commit(s) were added to refs/heads/master by this push:
new 7a9bcdaa3 [spark] dataframe.write and insert sql syntax in byName mode
(#3871)
7a9bcdaa3 is described below
commit 7a9bcdaa3b556adb2695740d44fd109dc80eb87a
Author: Yann Byron <[email protected]>
AuthorDate: Thu Aug 1 23:08:41 2024 +0800
[spark] dataframe.write and insert sql syntax in byName mode (#3871)
---
.../spark/catalyst/analysis/PaimonAnalysis.scala | 87 +++++++++++++++++-----
.../paimon/spark/sql/DataFrameWriteTest.scala | 29 ++++++++
.../spark/sql/InsertOverwriteTableTest.scala | 51 +++++++++++++
3 files changed, 149 insertions(+), 18 deletions(-)
diff --git
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonAnalysis.scala
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonAnalysis.scala
index 3dc0e40c9..d115fe3fd 100644
---
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonAnalysis.scala
+++
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonAnalysis.scala
@@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.analysis.ResolvedTable
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute,
Expression, NamedExpression}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType}
@@ -39,11 +40,17 @@ class PaimonAnalysis(session: SparkSession) extends
Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan =
plan.resolveOperatorsDown {
case a @ PaimonV2WriteCommand(table, paimonTable)
- if !schemaCompatible(
- a.query.output.toStructType,
- table.output.toStructType,
- paimonTable.partitionKeys().asScala) =>
- val newQuery = resolveQueryColumns(a.query, table.output)
+ if a.isByName && needsSchemaAdjustmentByName(a.query, table.output,
paimonTable) =>
+ val newQuery = resolveQueryColumnsByName(a.query, table.output)
+ if (newQuery != a.query) {
+ Compatibility.withNewQuery(a, newQuery)
+ } else {
+ a
+ }
+
+ case a @ PaimonV2WriteCommand(table, paimonTable)
+ if !a.isByName && needsSchemaAdjustmentByPosition(a.query,
table.output, paimonTable) =>
+ val newQuery = resolveQueryColumnsByPosition(a.query, table.output)
if (newQuery != a.query) {
Compatibility.withNewQuery(a, newQuery)
} else {
@@ -57,6 +64,62 @@ class PaimonAnalysis(session: SparkSession) extends
Rule[LogicalPlan] {
PaimonMergeIntoResolver(merge, session)
}
+ private def needsSchemaAdjustmentByName(
+ query: LogicalPlan,
+ targetAttrs: Seq[Attribute],
+ paimonTable: FileStoreTable): Boolean = {
+ val userSpecifiedNames = if
(session.sessionState.conf.caseSensitiveAnalysis) {
+ query.output.map(a => (a.name, a)).toMap
+ } else {
+ CaseInsensitiveMap(query.output.map(a => (a.name, a)).toMap)
+ }
+ val specifiedTargetAttrs = targetAttrs.filter(col =>
userSpecifiedNames.contains(col.name))
+ !schemaCompatible(
+ specifiedTargetAttrs.toStructType,
+ query.output.toStructType,
+ paimonTable.partitionKeys().asScala)
+ }
+
+ private def resolveQueryColumnsByName(
+ query: LogicalPlan,
+ targetAttrs: Seq[Attribute]): LogicalPlan = {
+ val output = query.output
+ val project = targetAttrs.map {
+ attr =>
+ val outputAttr = output
+ .find(t => session.sessionState.conf.resolver(t.name, attr.name))
+ .getOrElse {
+ throw new RuntimeException(
+ s"Cannot find ${attr.name} in data columns:
${output.map(_.name).mkString(", ")}")
+ }
+ addCastToColumn(outputAttr, attr)
+ }
+ Project(project, query)
+ }
+
+ private def needsSchemaAdjustmentByPosition(
+ query: LogicalPlan,
+ targetAttrs: Seq[Attribute],
+ paimonTable: FileStoreTable): Boolean = {
+ val output = query.output
+ targetAttrs.map(_.name) != output.map(_.name) ||
+ !schemaCompatible(
+ targetAttrs.toStructType,
+ output.toStructType,
+ paimonTable.partitionKeys().asScala)
+ }
+
+ private def resolveQueryColumnsByPosition(
+ query: LogicalPlan,
+ tableAttributes: Seq[Attribute]): LogicalPlan = {
+ val project = query.output.zipWithIndex.map {
+ case (attr, i) =>
+ val targetAttr = tableAttributes(i)
+ addCastToColumn(attr, targetAttr)
+ }
+ Project(project, query)
+ }
+
private def schemaCompatible(
dataSchema: StructType,
tableSchema: StructType,
@@ -83,22 +146,10 @@ class PaimonAnalysis(session: SparkSession) extends
Rule[LogicalPlan] {
}
dataSchema.zip(tableSchema).forall {
- case (f1, f2) =>
- f1.name == f2.name && dataTypeCompatible(f1.name, f1.dataType,
f2.dataType)
+ case (f1, f2) => dataTypeCompatible(f1.name, f1.dataType, f2.dataType)
}
}
- private def resolveQueryColumns(
- query: LogicalPlan,
- tableAttributes: Seq[Attribute]): LogicalPlan = {
- val project = query.output.zipWithIndex.map {
- case (attr, i) =>
- val targetAttr = tableAttributes(i)
- addCastToColumn(attr, targetAttr)
- }
- Project(project, query)
- }
-
private def addCastToColumn(attr: Attribute, targetAttr: Attribute):
NamedExpression = {
val expr = (attr.dataType, targetAttr.dataType) match {
case (s, t) if s == t =>
diff --git
a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/DataFrameWriteTest.scala
b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/DataFrameWriteTest.scala
index a4b618318..a2509871f 100644
---
a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/DataFrameWriteTest.scala
+++
b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/DataFrameWriteTest.scala
@@ -56,6 +56,35 @@ class DataFrameWriteTest extends PaimonSparkTestBase {
Assertions.assertFalse(paimonTable.options().containsKey("write.merge-schema.explicit-cast"))
}
+ fileFormats.foreach {
+ fileFormat =>
+ test(s"Paimon: DataFrameWrite.saveAsTable in ByName mode, file.format:
$fileFormat") {
+ withTable("t1", "t2") {
+ spark.sql(s"""
+ |CREATE TABLE t1 (col1 STRING, col2 INT, col3 DOUBLE)
+ |TBLPROPERTIES ('file.format' = '$fileFormat')
+ |""".stripMargin)
+
+ spark.sql(s"""
+ |CREATE TABLE t2 (col2 INT, col3 DOUBLE, col1 STRING)
+ |TBLPROPERTIES ('file.format' = '$fileFormat')
+ |""".stripMargin)
+
+ sql(s"""
+ |INSERT INTO TABLE t1 VALUES
+ |("Hello", 1, 1.1),
+ |("World", 2, 2.2),
+ |("Paimon", 3, 3.3);
+ |""".stripMargin)
+
+
spark.table("t1").write.format("paimon").mode("append").saveAsTable("t2")
+ checkAnswer(
+ sql("SELECT * FROM t2 ORDER BY col2"),
+ Row(1, 1.1d, "Hello") :: Row(2, 2.2d, "World") :: Row(3, 3.3d,
"Paimon") :: Nil)
+ }
+ }
+ }
+
withPk.foreach {
hasPk =>
bucketModes.foreach {
diff --git
a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/InsertOverwriteTableTest.scala
b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/InsertOverwriteTableTest.scala
index 9ad1f4523..528df32e6 100644
---
a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/InsertOverwriteTableTest.scala
+++
b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/InsertOverwriteTableTest.scala
@@ -27,6 +27,57 @@ import java.sql.{Date, Timestamp}
abstract class InsertOverwriteTableTestBase extends PaimonSparkTestBase {
+ fileFormats.foreach {
+ fileFormat =>
+ Seq(true, false).foreach {
+ isPartitioned =>
+ test(
+ s"Paimon: insert into/overwrite in ByName mode, file.format:
$fileFormat, isPartitioned: $isPartitioned") {
+ withTable("t1", "t2") {
+ val partitionedSQL = if (isPartitioned) {
+ "PARTITIONED BY (col4)"
+ } else {
+ ""
+ }
+ spark.sql(s"""
+ |CREATE TABLE t1 (col1 STRING, col2 INT, col3
DOUBLE, col4 STRING)
+ |$partitionedSQL
+ |TBLPROPERTIES ('file.format' = '$fileFormat')
+ |""".stripMargin)
+
+ spark.sql(s"""
+ |CREATE TABLE t2 (col2 INT, col3 DOUBLE, col1
STRING, col4 STRING)
+ |$partitionedSQL
+ |TBLPROPERTIES ('file.format' = '$fileFormat')
+ |""".stripMargin)
+
+ sql(s"""
+ |INSERT INTO TABLE t1 VALUES
+ |("Hello", 1, 1.1, "pt1"),
+ |("Paimon", 3, 3.3, "pt2");
+ |""".stripMargin)
+
+ sql("INSERT INTO t2 (col1, col2, col3, col4) SELECT * FROM t1")
+ checkAnswer(
+ sql("SELECT * FROM t2 ORDER BY col2"),
+ Row(1, 1.1d, "Hello", "pt1") :: Row(3, 3.3d, "Paimon", "pt2")
:: Nil)
+
+ sql(s"""
+ |INSERT INTO TABLE t1 VALUES ("World", 2, 2.2, "pt1");
+ |""".stripMargin)
+ sql("INSERT OVERWRITE t2 (col1, col2, col3, col4) SELECT * FROM
t1")
+ checkAnswer(
+ sql("SELECT * FROM t2 ORDER BY col2"),
+ Row(1, 1.1d, "Hello", "pt1") :: Row(2, 2.2d, "World", "pt1")
:: Row(
+ 3,
+ 3.3d,
+ "Paimon",
+ "pt2") :: Nil)
+ }
+ }
+ }
+ }
+
withPk.foreach {
hasPk =>
bucketModes.foreach {