This is an automated email from the ASF dual-hosted git repository.

lzljs3620320 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-paimon.git


The following commit(s) were added to refs/heads/master by this push:
     new 0b7332cab [spark] SQL: supports INSERT OVERWRITE syntax (#1459)
0b7332cab is described below

commit 0b7332cab22aa4a4fb39fa0992973f1b7db8c762
Author: Yann Byron <[email protected]>
AuthorDate: Mon Jul 3 19:20:55 2023 +0800

    [spark] SQL: supports INSERT OVERWRITE syntax (#1459)
---
 paimon-spark/paimon-spark-common/pom.xml           |  74 +++++++
 .../apache/paimon/spark/SparkFilterConverter.java  |   6 +
 .../java/org/apache/paimon/spark/SparkTable.java   |   1 +
 .../org/apache/paimon/spark/SaveMode.scala}        |  29 +--
 .../scala/org/apache/paimon/spark/SparkWrite.scala |   5 +-
 .../{SparkWrite.scala => SparkWriteBuilder.scala}  |  28 +--
 .../paimon/spark/commands/PaimonCommand.scala      |  56 +++++-
 .../spark/commands/WriteIntoPaimonTable.scala      |  46 ++++-
 .../paimon/spark/sql/InsertOverwriteTest.scala     | 221 +++++++++++++++++++++
 .../paimon/spark/sql/PaimonSparkTestBase.scala     |  66 ++++++
 .../scala/org/apache/spark/paimon/Utils.scala}     |  26 +--
 11 files changed, 490 insertions(+), 68 deletions(-)

diff --git a/paimon-spark/paimon-spark-common/pom.xml 
b/paimon-spark/paimon-spark-common/pom.xml
index f7d8a5223..c3c647086 100644
--- a/paimon-spark/paimon-spark-common/pom.xml
+++ b/paimon-spark/paimon-spark-common/pom.xml
@@ -99,6 +99,80 @@ under the License.
                 </exclusion>
             </exclusions>
         </dependency>
+
+        <dependency>
+            <groupId>org.apache.spark</groupId>
+            <artifactId>spark-sql_2.12</artifactId>
+            <version>${spark.version}</version>
+            <classifier>tests</classifier>
+            <scope>test</scope>
+            <exclusions>
+                <exclusion>
+                    <groupId>log4j</groupId>
+                    <artifactId>log4j</artifactId>
+                </exclusion>
+                <exclusion>
+                    <groupId>org.slf4j</groupId>
+                    <artifactId>slf4j-log4j12</artifactId>
+                </exclusion>
+                <exclusion>
+                    <groupId>org.apache.orc</groupId>
+                    <artifactId>orc-core</artifactId>
+                </exclusion>
+            </exclusions>
+        </dependency>
+        <dependency>
+            <groupId>org.apache.spark</groupId>
+            <artifactId>spark-core_2.12</artifactId>
+            <version>${spark.version}</version>
+            <classifier>tests</classifier>
+            <scope>test</scope>
+            <exclusions>
+                <exclusion>
+                    <groupId>log4j</groupId>
+                    <artifactId>log4j</artifactId>
+                </exclusion>
+                <exclusion>
+                    <groupId>org.slf4j</groupId>
+                    <artifactId>slf4j-log4j12</artifactId>
+                </exclusion>
+                <exclusion>
+                    <groupId>org.apache.logging.log4j</groupId>
+                    <artifactId>log4j-slf4j2-impl</artifactId>
+                </exclusion>
+                <exclusion>
+                    <groupId>org.apache.orc</groupId>
+                    <artifactId>orc-core</artifactId>
+                </exclusion>
+            </exclusions>
+        </dependency>
+        <dependency>
+            <groupId>org.apache.spark</groupId>
+            <artifactId>spark-catalyst_2.12</artifactId>
+            <version>${spark.version}</version>
+            <classifier>tests</classifier>
+            <scope>test</scope>
+            <exclusions>
+                <exclusion>
+                    <groupId>log4j</groupId>
+                    <artifactId>log4j</artifactId>
+                </exclusion>
+                <exclusion>
+                    <groupId>org.slf4j</groupId>
+                    <artifactId>slf4j-log4j12</artifactId>
+                </exclusion>
+                <exclusion>
+                    <groupId>org.apache.orc</groupId>
+                    <artifactId>orc-core</artifactId>
+                </exclusion>
+            </exclusions>
+        </dependency>
+        <dependency>
+            <groupId>org.scalatest</groupId>
+            <artifactId>scalatest_${scala.binary.version}</artifactId>
+            <version>3.1.0</version>
+            <scope>test</scope>
+        </dependency>
     </dependencies>
 
     <build>
diff --git 
a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkFilterConverter.java
 
b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkFilterConverter.java
index c6bab357f..4f7cee52c 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkFilterConverter.java
+++ 
b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkFilterConverter.java
@@ -129,6 +129,12 @@ public class SparkFilterConverter {
                 filter + " is unsupported. Support Filters: " + 
SUPPORT_FILTERS);
     }
 
+    public Object convertLiteral(String field, Object value) {
+        int index = fieldIndex(field);
+        DataType type = rowType.getTypeAt(index);
+        return convertJavaObject(type, value);
+    }
+
     private int fieldIndex(String field) {
         int index = rowType.getFieldIndex(field);
         // TODO: support nested field
diff --git 
a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkTable.java
 
b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkTable.java
index 3b7939994..180e1fcec 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkTable.java
+++ 
b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkTable.java
@@ -79,6 +79,7 @@ public class SparkTable
         Set<TableCapability> capabilities = new HashSet<>();
         capabilities.add(TableCapability.BATCH_READ);
         capabilities.add(TableCapability.V1_BATCH_WRITE);
+        capabilities.add(TableCapability.OVERWRITE_BY_FILTER);
         return capabilities;
     }
 
diff --git 
a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkWriteBuilder.java
 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SaveMode.scala
similarity index 56%
copy from 
paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkWriteBuilder.java
copy to 
paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SaveMode.scala
index 876a6cf61..b4230d00f 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkWriteBuilder.java
+++ 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SaveMode.scala
@@ -7,7 +7,7 @@
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
  *
- * http://www.apache.org/licenses/LICENSE-2.0
+ *     http://www.apache.org/licenses/LICENSE-2.0
  *
  * Unless required by applicable law or agreed to in writing, software
  * distributed under the License is distributed on an "AS IS" BASIS,
@@ -15,29 +15,14 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
+package org.apache.paimon.spark
 
-package org.apache.paimon.spark;
+import org.apache.spark.sql.sources.Filter
 
-import org.apache.paimon.table.FileStoreTable;
+sealed trait SaveMode extends Serializable
 
-import org.apache.spark.sql.connector.write.Write;
-import org.apache.spark.sql.connector.write.WriteBuilder;
+object InsertInto extends SaveMode
 
-/**
- * Spark {@link WriteBuilder}.
- *
- * <p>TODO: Support overwrite.
- */
-public class SparkWriteBuilder implements WriteBuilder {
-
-    private final FileStoreTable table;
-
-    public SparkWriteBuilder(FileStoreTable table) {
-        this.table = table;
-    }
+case class Overwrite(filters: Option[Filter]) extends SaveMode
 
-    @Override
-    public Write build() {
-        return new SparkWrite(table);
-    }
-}
+object DynamicOverWrite extends SaveMode
diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkWrite.scala
 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkWrite.scala
index 61962bb95..44b017414 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkWrite.scala
+++ 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkWrite.scala
@@ -17,7 +17,6 @@
  */
 package org.apache.paimon.spark
 
-import org.apache.paimon.operation.Lock.Factory
 import org.apache.paimon.spark.commands.WriteIntoPaimonTable
 import org.apache.paimon.table.FileStoreTable
 
@@ -26,12 +25,12 @@ import org.apache.spark.sql.connector.write.V1Write
 import org.apache.spark.sql.sources.InsertableRelation
 
 /** Spark {@link V1Write}, it is required to use v1 write for grouping by 
bucket. */
-class SparkWrite(val table: FileStoreTable) extends V1Write {
+class SparkWrite(val table: FileStoreTable, saveMode: SaveMode) extends 
V1Write {
 
   override def toInsertableRelation: InsertableRelation = {
     (data: DataFrame, overwrite: Boolean) =>
       {
-        WriteIntoPaimonTable(table, overwrite, data).run(data.sparkSession)
+        WriteIntoPaimonTable(table, saveMode, data).run(data.sparkSession)
       }
   }
 }
diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkWrite.scala
 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkWriteBuilder.scala
similarity index 59%
copy from 
paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkWrite.scala
copy to 
paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkWriteBuilder.scala
index 61962bb95..5cf429af1 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkWrite.scala
+++ 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkWriteBuilder.scala
@@ -17,21 +17,25 @@
  */
 package org.apache.paimon.spark
 
-import org.apache.paimon.operation.Lock.Factory
-import org.apache.paimon.spark.commands.WriteIntoPaimonTable
 import org.apache.paimon.table.FileStoreTable
 
-import org.apache.spark.sql.DataFrame
-import org.apache.spark.sql.connector.write.V1Write
-import org.apache.spark.sql.sources.InsertableRelation
+import org.apache.spark.sql.connector.write.{SupportsDynamicOverwrite, 
SupportsOverwrite, WriteBuilder}
+import org.apache.spark.sql.sources.{And, Filter}
 
-/** Spark {@link V1Write}, it is required to use v1 write for grouping by 
bucket. */
-class SparkWrite(val table: FileStoreTable) extends V1Write {
+private class SparkWriteBuilder(table: FileStoreTable) extends WriteBuilder 
with SupportsOverwrite {
 
-  override def toInsertableRelation: InsertableRelation = {
-    (data: DataFrame, overwrite: Boolean) =>
-      {
-        WriteIntoPaimonTable(table, overwrite, data).run(data.sparkSession)
-      }
+  private var saveMode: SaveMode = InsertInto
+
+  override def build = new SparkWrite(table, saveMode)
+
+  override def overwrite(filters: Array[Filter]): WriteBuilder = {
+    val conjunctiveFilters = if (filters.nonEmpty) {
+      Some(filters.reduce((l, r) => And(l, r)))
+    } else {
+      None
+    }
+    this.saveMode = Overwrite(conjunctiveFilters)
+    this
   }
+
 }
diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonCommand.scala
 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonCommand.scala
index a22e3d559..ea951bb29 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonCommand.scala
+++ 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonCommand.scala
@@ -17,21 +17,27 @@
  */
 package org.apache.paimon.spark.commands
 
+import org.apache.paimon.predicate.PredicateBuilder
+import org.apache.paimon.spark.SparkFilterConverter
 import org.apache.paimon.table.{BucketMode, FileStoreTable, Table}
 import org.apache.paimon.table.sink.{CommitMessage, CommitMessageSerializer}
+import org.apache.paimon.types.RowType
+
+import org.apache.spark.sql.catalyst.analysis.Resolver
+import org.apache.spark.sql.sources.{AlwaysTrue, And, EqualNullSafe, Filter, 
Not, Or}
 
 import java.io.IOException
 
 /** Helper trait for all paimon commands. */
 trait PaimonCommand {
 
-  val table: Table
-
   val BUCKET_COL = "_bucket_"
 
+  def getTable: Table
+
   def isDynamicBucketTable: Boolean = {
-    table.isInstanceOf[FileStoreTable] &&
-    table.asInstanceOf[FileStoreTable].bucketMode == BucketMode.DYNAMIC
+    getTable.isInstanceOf[FileStoreTable] &&
+    getTable.asInstanceOf[FileStoreTable].bucketMode == BucketMode.DYNAMIC
   }
 
   def deserializeCommitMessage(
@@ -44,4 +50,46 @@ trait PaimonCommand {
         throw new RuntimeException("Failed to deserialize CommitMessage's 
object", e)
     }
   }
+
+  /**
+   * For the 'INSERT OVERWRITE' semantics of SQL, Spark DataSourceV2 will call 
the `truncate`
+   * methods where the `AlwaysTrue` Filter is used.
+   */
+  def isTruncate(filter: Filter): Boolean = {
+    val filters = splitConjunctiveFilters(filter)
+    filters.length == 1 && filters.head.isInstanceOf[AlwaysTrue]
+  }
+
+  /**
+   * For the 'INSERT OVERWRITE T PARTITION (partitionVal, ...)' semantics of 
SQL, Spark will
+   * transform `partitionVal`s to EqualNullSafe Filters.
+   */
+  def convertFilterToMap(filter: Filter, partitionRowType: RowType): 
Map[String, String] = {
+    val converter = new SparkFilterConverter(partitionRowType)
+    splitConjunctiveFilters(filter).map {
+      case EqualNullSafe(attribute, value) =>
+        if (isNestedFilterInValue(value)) {
+          throw new RuntimeException(
+            s"Not support the complex partition value in EqualNullSafe when 
run `INSERT OVERWRITE`.")
+        } else {
+          (attribute, converter.convertLiteral(attribute, value).toString)
+        }
+      case _ =>
+        throw new RuntimeException(
+          s"Only EqualNullSafe should be used when run `INSERT OVERWRITE`.")
+    }.toMap
+  }
+
+  def splitConjunctiveFilters(filter: Filter): Seq[Filter] = {
+    filter match {
+      case And(filter1, filter2) =>
+        splitConjunctiveFilters(filter1) ++ splitConjunctiveFilters(filter2)
+      case other => other :: Nil
+    }
+  }
+
+  def isNestedFilterInValue(value: Any): Boolean = {
+    value.isInstanceOf[Filter]
+  }
+
 }
diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/WriteIntoPaimonTable.scala
 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/WriteIntoPaimonTable.scala
index 06c9f0660..29f56776a 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/WriteIntoPaimonTable.scala
+++ 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/WriteIntoPaimonTable.scala
@@ -17,12 +17,14 @@
  */
 package org.apache.paimon.spark.commands
 
+import org.apache.paimon.CoreOptions.DYNAMIC_PARTITION_OVERWRITE
 import org.apache.paimon.data.BinaryRow
 import org.apache.paimon.index.PartitionIndex
+import org.apache.paimon.spark.{DynamicOverWrite, InsertInto, Overwrite, 
SaveMode}
 import org.apache.paimon.spark.SparkRow
 import org.apache.paimon.spark.SparkUtils.createIOManager
-import org.apache.paimon.table.FileStoreTable
-import org.apache.paimon.table.sink.{BatchWriteBuilder, 
CommitMessageSerializer, DynamicBucketRow, InnerTableCommit, 
RowPartitionKeyExtractor}
+import org.apache.paimon.table.{FileStoreTable, Table}
+import org.apache.paimon.table.sink.{BatchWriteBuilder, 
CommitMessageSerializer, DynamicBucketRow, RowPartitionKeyExtractor}
 import org.apache.paimon.types.RowType
 
 import org.apache.spark.TaskContext
@@ -38,12 +40,14 @@ import scala.collection.JavaConverters._
 import scala.collection.mutable
 
 /** Used to write a [[DataFrame]] into a paimon table. */
-case class WriteIntoPaimonTable(table: FileStoreTable, overwrite: Boolean, 
data: DataFrame)
+case class WriteIntoPaimonTable(_table: FileStoreTable, saveMode: SaveMode, 
data: DataFrame)
   extends RunnableCommand
   with PaimonCommand {
 
   import WriteIntoPaimonTable._
 
+  private var table = _table
+
   private lazy val tableSchema = table.schema()
 
   private lazy val rowType = table.rowType()
@@ -52,13 +56,14 @@ case class WriteIntoPaimonTable(table: FileStoreTable, 
overwrite: Boolean, data:
 
   private lazy val serializer = new CommitMessageSerializer
 
-  if (overwrite) {
-    throw new UnsupportedOperationException("Overwrite is unsupported.");
-  }
-
   override def run(sparkSession: SparkSession): Seq[Row] = {
     import sparkSession.implicits._
 
+    val (dynamicPartitionOverwriteMode, overwritePartition) = parseSaveMode()
+    // use the extra options to rebuild the table object
+    table = table.copy(
+      Map(DYNAMIC_PARTITION_OVERWRITE.key() -> 
dynamicPartitionOverwriteMode.toString).asJava)
+
     val primaryKeyCols = tableSchema.trimmedPrimaryKeys().asScala.map(col)
     val partitionCols = tableSchema.partitionKeys().asScala.map(col)
 
@@ -113,7 +118,11 @@ case class WriteIntoPaimonTable(table: FileStoreTable, 
overwrite: Boolean, data:
         .map(deserializeCommitMessage(serializer, _))
 
     try {
-      val tableCommit = writeBuilder.newCommit().asInstanceOf[InnerTableCommit]
+      val tableCommit = if (overwritePartition == null) {
+        writeBuilder.newCommit()
+      } else {
+        writeBuilder.withOverwrite(overwritePartition.asJava).newCommit()
+      }
       tableCommit.commit(commitMessages.toList.asJava)
     } catch {
       case e: Throwable => throw new RuntimeException(e);
@@ -122,8 +131,29 @@ case class WriteIntoPaimonTable(table: FileStoreTable, 
overwrite: Boolean, data:
     Seq.empty
   }
 
+  private def parseSaveMode(): (Boolean, Map[String, String]) = {
+    var dynamicPartitionOverwriteMode = false
+    val overwritePartition = saveMode match {
+      case InsertInto => null
+      case Overwrite(filter) =>
+        if (filter.isEmpty) {
+          Map.empty[String, String]
+        } else if (isTruncate(filter.get)) {
+          Map.empty[String, String]
+        } else {
+          convertFilterToMap(filter.get, tableSchema.logicalPartitionType())
+        }
+      case DynamicOverWrite =>
+        dynamicPartitionOverwriteMode = true
+        throw new UnsupportedOperationException("Dynamic Overwrite is 
unsupported for now.")
+    }
+    (dynamicPartitionOverwriteMode, overwritePartition)
+  }
+
   override def withNewChildrenInternal(newChildren: IndexedSeq[LogicalPlan]): 
LogicalPlan =
     this.asInstanceOf[WriteIntoPaimonTable]
+
+  override def getTable: Table = table
 }
 
 object WriteIntoPaimonTable {
diff --git 
a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/InsertOverwriteTest.scala
 
b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/InsertOverwriteTest.scala
new file mode 100644
index 000000000..832a96d83
--- /dev/null
+++ 
b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/InsertOverwriteTest.scala
@@ -0,0 +1,221 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.paimon.spark.sql
+
+import org.apache.paimon.WriteMode.{APPEND_ONLY, CHANGE_LOG}
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.types._
+import org.scalactic.source.Position
+
+class InsertOverwriteTest extends PaimonSparkTestBase {
+
+  // 3: fixed bucket, -1: dynamic bucket
+  private val bucketModes = Seq(3, -1)
+
+  Seq(APPEND_ONLY, CHANGE_LOG).foreach {
+    writeMode =>
+      bucketModes.foreach {
+        bucket =>
+          test(s"insert overwrite non-partitioned table: write-mode: 
$writeMode, bucket: $bucket") {
+            val primaryKeysProp = if (writeMode == CHANGE_LOG) {
+              "'primary-key'='a,b',"
+            } else {
+              ""
+            }
+
+            spark.sql(
+              s"""
+                 |CREATE TABLE T (a INT, b INT, c STRING)
+                 |TBLPROPERTIES ($primaryKeysProp 
'write-mode'='${writeMode.toString}', 'bucket'='$bucket')
+                 |""".stripMargin)
+
+            spark.sql("INSERT INTO T values (1, 1, '1'), (2, 2, '2')")
+            checkAnswer(
+              spark.sql("SELECT * FROM T ORDER BY a, b"),
+              Row(1, 1, "1") :: Row(2, 2, "2") :: Nil)
+
+            spark.sql("INSERT OVERWRITE T VALUES (1, 3, '3'), (2, 4, '4')");
+            checkAnswer(
+              spark.sql("SELECT * FROM T ORDER BY a, b"),
+              Row(1, 3, "3") :: Row(2, 4, "4") :: Nil)
+          }
+      }
+  }
+
+  Seq(APPEND_ONLY, CHANGE_LOG).foreach {
+    writeMode =>
+      bucketModes.foreach {
+        bucket =>
+          test(
+            s"insert overwrite single-partitioned table: write-mode: 
$writeMode, bucket: $bucket") {
+            val primaryKeysProp = if (writeMode == CHANGE_LOG) {
+              "'primary-key'='a,b',"
+            } else {
+              ""
+            }
+
+            spark.sql(
+              s"""
+                 |CREATE TABLE T (a INT, b INT, c STRING)
+                 |TBLPROPERTIES ($primaryKeysProp 
'write-mode'='${writeMode.toString}', 'bucket'='$bucket')
+                 |PARTITIONED BY (a)
+                 |""".stripMargin)
+
+            spark.sql("INSERT INTO T values (1, 1, '1'), (2, 2, '2')")
+            checkAnswer(
+              spark.sql("SELECT * FROM T ORDER BY a, b"),
+              Row(1, 1, "1") :: Row(2, 2, "2") :: Nil)
+
+            // overwrite the whole table
+            spark.sql("INSERT OVERWRITE T VALUES (1, 3, '3'), (2, 4, '4')")
+            checkAnswer(
+              spark.sql("SELECT * FROM T ORDER BY a, b"),
+              Row(1, 3, "3") :: Row(2, 4, "4") :: Nil)
+
+            // overwrite the a=1 partition
+            spark.sql("INSERT OVERWRITE T PARTITION (a = 1) VALUES (5, '5'), 
(7, '7')")
+            checkAnswer(
+              spark.sql("SELECT * FROM T ORDER BY a, b"),
+              Row(1, 5, "5") :: Row(1, 7, "7") :: Row(2, 4, "4") :: Nil)
+
+          }
+      }
+  }
+
+  Seq(APPEND_ONLY, CHANGE_LOG).foreach {
+    writeMode =>
+      bucketModes.foreach {
+        bucket =>
+          test(
+            s"insert overwrite mutil-partitioned table: write-mode: 
$writeMode, bucket: $bucket") {
+            val primaryKeysProp = if (writeMode == CHANGE_LOG) {
+              "'primary-key'='a,pt1,pt2',"
+            } else {
+              ""
+            }
+
+            spark.sql(
+              s"""
+                 |CREATE TABLE T (a INT, b STRING, pt1 STRING, pt2 INT)
+                 |TBLPROPERTIES ($primaryKeysProp 
'write-mode'='${writeMode.toString}', 'bucket'='$bucket')
+                 |PARTITIONED BY (pt1, pt2)
+                 |""".stripMargin)
+
+            spark.sql(
+              "INSERT INTO T values (1, 'a', 'ptv1', 11), (2, 'b', 'ptv1', 
11), (3, 'c', 'ptv1', 22), (4, 'd', 'ptv2', 22)")
+            checkAnswer(
+              spark.sql("SELECT * FROM T ORDER BY a"),
+              Row(1, "a", "ptv1", 11) :: Row(2, "b", "ptv1", 11) :: Row(3, 
"c", "ptv1", 22) :: Row(
+                4,
+                "d",
+                "ptv2",
+                22) :: Nil)
+
+            // overwrite the pt2=22 partition
+            spark.sql(
+              "INSERT OVERWRITE T PARTITION (pt2 = 22) VALUES (3, 'c2', 
'ptv1'), (4, 'd2', 'ptv3')")
+            checkAnswer(
+              spark.sql("SELECT * FROM T ORDER BY a"),
+              Row(1, "a", "ptv1", 11) :: Row(2, "b", "ptv1", 11) :: Row(3, 
"c2", "ptv1", 22) :: Row(
+                4,
+                "d2",
+                "ptv3",
+                22) :: Nil)
+
+            // overwrite the pt1=ptv3 partition
+            spark.sql("INSERT OVERWRITE T PARTITION (pt1 = 'ptv3') VALUES (4, 
'd3', 22)")
+            checkAnswer(
+              spark.sql("SELECT * FROM T ORDER BY a"),
+              Row(1, "a", "ptv1", 11) :: Row(2, "b", "ptv1", 11) :: Row(3, 
"c2", "ptv1", 22) :: Row(
+                4,
+                "d3",
+                "ptv3",
+                22) :: Nil)
+
+            // overwrite the pt1=ptv1, pt2=11 partition
+            spark.sql("INSERT OVERWRITE T PARTITION (pt1 = 'ptv1', pt2=11) 
VALUES (5, 'e')")
+            checkAnswer(
+              spark.sql("SELECT * FROM T ORDER BY a"),
+              Row(3, "c2", "ptv1", 22) :: Row(4, "d3", "ptv3", 22) :: Row(
+                5,
+                "e",
+                "ptv1",
+                11) :: Nil)
+
+            // overwrite the whole table
+            spark.sql(
+              "INSERT OVERWRITE T VALUES (1, 'a5', 'ptv1', 11), (3, 'c5', 
'ptv1', 22), (4, 'd5', 'ptv3', 22)")
+            checkAnswer(
+              spark.sql("SELECT * FROM T ORDER BY a"),
+              Row(1, "a5", "ptv1", 11) :: Row(3, "c5", "ptv1", 22) :: Row(
+                4,
+                "d5",
+                "ptv3",
+                22) :: Nil)
+          }
+      }
+  }
+
+  // These cases that date/timestamp/bool is used as the partition field type 
are to be supported.
+  Seq(IntegerType, LongType, FloatType, DoubleType, DecimalType).foreach {
+    dataType =>
+      test(s"insert overwrite table using $dataType as the partition field 
type") {
+        case class PartitionSQLAndValue(sql: Any, value: Any)
+
+        val (ptField, sv1, sv2) = dataType match {
+          case IntegerType =>
+            ("INT", PartitionSQLAndValue(1, 1), PartitionSQLAndValue(2, 2))
+          case LongType =>
+            ("LONG", PartitionSQLAndValue(1L, 1L), PartitionSQLAndValue(2L, 
2L))
+          case FloatType =>
+            ("FLOAT", PartitionSQLAndValue(12.3f, 12.3f), 
PartitionSQLAndValue(45.6f, 45.6f))
+          case DoubleType =>
+            ("DOUBLE", PartitionSQLAndValue(12.3d, 12.3), 
PartitionSQLAndValue(45.6d, 45.6))
+          case DecimalType =>
+            (
+              "DECIMAL(5, 2)",
+              PartitionSQLAndValue(11.222, 11.22),
+              PartitionSQLAndValue(66.777, 66.78))
+        }
+
+        spark.sql(s"""
+                     |CREATE TABLE T (a INT, b STRING, pt $ptField)
+                     |PARTITIONED BY (pt)
+                     |""".stripMargin)
+
+        spark.sql(s"INSERT INTO T SELECT 1, 'a', ${sv1.sql} UNION ALL SELECT 
2, 'b', ${sv2.sql}")
+        checkAnswer(
+          spark.sql("SELECT * FROM T ORDER BY a"),
+          Row(1, "a", sv1.value) :: Row(2, "b", sv2.value) :: Nil)
+
+        // overwrite the whole table
+        spark.sql(
+          s"INSERT OVERWRITE T SELECT 3, 'c', ${sv1.sql} UNION ALL SELECT 4, 
'd', ${sv2.sql}")
+        checkAnswer(
+          spark.sql("SELECT * FROM T ORDER BY a"),
+          Row(3, "c", sv1.value) :: Row(4, "d", sv2.value) :: Nil)
+
+        // overwrite the a=1 partition
+        spark.sql(s"INSERT OVERWRITE T PARTITION (pt = ${sv1.value}) VALUES 
(5, 'e'), (7, 'g')")
+        checkAnswer(
+          spark.sql("SELECT * FROM T ORDER BY a"),
+          Row(4, "d", sv2.value) :: Row(5, "e", sv1.value) :: Row(7, "g", 
sv1.value) :: Nil)
+      }
+  }
+}
diff --git 
a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/PaimonSparkTestBase.scala
 
b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/PaimonSparkTestBase.scala
new file mode 100644
index 000000000..bc2c0f548
--- /dev/null
+++ 
b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/PaimonSparkTestBase.scala
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.paimon.spark.sql
+
+import org.apache.paimon.spark.SparkCatalog
+
+import org.apache.spark.paimon.Utils
+import org.apache.spark.sql.QueryTest
+import org.apache.spark.sql.test.SharedSparkSession
+import org.scalactic.source.Position
+import org.scalatest.Tag
+
+import java.io.File
+
+class PaimonSparkTestBase extends QueryTest with SharedSparkSession {
+
+  protected var tempDBDir: File = _
+
+  protected val tableName0: String = "T"
+
+  override protected def beforeAll(): Unit = {
+    super.beforeAll()
+
+    tempDBDir = Utils.createTempDir
+    spark.conf.set("spark.sql.catalog.paimon", classOf[SparkCatalog].getName)
+    spark.conf.set("spark.sql.catalog.paimon.warehouse", 
tempDBDir.getCanonicalPath)
+    spark.sql("CREATE DATABASE paimon.db")
+    spark.sql("USE paimon.db")
+    println(s"${tempDBDir.getCanonicalPath}")
+  }
+
+  override protected def afterAll(): Unit = {
+    try {
+      spark.sql("USE default")
+      spark.sql("DROP DATABASE paimon.db CASCADE")
+    } finally {
+      super.afterAll()
+    }
+  }
+
+  override protected def beforeEach(): Unit = {
+    super.beforeAll()
+    spark.sql(s"DROP TABLE IF EXISTS $tableName0")
+  }
+
+  override def test(testName: String, testTags: Tag*)(testFun: => Any)(implicit
+      pos: Position): Unit = {
+    println(testName)
+    super.test(testName, testTags: _*)(testFun)(pos)
+  }
+}
diff --git 
a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkWriteBuilder.java
 
b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/spark/paimon/Utils.scala
similarity index 57%
rename from 
paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkWriteBuilder.java
rename to 
paimon-spark/paimon-spark-common/src/test/scala/org/apache/spark/paimon/Utils.scala
index 876a6cf61..974bbf0c7 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkWriteBuilder.java
+++ 
b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/spark/paimon/Utils.scala
@@ -7,7 +7,7 @@
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
  *
- * http://www.apache.org/licenses/LICENSE-2.0
+ *     http://www.apache.org/licenses/LICENSE-2.0
  *
  * Unless required by applicable law or agreed to in writing, software
  * distributed under the License is distributed on an "AS IS" BASIS,
@@ -15,29 +15,17 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
+package org.apache.spark.paimon
 
-package org.apache.paimon.spark;
+import org.apache.spark.util.{Utils => SparkUtils}
 
-import org.apache.paimon.table.FileStoreTable;
-
-import org.apache.spark.sql.connector.write.Write;
-import org.apache.spark.sql.connector.write.WriteBuilder;
+import java.io.File
 
 /**
- * Spark {@link WriteBuilder}.
- *
- * <p>TODO: Support overwrite.
+ * A wrapper that some Objects or Classes is limited to access beyond 
[[org.apache.spark]] package.
  */
-public class SparkWriteBuilder implements WriteBuilder {
-
-    private final FileStoreTable table;
+object Utils {
 
-    public SparkWriteBuilder(FileStoreTable table) {
-        this.table = table;
-    }
+  def createTempDir: File = SparkUtils.createTempDir()
 
-    @Override
-    public Write build() {
-        return new SparkWrite(table);
-    }
 }

Reply via email to