This is an automated email from the ASF dual-hosted git repository.

zouxxyy pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git


The following commit(s) were added to refs/heads/master by this push:
     new 39ca57d72 [spark] support complex data type in byName mode (#3878)
39ca57d72 is described below

commit 39ca57d72ec8334a349298cc3806ed238319620f
Author: Yann Byron <[email protected]>
AuthorDate: Fri Aug 2 20:48:48 2024 +0800

    [spark] support complex data type in byName mode (#3878)
---
 .../spark/catalyst/analysis/PaimonAnalysis.scala   | 108 ++++++++++++++++++++-
 .../paimon/spark/sql/DataFrameWriteTest.scala      |  50 +++++++++-
 2 files changed, 152 insertions(+), 6 deletions(-)

diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonAnalysis.scala
 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonAnalysis.scala
index d115fe3fd..7ed90283d 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonAnalysis.scala
+++ 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonAnalysis.scala
@@ -26,12 +26,12 @@ import org.apache.paimon.table.FileStoreTable
 
 import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.catalyst.analysis.ResolvedTable
-import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, 
Expression, NamedExpression}
+import org.apache.spark.sql.catalyst.expressions.{Alias, ArrayTransform, 
Attribute, CreateStruct, Expression, GetArrayItem, GetStructField, 
LambdaFunction, NamedExpression, NamedLambdaVariable}
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
 import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
-import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType}
+import org.apache.spark.sql.types.{ArrayType, DataType, IntegerType, MapType, 
StructField, StructType}
 
 import scala.collection.JavaConverters._
 
@@ -92,7 +92,7 @@ class PaimonAnalysis(session: SparkSession) extends 
Rule[LogicalPlan] {
             throw new RuntimeException(
               s"Cannot find ${attr.name} in data columns: 
${output.map(_.name).mkString(", ")}")
           }
-        addCastToColumn(outputAttr, attr)
+        addCastToColumn(outputAttr, attr, isByName = true)
     }
     Project(project, query)
   }
@@ -115,7 +115,7 @@ class PaimonAnalysis(session: SparkSession) extends 
Rule[LogicalPlan] {
     val project = query.output.zipWithIndex.map {
       case (attr, i) =>
         val targetAttr = tableAttributes(i)
-        addCastToColumn(attr, targetAttr)
+        addCastToColumn(attr, targetAttr, isByName = false)
     }
     Project(project, query)
   }
@@ -150,16 +150,114 @@ class PaimonAnalysis(session: SparkSession) extends 
Rule[LogicalPlan] {
     }
   }
 
-  private def addCastToColumn(attr: Attribute, targetAttr: Attribute): 
NamedExpression = {
+  private def addCastToColumn(
+      attr: Attribute,
+      targetAttr: Attribute,
+      isByName: Boolean): NamedExpression = {
     val expr = (attr.dataType, targetAttr.dataType) match {
       case (s, t) if s == t =>
         attr
+      case (s: StructType, t: StructType) if s != t =>
+        if (isByName) {
+          addCastToStructByName(attr, s, t)
+        } else {
+          addCastToStructByPosition(attr, s, t)
+        }
+      case (ArrayType(s: StructType, sNull: Boolean), ArrayType(t: StructType, 
_: Boolean))
+          if s != t =>
+        val castToStructFunc = if (isByName) {
+          addCastToStructByName _
+        } else {
+          addCastToStructByPosition _
+        }
+        castToArrayStruct(attr, s, t, sNull, castToStructFunc)
       case _ =>
         cast(attr, targetAttr.dataType)
     }
     Alias(expr, targetAttr.name)(explicitMetadata = 
Option(targetAttr.metadata))
   }
 
+  private def addCastToStructByName(
+      parent: NamedExpression,
+      source: StructType,
+      target: StructType): NamedExpression = {
+    val fields = target.map {
+      case targetField @ StructField(name, nested: StructType, _, _) =>
+        val sourceIndex = source.fieldIndex(name)
+        val sourceField = source(sourceIndex)
+        sourceField.dataType match {
+          case s: StructType =>
+            val subField = castStructField(parent, sourceIndex, 
sourceField.name, targetField)
+            addCastToStructByName(subField, s, nested)
+          case o =>
+            throw new RuntimeException(s"Can not support to cast $o to 
StructType.")
+        }
+      case targetField =>
+        val sourceIndex = source.fieldIndex(targetField.name)
+        val sourceField = source(sourceIndex)
+        castStructField(parent, sourceIndex, sourceField.name, targetField)
+    }
+    Alias(CreateStruct(fields), parent.name)(
+      parent.exprId,
+      parent.qualifier,
+      Option(parent.metadata))
+  }
+
+  private def addCastToStructByPosition(
+      parent: NamedExpression,
+      source: StructType,
+      target: StructType): NamedExpression = {
+    if (source.length != target.length) {
+      throw new RuntimeException("The number of fields in source and target is 
not same.")
+    }
+
+    val fields = target.zipWithIndex.map {
+      case (targetField @ StructField(_, nested: StructType, _, _), i) =>
+        val sourceField = source(i)
+        sourceField.dataType match {
+          case s: StructType =>
+            val subField = castStructField(parent, i, sourceField.name, 
targetField)
+            addCastToStructByPosition(subField, s, nested)
+          case o =>
+            throw new RuntimeException(s"Can not support to cast $o to 
StructType.")
+        }
+      case (targetField, i) =>
+        val sourceField = source(i)
+        castStructField(parent, i, sourceField.name, targetField)
+    }
+    Alias(CreateStruct(fields), parent.name)(
+      parent.exprId,
+      parent.qualifier,
+      Option(parent.metadata))
+  }
+
+  private def castStructField(
+      parent: NamedExpression,
+      i: Int,
+      sourceFieldName: String,
+      targetField: StructField): NamedExpression = {
+    Alias(
+      cast(GetStructField(parent, i, Option(sourceFieldName)), 
targetField.dataType),
+      targetField.name
+    )(explicitMetadata = Option(targetField.metadata))
+  }
+  private def castToArrayStruct(
+      parent: NamedExpression,
+      source: StructType,
+      target: StructType,
+      sourceNullable: Boolean,
+      castToStructFunc: (NamedExpression, StructType, StructType) => 
NamedExpression
+  ): Expression = {
+    val structConverter: (Expression, Expression) => Expression = (_, i) =>
+      castToStructFunc(Alias(GetArrayItem(parent, i), i.toString)(), source, 
target)
+    val transformLambdaFunc = {
+      val elementVar = NamedLambdaVariable("elementVar", source, 
sourceNullable)
+      val indexVar = NamedLambdaVariable("indexVar", IntegerType, false)
+      LambdaFunction(structConverter(elementVar, indexVar), Seq(elementVar, 
indexVar))
+    }
+    ArrayTransform(parent, transformLambdaFunc)
+  }
+
   private def cast(expr: Expression, dataType: DataType): Expression = {
     val cast = Compatibility.cast(expr, dataType, 
Option(conf.sessionLocalTimeZone))
     cast.setTagValue(Compatibility.castByTableInsertionTag, ())
diff --git 
a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/DataFrameWriteTest.scala
 
b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/DataFrameWriteTest.scala
index a2509871f..f50483d9f 100644
--- 
a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/DataFrameWriteTest.scala
+++ 
b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/DataFrameWriteTest.scala
@@ -23,7 +23,7 @@ import org.apache.paimon.spark.PaimonSparkTestBase
 import org.apache.spark.sql.Row
 import org.junit.jupiter.api.Assertions
 
-import java.sql.Date
+import java.sql.{Date, Timestamp}
 
 class DataFrameWriteTest extends PaimonSparkTestBase {
 
@@ -85,6 +85,54 @@ class DataFrameWriteTest extends PaimonSparkTestBase {
       }
   }
 
+  fileFormats.foreach {
+    fileFormat =>
+      test(
+        s"Paimon: DataFrameWrite.saveAsTable with complex data type in ByName 
mode, file.format: $fileFormat") {
+        withTable("t1", "t2") {
+          spark.sql(
+            s"""
+               |CREATE TABLE t1 (a STRING, b INT, c STRUCT<c1:DOUBLE, 
c2:LONG>, d ARRAY<STRUCT<d1 TIMESTAMP, d2 MAP<STRING, STRING>>>, e ARRAY<INT>)
+               |TBLPROPERTIES ('file.format' = '$fileFormat')
+               |""".stripMargin)
+
+          spark.sql(
+            s"""
+               |CREATE TABLE t2 (b INT, c STRUCT<c2:LONG, c1:DOUBLE>, d 
ARRAY<STRUCT<d2 MAP<STRING, STRING>, d1 TIMESTAMP>>, e ARRAY<INT>, a STRING)
+               |TBLPROPERTIES ('file.format' = '$fileFormat')
+               |""".stripMargin)
+
+          sql(s"""
+                 |INSERT INTO TABLE t1 VALUES
+                 |("Hello", 1, struct(1.1, 1000), 
array(struct(timestamp'2024-01-01 00:00:00', map("k1", "v1")), 
struct(timestamp'2024-08-01 00:00:00', map("k1", "v11"))), array(123, 345)),
+                 |("World", 2, struct(2.2, 2000), 
array(struct(timestamp'2024-02-01 00:00:00', map("k2", "v2"))), array(234, 
456)),
+                 |("Paimon", 3, struct(3.3, 3000), null, array(345, 567));
+                 |""".stripMargin)
+
+          
spark.table("t1").write.format("paimon").mode("append").saveAsTable("t2")
+          checkAnswer(
+            sql("SELECT * FROM t2 ORDER BY b"),
+            Row(
+              1,
+              Row(1000L, 1.1d),
+              Array(
+                Row(Map("k1" -> "v1"), Timestamp.valueOf("2024-01-01 
00:00:00")),
+                Row(Map("k1" -> "v11"), Timestamp.valueOf("2024-08-01 
00:00:00"))),
+              Array(123, 345),
+              "Hello"
+            )
+              :: Row(
+                2,
+                Row(2000L, 2.2d),
+                Array(Row(Map("k2" -> "v2"), Timestamp.valueOf("2024-02-01 
00:00:00"))),
+                Array(234, 456),
+                "World")
+              :: Row(3, Row(3000L, 3.3d), null, Array(345, 567), "Paimon") :: 
Nil
+          )
+        }
+      }
+  }
+
   withPk.foreach {
     hasPk =>
       bucketModes.foreach {

Reply via email to