This is an automated email from the ASF dual-hosted git repository.
zouxxyy pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git
The following commit(s) were added to refs/heads/master by this push:
new 39ca57d72 [spark] support complex data type in byName mode (#3878)
39ca57d72 is described below
commit 39ca57d72ec8334a349298cc3806ed238319620f
Author: Yann Byron <[email protected]>
AuthorDate: Fri Aug 2 20:48:48 2024 +0800
[spark] support complex data type in byName mode (#3878)
---
.../spark/catalyst/analysis/PaimonAnalysis.scala | 108 ++++++++++++++++++++-
.../paimon/spark/sql/DataFrameWriteTest.scala | 50 +++++++++-
2 files changed, 152 insertions(+), 6 deletions(-)
diff --git
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonAnalysis.scala
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonAnalysis.scala
index d115fe3fd..7ed90283d 100644
---
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonAnalysis.scala
+++
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonAnalysis.scala
@@ -26,12 +26,12 @@ import org.apache.paimon.table.FileStoreTable
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.analysis.ResolvedTable
-import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute,
Expression, NamedExpression}
+import org.apache.spark.sql.catalyst.expressions.{Alias, ArrayTransform,
Attribute, CreateStruct, Expression, GetArrayItem, GetStructField,
LambdaFunction, NamedExpression, NamedLambdaVariable}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
-import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType}
+import org.apache.spark.sql.types.{ArrayType, DataType, IntegerType, MapType,
StructField, StructType}
import scala.collection.JavaConverters._
@@ -92,7 +92,7 @@ class PaimonAnalysis(session: SparkSession) extends
Rule[LogicalPlan] {
throw new RuntimeException(
s"Cannot find ${attr.name} in data columns:
${output.map(_.name).mkString(", ")}")
}
- addCastToColumn(outputAttr, attr)
+ addCastToColumn(outputAttr, attr, isByName = true)
}
Project(project, query)
}
@@ -115,7 +115,7 @@ class PaimonAnalysis(session: SparkSession) extends
Rule[LogicalPlan] {
val project = query.output.zipWithIndex.map {
case (attr, i) =>
val targetAttr = tableAttributes(i)
- addCastToColumn(attr, targetAttr)
+ addCastToColumn(attr, targetAttr, isByName = false)
}
Project(project, query)
}
@@ -150,16 +150,114 @@ class PaimonAnalysis(session: SparkSession) extends
Rule[LogicalPlan] {
}
}
- private def addCastToColumn(attr: Attribute, targetAttr: Attribute):
NamedExpression = {
+ private def addCastToColumn(
+ attr: Attribute,
+ targetAttr: Attribute,
+ isByName: Boolean): NamedExpression = {
val expr = (attr.dataType, targetAttr.dataType) match {
case (s, t) if s == t =>
attr
+ case (s: StructType, t: StructType) if s != t =>
+ if (isByName) {
+ addCastToStructByName(attr, s, t)
+ } else {
+ addCastToStructByPosition(attr, s, t)
+ }
+ case (ArrayType(s: StructType, sNull: Boolean), ArrayType(t: StructType,
_: Boolean))
+ if s != t =>
+ val castToStructFunc = if (isByName) {
+ addCastToStructByName _
+ } else {
+ addCastToStructByPosition _
+ }
+ castToArrayStruct(attr, s, t, sNull, castToStructFunc)
case _ =>
cast(attr, targetAttr.dataType)
}
Alias(expr, targetAttr.name)(explicitMetadata =
Option(targetAttr.metadata))
}
+ private def addCastToStructByName(
+ parent: NamedExpression,
+ source: StructType,
+ target: StructType): NamedExpression = {
+ val fields = target.map {
+ case targetField @ StructField(name, nested: StructType, _, _) =>
+ val sourceIndex = source.fieldIndex(name)
+ val sourceField = source(sourceIndex)
+ sourceField.dataType match {
+ case s: StructType =>
+ val subField = castStructField(parent, sourceIndex,
sourceField.name, targetField)
+ addCastToStructByName(subField, s, nested)
+ case o =>
+ throw new RuntimeException(s"Can not support to cast $o to
StructType.")
+ }
+ case targetField =>
+ val sourceIndex = source.fieldIndex(targetField.name)
+ val sourceField = source(sourceIndex)
+ castStructField(parent, sourceIndex, sourceField.name, targetField)
+ }
+ Alias(CreateStruct(fields), parent.name)(
+ parent.exprId,
+ parent.qualifier,
+ Option(parent.metadata))
+ }
+
+ private def addCastToStructByPosition(
+ parent: NamedExpression,
+ source: StructType,
+ target: StructType): NamedExpression = {
+ if (source.length != target.length) {
+ throw new RuntimeException("The number of fields in source and target is
not same.")
+ }
+
+ val fields = target.zipWithIndex.map {
+ case (targetField @ StructField(_, nested: StructType, _, _), i) =>
+ val sourceField = source(i)
+ sourceField.dataType match {
+ case s: StructType =>
+ val subField = castStructField(parent, i, sourceField.name,
targetField)
+ addCastToStructByPosition(subField, s, nested)
+ case o =>
+ throw new RuntimeException(s"Can not support to cast $o to
StructType.")
+ }
+ case (targetField, i) =>
+ val sourceField = source(i)
+ castStructField(parent, i, sourceField.name, targetField)
+ }
+ Alias(CreateStruct(fields), parent.name)(
+ parent.exprId,
+ parent.qualifier,
+ Option(parent.metadata))
+ }
+
+ private def castStructField(
+ parent: NamedExpression,
+ i: Int,
+ sourceFieldName: String,
+ targetField: StructField): NamedExpression = {
+ Alias(
+ cast(GetStructField(parent, i, Option(sourceFieldName)),
targetField.dataType),
+ targetField.name
+ )(explicitMetadata = Option(targetField.metadata))
+ }
+ private def castToArrayStruct(
+ parent: NamedExpression,
+ source: StructType,
+ target: StructType,
+ sourceNullable: Boolean,
+ castToStructFunc: (NamedExpression, StructType, StructType) =>
NamedExpression
+ ): Expression = {
+ val structConverter: (Expression, Expression) => Expression = (_, i) =>
+ castToStructFunc(Alias(GetArrayItem(parent, i), i.toString)(), source,
target)
+ val transformLambdaFunc = {
+ val elementVar = NamedLambdaVariable("elementVar", source,
sourceNullable)
+ val indexVar = NamedLambdaVariable("indexVar", IntegerType, false)
+ LambdaFunction(structConverter(elementVar, indexVar), Seq(elementVar,
indexVar))
+ }
+ ArrayTransform(parent, transformLambdaFunc)
+ }
+
private def cast(expr: Expression, dataType: DataType): Expression = {
val cast = Compatibility.cast(expr, dataType,
Option(conf.sessionLocalTimeZone))
cast.setTagValue(Compatibility.castByTableInsertionTag, ())
diff --git
a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/DataFrameWriteTest.scala
b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/DataFrameWriteTest.scala
index a2509871f..f50483d9f 100644
---
a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/DataFrameWriteTest.scala
+++
b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/DataFrameWriteTest.scala
@@ -23,7 +23,7 @@ import org.apache.paimon.spark.PaimonSparkTestBase
import org.apache.spark.sql.Row
import org.junit.jupiter.api.Assertions
-import java.sql.Date
+import java.sql.{Date, Timestamp}
class DataFrameWriteTest extends PaimonSparkTestBase {
@@ -85,6 +85,54 @@ class DataFrameWriteTest extends PaimonSparkTestBase {
}
}
+ fileFormats.foreach {
+ fileFormat =>
+ test(
+ s"Paimon: DataFrameWrite.saveAsTable with complex data type in ByName
mode, file.format: $fileFormat") {
+ withTable("t1", "t2") {
+ spark.sql(
+ s"""
+ |CREATE TABLE t1 (a STRING, b INT, c STRUCT<c1:DOUBLE,
c2:LONG>, d ARRAY<STRUCT<d1 TIMESTAMP, d2 MAP<STRING, STRING>>>, e ARRAY<INT>)
+ |TBLPROPERTIES ('file.format' = '$fileFormat')
+ |""".stripMargin)
+
+ spark.sql(
+ s"""
+ |CREATE TABLE t2 (b INT, c STRUCT<c2:LONG, c1:DOUBLE>, d
ARRAY<STRUCT<d2 MAP<STRING, STRING>, d1 TIMESTAMP>>, e ARRAY<INT>, a STRING)
+ |TBLPROPERTIES ('file.format' = '$fileFormat')
+ |""".stripMargin)
+
+ sql(s"""
+ |INSERT INTO TABLE t1 VALUES
+ |("Hello", 1, struct(1.1, 1000),
array(struct(timestamp'2024-01-01 00:00:00', map("k1", "v1")),
struct(timestamp'2024-08-01 00:00:00', map("k1", "v11"))), array(123, 345)),
+ |("World", 2, struct(2.2, 2000),
array(struct(timestamp'2024-02-01 00:00:00', map("k2", "v2"))), array(234,
456)),
+ |("Paimon", 3, struct(3.3, 3000), null, array(345, 567));
+ |""".stripMargin)
+
+
spark.table("t1").write.format("paimon").mode("append").saveAsTable("t2")
+ checkAnswer(
+ sql("SELECT * FROM t2 ORDER BY b"),
+ Row(
+ 1,
+ Row(1000L, 1.1d),
+ Array(
+ Row(Map("k1" -> "v1"), Timestamp.valueOf("2024-01-01
00:00:00")),
+ Row(Map("k1" -> "v11"), Timestamp.valueOf("2024-08-01
00:00:00"))),
+ Array(123, 345),
+ "Hello"
+ )
+ :: Row(
+ 2,
+ Row(2000L, 2.2d),
+ Array(Row(Map("k2" -> "v2"), Timestamp.valueOf("2024-02-01
00:00:00"))),
+ Array(234, 456),
+ "World")
+ :: Row(3, Row(3000L, 3.3d), null, Array(345, 567), "Paimon") ::
Nil
+ )
+ }
+ }
+ }
+
withPk.foreach {
hasPk =>
bucketModes.foreach {