This is an automated email from the ASF dual-hosted git repository.

biyan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git


The following commit(s) were added to refs/heads/master by this push:
     new cbda788a7 Support dataframe v2 write with overwrite (#4082)
cbda788a7 is described below

commit cbda788a73412ea004095993979599519797e210
Author: Xiduo You <[email protected]>
AuthorDate: Wed Sep 4 17:16:22 2024 +0800

    Support dataframe v2 write with overwrite (#4082)
---
 .../apache/paimon/spark/SparkFilterConverter.java  |   5 +
 .../apache/paimon/spark/SparkWriteBuilder.scala    |  56 ++++++++++-
 .../spark/catalyst/analysis/PaimonAnalysis.scala   |   3 +-
 .../paimon/spark/commands/PaimonCommand.scala      |  30 +++---
 .../spark/commands/WriteIntoPaimonTable.scala      |   2 +-
 .../paimon/spark/sql/DataFrameWriteTest.scala      | 103 +++++++++++++++++++--
 6 files changed, 167 insertions(+), 32 deletions(-)

diff --git 
a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkFilterConverter.java
 
b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkFilterConverter.java
index ea61f6f7d..2ea2e3c45 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkFilterConverter.java
+++ 
b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkFilterConverter.java
@@ -166,6 +166,11 @@ public class SparkFilterConverter {
         return convertLiteral(fieldIndex(field), value);
     }
 
+    public String convertString(String field, Object value) {
+        Object literal = convertLiteral(field, value);
+        return literal == null ? null : literal.toString();
+    }
+
     private int fieldIndex(String field) {
         int index = rowType.getFieldIndex(field);
         // TODO: support nested field
diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkWriteBuilder.scala
 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkWriteBuilder.scala
index 96f75ab10..74a474b8c 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkWriteBuilder.scala
+++ 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkWriteBuilder.scala
@@ -21,18 +21,70 @@ package org.apache.paimon.spark
 import org.apache.paimon.options.Options
 import org.apache.paimon.table.FileStoreTable
 
+import org.apache.spark.sql.catalyst.SQLConfHelper
 import org.apache.spark.sql.connector.write.{SupportsOverwrite, WriteBuilder}
-import org.apache.spark.sql.sources.{And, Filter}
+import org.apache.spark.sql.sources.{AlwaysFalse, AlwaysTrue, And, 
EqualNullSafe, EqualTo, Filter, Not, Or}
+
+import scala.collection.JavaConverters._
 
 private class SparkWriteBuilder(table: FileStoreTable, options: Options)
   extends WriteBuilder
-  with SupportsOverwrite {
+  with SupportsOverwrite
+  with SQLConfHelper {
 
   private var saveMode: SaveMode = InsertInto
 
   override def build = new SparkWrite(table, saveMode, options)
 
+  private def failWithReason(filter: Filter): Unit = {
+    throw new RuntimeException(
+      s"Only support Overwrite filters with Equal and EqualNullSafe, but got: 
$filter")
+  }
+
+  private def validateFilter(filter: Filter): Unit = filter match {
+    case And(left, right) =>
+      validateFilter(left)
+      validateFilter(right)
+    case _: Or => failWithReason(filter)
+    case _: Not => failWithReason(filter)
+    case e: EqualTo if e.references.length == 1 && 
!e.value.isInstanceOf[Filter] =>
+    case e: EqualNullSafe if e.references.length == 1 && 
!e.value.isInstanceOf[Filter] =>
+    case _: AlwaysTrue | _: AlwaysFalse =>
+    case _ => failWithReason(filter)
+  }
+
+  // `SupportsOverwrite#canOverwrite` is added since Spark 3.4.0.
+  // We do this checking by self to work with previous Spark version.
+  private def failIfCanNotOverwrite(filters: Array[Filter]): Unit = {
+    // For now, we only support overwrite with two cases:
+    // - overwrite with partition columns to be compatible with v1 insert 
overwrite
+    //   See 
[[org.apache.spark.sql.catalyst.analysis.Analyzer.ResolveInsertInto#staticDeleteExpression]].
+    // - truncate-like overwrite and the filter is always true.
+    //
+    // Fast fail for other custom filters which through v2 write interface, 
e.g.,
+    // `dataframe.writeTo(T).overwrite(...)`
+    val partitionRowType = table.schema.logicalPartitionType()
+    val partitionNames = partitionRowType.getFieldNames.asScala
+    val allReferences = filters.flatMap(_.references)
+    val containsDataColumn = allReferences.exists {
+      reference => !partitionNames.exists(conf.resolver.apply(reference, _))
+    }
+    if (containsDataColumn) {
+      throw new RuntimeException(
+        s"Only support Overwrite filters on partition column 
${partitionNames.mkString(
+            ", ")}, but got ${filters.mkString(", ")}.")
+    }
+    if (allReferences.distinct.length < allReferences.length) {
+      // fail with `part = 1 and part = 2`
+      throw new RuntimeException(
+        s"Only support Overwrite with one filter for each partition column, 
but got ${filters.mkString(", ")}.")
+    }
+    filters.foreach(validateFilter)
+  }
+
   override def overwrite(filters: Array[Filter]): WriteBuilder = {
+    failIfCanNotOverwrite(filters)
+
     val conjunctiveFilters = if (filters.nonEmpty) {
       Some(filters.reduce((l, r) => And(l, r)))
     } else {
diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonAnalysis.scala
 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonAnalysis.scala
index 307cab734..98d3c03aa 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonAnalysis.scala
+++ 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/PaimonAnalysis.scala
@@ -119,8 +119,7 @@ class PaimonAnalysis(session: SparkSession) extends 
Rule[LogicalPlan] {
         .mkString(", ")
       // There are seme unknown column names
       throw new RuntimeException(
-        s"Cannot write incompatible data for the table `${table.name}`, due to 
unknown column names: ${extraCols
-            .mkString(", ")}.")
+        s"Cannot write incompatible data for the table `${table.name}`, due to 
unknown column names: $extraCols.")
     }
     Project(reorderedCols, query)
   }
diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonCommand.scala
 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonCommand.scala
index aad4b82bd..e8caea3cd 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonCommand.scala
+++ 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonCommand.scala
@@ -18,7 +18,6 @@
 
 package org.apache.paimon.spark.commands
 
-import org.apache.paimon.data.BinaryRow
 import org.apache.paimon.deletionvectors.BitmapDeletionVector
 import org.apache.paimon.fs.Path
 import org.apache.paimon.index.IndexFileMeta
@@ -38,9 +37,9 @@ import org.apache.spark.sql.{Dataset, Row, SparkSession}
 import org.apache.spark.sql.PaimonUtils.createDataset
 import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
 import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral
-import org.apache.spark.sql.catalyst.plans.logical.{Filter => 
FilterLogicalNode, LogicalPlan, Project}
+import org.apache.spark.sql.catalyst.plans.logical.{Filter => 
FilterLogicalNode, LogicalPlan}
 import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, 
DataSourceV2ScanRelation}
-import org.apache.spark.sql.sources.{AlwaysTrue, And, EqualNullSafe, Filter}
+import org.apache.spark.sql.sources.{AlwaysTrue, And, EqualNullSafe, EqualTo, 
Filter}
 
 import java.net.URI
 import java.util.Collections
@@ -59,23 +58,20 @@ trait PaimonCommand extends WithFileStoreTable with 
ExpressionHelper {
     filters.length == 1 && filters.head.isInstanceOf[AlwaysTrue]
   }
 
-  /**
-   * For the 'INSERT OVERWRITE T PARTITION (partitionVal, ...)' semantics of 
SQL, Spark will
-   * transform `partitionVal`s to EqualNullSafe Filters.
-   */
-  def convertFilterToMap(filter: Filter, partitionRowType: RowType): 
Map[String, String] = {
+  /** See [[ org.apache.paimon.spark.SparkWriteBuilder#failIfCanNotOverwrite]] 
*/
+  def convertPartitionFilterToMap(
+      filter: Filter,
+      partitionRowType: RowType): Map[String, String] = {
     val converter = new SparkFilterConverter(partitionRowType)
     splitConjunctiveFilters(filter).map {
       case EqualNullSafe(attribute, value) =>
-        if (isNestedFilterInValue(value)) {
-          throw new RuntimeException(
-            s"Not support the complex partition value in EqualNullSafe when 
run `INSERT OVERWRITE`.")
-        } else {
-          (attribute, converter.convertLiteral(attribute, value).toString)
-        }
+        (attribute, converter.convertString(attribute, value))
+      case EqualTo(attribute, value) =>
+        (attribute, converter.convertString(attribute, value))
       case _ =>
+        // Should not happen
         throw new RuntimeException(
-          s"Only EqualNullSafe should be used when run `INSERT OVERWRITE`.")
+          s"Only support Overwrite filters with Equal and EqualNullSafe, but 
got: $filter")
     }.toMap
   }
 
@@ -87,10 +83,6 @@ trait PaimonCommand extends WithFileStoreTable with 
ExpressionHelper {
     }
   }
 
-  private def isNestedFilterInValue(value: Any): Boolean = {
-    value.isInstanceOf[Filter]
-  }
-
   /** Gets a relative path against the table path. */
   protected def relativePath(absolutePath: String): String = {
     val location = table.location().toUri
diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/WriteIntoPaimonTable.scala
 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/WriteIntoPaimonTable.scala
index 905c9cdfb..fe740ea8c 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/WriteIntoPaimonTable.scala
+++ 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/WriteIntoPaimonTable.scala
@@ -76,7 +76,7 @@ case class WriteIntoPaimonTable(
         } else if (isTruncate(filter.get)) {
           Map.empty[String, String]
         } else {
-          convertFilterToMap(filter.get, table.schema.logicalPartitionType())
+          convertPartitionFilterToMap(filter.get, 
table.schema.logicalPartitionType())
         }
       case DynamicOverWrite =>
         dynamicPartitionOverwriteMode = true
diff --git 
a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/DataFrameWriteTest.scala
 
b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/DataFrameWriteTest.scala
index f50483d9f..ca3ba8797 100644
--- 
a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/DataFrameWriteTest.scala
+++ 
b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/DataFrameWriteTest.scala
@@ -26,11 +26,9 @@ import org.junit.jupiter.api.Assertions
 import java.sql.{Date, Timestamp}
 
 class DataFrameWriteTest extends PaimonSparkTestBase {
+  import testImplicits._
 
   test("Paimon: DataFrameWrite.saveAsTable") {
-
-    import testImplicits._
-
     Seq((1L, "x1"), (2L, "x2"))
       .toDF("a", "b")
       .write
@@ -139,9 +137,6 @@ class DataFrameWriteTest extends PaimonSparkTestBase {
         bucket =>
           test(s"Write data into Paimon directly: has-pk: $hasPk, bucket: 
$bucket") {
 
-            val _spark = spark
-            import _spark.implicits._
-
             val prop = if (hasPk) {
               s"'primary-key'='a', 'bucket' = '$bucket' "
             } else if (bucket != -1) {
@@ -278,8 +273,6 @@ class DataFrameWriteTest extends PaimonSparkTestBase {
         bucket =>
           test(
             s"Schema evolution: write data into Paimon with allowExplicitCast 
= true: $hasPk, bucket: $bucket") {
-            val _spark = spark
-            import _spark.implicits._
 
             val prop = if (hasPk) {
               s"'primary-key'='a', 'bucket' = '$bucket' "
@@ -380,4 +373,98 @@ class DataFrameWriteTest extends PaimonSparkTestBase {
       }
   }
 
+  withPk.foreach {
+    hasPk =>
+      test(s"Support v2 write with overwrite, hasPk: $hasPk") {
+        withTable("t") {
+          val prop = if (hasPk) {
+            "'primary-key'='c1'"
+          } else {
+            "'write-only'='true'"
+          }
+          spark.sql(s"""
+                       |CREATE TABLE t (c1 INT, c2 STRING) PARTITIONED BY(p1 
String, p2 string)
+                       |TBLPROPERTIES ($prop)
+                       |""".stripMargin)
+
+          spark
+            .range(3)
+            .selectExpr("id as c1", "id as c2", "'a' as p1", "id as p2")
+            .writeTo("t")
+            .overwrite($"p1" === "a")
+          checkAnswer(
+            spark.sql("SELECT * FROM t ORDER BY c1"),
+            Row(0, "0", "a", "0") :: Row(1, "1", "a", "1") :: Row(2, "2", "a", 
"2") :: Nil
+          )
+
+          spark
+            .range(7, 10)
+            .selectExpr("id as c1", "id as c2", "'a' as p1", "id as p2")
+            .writeTo("t")
+            .overwrite($"p1" === "a")
+          checkAnswer(
+            spark.sql("SELECT * FROM t ORDER BY c1"),
+            Row(7, "7", "a", "7") :: Row(8, "8", "a", "8") :: Row(9, "9", "a", 
"9") :: Nil
+          )
+
+          spark
+            .range(2)
+            .selectExpr("id as c1", "id as c2", "'a' as p1", "9 as p2")
+            .writeTo("t")
+            .overwrite(($"p1" <=> "a").and($"p2" === "9"))
+          checkAnswer(
+            spark.sql("SELECT * FROM t ORDER BY c1"),
+            Row(0, "0", "a", "9") :: Row(1, "1", "a", "9") :: Row(7, "7", "a", 
"7") ::
+              Row(8, "8", "a", "8") :: Nil
+          )
+
+          // bad case
+          val msg1 = intercept[Exception] {
+            spark
+              .range(2)
+              .selectExpr("id as c1", "id as c2", "'a' as p1", "id as p2")
+              .writeTo("t")
+              .overwrite($"p1" =!= "a")
+          }.getMessage
+          assert(msg1.contains("Only support Overwrite filters with Equal and 
EqualNullSafe"))
+
+          val msg2 = intercept[Exception] {
+            spark
+              .range(2)
+              .selectExpr("id as c1", "id as c2", "'a' as p1", "id as p2")
+              .writeTo("t")
+              .overwrite($"p1" === $"c2")
+          }.getMessage
+          assert(msg2.contains("Table does not support overwrite by 
expression"))
+
+          val msg3 = intercept[Exception] {
+            spark
+              .range(2)
+              .selectExpr("id as c1", "id as c2", "'a' as p1", "id as p2")
+              .writeTo("t")
+              .overwrite($"c1" === ($"c2" + 1))
+          }.getMessage
+          assert(msg3.contains("cannot translate expression to source filter"))
+
+          val msg4 = intercept[Exception] {
+            spark
+              .range(2)
+              .selectExpr("id as c1", "id as c2", "'a' as p1", "id as p2")
+              .writeTo("t")
+              .overwrite(($"p1" === "a").and($"p1" === "b"))
+          }.getMessage
+          assert(msg4.contains("Only support Overwrite with one filter for 
each partition column"))
+
+          // Overwrite a partition which is not the specified
+          val msg5 = intercept[Exception] {
+            spark
+              .range(2)
+              .selectExpr("id as c1", "id as c2", "'a' as p1", "id as p2")
+              .writeTo("t")
+              .overwrite($"p1" === "b")
+          }.getMessage
+          assert(msg5.contains("does not belong to this partition"))
+        }
+      }
+  }
 }

Reply via email to