This is an automated email from the ASF dual-hosted git repository.
lzljs3620320 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-paimon.git
The following commit(s) were added to refs/heads/master by this push:
new 0b7332cab [spark] SQL: supports INSERT OVERWRITE syntax (#1459)
0b7332cab is described below
commit 0b7332cab22aa4a4fb39fa0992973f1b7db8c762
Author: Yann Byron <[email protected]>
AuthorDate: Mon Jul 3 19:20:55 2023 +0800
[spark] SQL: supports INSERT OVERWRITE syntax (#1459)
---
paimon-spark/paimon-spark-common/pom.xml | 74 +++++++
.../apache/paimon/spark/SparkFilterConverter.java | 6 +
.../java/org/apache/paimon/spark/SparkTable.java | 1 +
.../org/apache/paimon/spark/SaveMode.scala} | 29 +--
.../scala/org/apache/paimon/spark/SparkWrite.scala | 5 +-
.../{SparkWrite.scala => SparkWriteBuilder.scala} | 28 +--
.../paimon/spark/commands/PaimonCommand.scala | 56 +++++-
.../spark/commands/WriteIntoPaimonTable.scala | 46 ++++-
.../paimon/spark/sql/InsertOverwriteTest.scala | 221 +++++++++++++++++++++
.../paimon/spark/sql/PaimonSparkTestBase.scala | 66 ++++++
.../scala/org/apache/spark/paimon/Utils.scala} | 26 +--
11 files changed, 490 insertions(+), 68 deletions(-)
diff --git a/paimon-spark/paimon-spark-common/pom.xml
b/paimon-spark/paimon-spark-common/pom.xml
index f7d8a5223..c3c647086 100644
--- a/paimon-spark/paimon-spark-common/pom.xml
+++ b/paimon-spark/paimon-spark-common/pom.xml
@@ -99,6 +99,80 @@ under the License.
</exclusion>
</exclusions>
</dependency>
+
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-sql_2.12</artifactId>
+ <version>${spark.version}</version>
+ <classifier>tests</classifier>
+ <scope>test</scope>
+ <exclusions>
+ <exclusion>
+ <groupId>log4j</groupId>
+ <artifactId>log4j</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>org.slf4j</groupId>
+ <artifactId>slf4j-log4j12</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>org.apache.orc</groupId>
+ <artifactId>orc-core</artifactId>
+ </exclusion>
+ </exclusions>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-core_2.12</artifactId>
+ <version>${spark.version}</version>
+ <classifier>tests</classifier>
+ <scope>test</scope>
+ <exclusions>
+ <exclusion>
+ <groupId>log4j</groupId>
+ <artifactId>log4j</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>org.slf4j</groupId>
+ <artifactId>slf4j-log4j12</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>org.apache.logging.log4j</groupId>
+ <artifactId>log4j-slf4j2-impl</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>org.apache.orc</groupId>
+ <artifactId>orc-core</artifactId>
+ </exclusion>
+ </exclusions>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-catalyst_2.12</artifactId>
+ <version>${spark.version}</version>
+ <classifier>tests</classifier>
+ <scope>test</scope>
+ <exclusions>
+ <exclusion>
+ <groupId>log4j</groupId>
+ <artifactId>log4j</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>org.slf4j</groupId>
+ <artifactId>slf4j-log4j12</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>org.apache.orc</groupId>
+ <artifactId>orc-core</artifactId>
+ </exclusion>
+ </exclusions>
+ </dependency>
+ <dependency>
+ <groupId>org.scalatest</groupId>
+ <artifactId>scalatest_${scala.binary.version}</artifactId>
+ <version>3.1.0</version>
+ <scope>test</scope>
+ </dependency>
</dependencies>
<build>
diff --git
a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkFilterConverter.java
b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkFilterConverter.java
index c6bab357f..4f7cee52c 100644
---
a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkFilterConverter.java
+++
b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkFilterConverter.java
@@ -129,6 +129,12 @@ public class SparkFilterConverter {
filter + " is unsupported. Support Filters: " +
SUPPORT_FILTERS);
}
+ public Object convertLiteral(String field, Object value) {
+ int index = fieldIndex(field);
+ DataType type = rowType.getTypeAt(index);
+ return convertJavaObject(type, value);
+ }
+
private int fieldIndex(String field) {
int index = rowType.getFieldIndex(field);
// TODO: support nested field
diff --git
a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkTable.java
b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkTable.java
index 3b7939994..180e1fcec 100644
---
a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkTable.java
+++
b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkTable.java
@@ -79,6 +79,7 @@ public class SparkTable
Set<TableCapability> capabilities = new HashSet<>();
capabilities.add(TableCapability.BATCH_READ);
capabilities.add(TableCapability.V1_BATCH_WRITE);
+ capabilities.add(TableCapability.OVERWRITE_BY_FILTER);
return capabilities;
}
diff --git
a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkWriteBuilder.java
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SaveMode.scala
similarity index 56%
copy from
paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkWriteBuilder.java
copy to
paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SaveMode.scala
index 876a6cf61..b4230d00f 100644
---
a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkWriteBuilder.java
+++
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SaveMode.scala
@@ -7,7 +7,7 @@
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
- * http://www.apache.org/licenses/LICENSE-2.0
+ * http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
@@ -15,29 +15,14 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+package org.apache.paimon.spark
-package org.apache.paimon.spark;
+import org.apache.spark.sql.sources.Filter
-import org.apache.paimon.table.FileStoreTable;
+sealed trait SaveMode extends Serializable
-import org.apache.spark.sql.connector.write.Write;
-import org.apache.spark.sql.connector.write.WriteBuilder;
+object InsertInto extends SaveMode
-/**
- * Spark {@link WriteBuilder}.
- *
- * <p>TODO: Support overwrite.
- */
-public class SparkWriteBuilder implements WriteBuilder {
-
- private final FileStoreTable table;
-
- public SparkWriteBuilder(FileStoreTable table) {
- this.table = table;
- }
+case class Overwrite(filters: Option[Filter]) extends SaveMode
- @Override
- public Write build() {
- return new SparkWrite(table);
- }
-}
+object DynamicOverWrite extends SaveMode
diff --git
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkWrite.scala
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkWrite.scala
index 61962bb95..44b017414 100644
---
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkWrite.scala
+++
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkWrite.scala
@@ -17,7 +17,6 @@
*/
package org.apache.paimon.spark
-import org.apache.paimon.operation.Lock.Factory
import org.apache.paimon.spark.commands.WriteIntoPaimonTable
import org.apache.paimon.table.FileStoreTable
@@ -26,12 +25,12 @@ import org.apache.spark.sql.connector.write.V1Write
import org.apache.spark.sql.sources.InsertableRelation
/** Spark {@link V1Write}, it is required to use v1 write for grouping by
bucket. */
-class SparkWrite(val table: FileStoreTable) extends V1Write {
+class SparkWrite(val table: FileStoreTable, saveMode: SaveMode) extends
V1Write {
override def toInsertableRelation: InsertableRelation = {
(data: DataFrame, overwrite: Boolean) =>
{
- WriteIntoPaimonTable(table, overwrite, data).run(data.sparkSession)
+ WriteIntoPaimonTable(table, saveMode, data).run(data.sparkSession)
}
}
}
diff --git
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkWrite.scala
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkWriteBuilder.scala
similarity index 59%
copy from
paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkWrite.scala
copy to
paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkWriteBuilder.scala
index 61962bb95..5cf429af1 100644
---
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkWrite.scala
+++
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkWriteBuilder.scala
@@ -17,21 +17,25 @@
*/
package org.apache.paimon.spark
-import org.apache.paimon.operation.Lock.Factory
-import org.apache.paimon.spark.commands.WriteIntoPaimonTable
import org.apache.paimon.table.FileStoreTable
-import org.apache.spark.sql.DataFrame
-import org.apache.spark.sql.connector.write.V1Write
-import org.apache.spark.sql.sources.InsertableRelation
+import org.apache.spark.sql.connector.write.{SupportsDynamicOverwrite,
SupportsOverwrite, WriteBuilder}
+import org.apache.spark.sql.sources.{And, Filter}
-/** Spark {@link V1Write}, it is required to use v1 write for grouping by
bucket. */
-class SparkWrite(val table: FileStoreTable) extends V1Write {
+private class SparkWriteBuilder(table: FileStoreTable) extends WriteBuilder
with SupportsOverwrite {
- override def toInsertableRelation: InsertableRelation = {
- (data: DataFrame, overwrite: Boolean) =>
- {
- WriteIntoPaimonTable(table, overwrite, data).run(data.sparkSession)
- }
+ private var saveMode: SaveMode = InsertInto
+
+ override def build = new SparkWrite(table, saveMode)
+
+ override def overwrite(filters: Array[Filter]): WriteBuilder = {
+ val conjunctiveFilters = if (filters.nonEmpty) {
+ Some(filters.reduce((l, r) => And(l, r)))
+ } else {
+ None
+ }
+ this.saveMode = Overwrite(conjunctiveFilters)
+ this
}
+
}
diff --git
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonCommand.scala
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonCommand.scala
index a22e3d559..ea951bb29 100644
---
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonCommand.scala
+++
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/PaimonCommand.scala
@@ -17,21 +17,27 @@
*/
package org.apache.paimon.spark.commands
+import org.apache.paimon.predicate.PredicateBuilder
+import org.apache.paimon.spark.SparkFilterConverter
import org.apache.paimon.table.{BucketMode, FileStoreTable, Table}
import org.apache.paimon.table.sink.{CommitMessage, CommitMessageSerializer}
+import org.apache.paimon.types.RowType
+
+import org.apache.spark.sql.catalyst.analysis.Resolver
+import org.apache.spark.sql.sources.{AlwaysTrue, And, EqualNullSafe, Filter,
Not, Or}
import java.io.IOException
/** Helper trait for all paimon commands. */
trait PaimonCommand {
- val table: Table
-
val BUCKET_COL = "_bucket_"
+ def getTable: Table
+
def isDynamicBucketTable: Boolean = {
- table.isInstanceOf[FileStoreTable] &&
- table.asInstanceOf[FileStoreTable].bucketMode == BucketMode.DYNAMIC
+ getTable.isInstanceOf[FileStoreTable] &&
+ getTable.asInstanceOf[FileStoreTable].bucketMode == BucketMode.DYNAMIC
}
def deserializeCommitMessage(
@@ -44,4 +50,46 @@ trait PaimonCommand {
throw new RuntimeException("Failed to deserialize CommitMessage's
object", e)
}
}
+
+ /**
+ * For the 'INSERT OVERWRITE' semantics of SQL, Spark DataSourceV2 will call
the `truncate`
+ * methods where the `AlwaysTrue` Filter is used.
+ */
+ def isTruncate(filter: Filter): Boolean = {
+ val filters = splitConjunctiveFilters(filter)
+ filters.length == 1 && filters.head.isInstanceOf[AlwaysTrue]
+ }
+
+ /**
+ * For the 'INSERT OVERWRITE T PARTITION (partitionVal, ...)' semantics of
SQL, Spark will
+ * transform `partitionVal`s to EqualNullSafe Filters.
+ */
+ def convertFilterToMap(filter: Filter, partitionRowType: RowType):
Map[String, String] = {
+ val converter = new SparkFilterConverter(partitionRowType)
+ splitConjunctiveFilters(filter).map {
+ case EqualNullSafe(attribute, value) =>
+ if (isNestedFilterInValue(value)) {
+ throw new RuntimeException(
+ s"Not support the complex partition value in EqualNullSafe when
run `INSERT OVERWRITE`.")
+ } else {
+ (attribute, converter.convertLiteral(attribute, value).toString)
+ }
+ case _ =>
+ throw new RuntimeException(
+ s"Only EqualNullSafe should be used when run `INSERT OVERWRITE`.")
+ }.toMap
+ }
+
+ def splitConjunctiveFilters(filter: Filter): Seq[Filter] = {
+ filter match {
+ case And(filter1, filter2) =>
+ splitConjunctiveFilters(filter1) ++ splitConjunctiveFilters(filter2)
+ case other => other :: Nil
+ }
+ }
+
+ def isNestedFilterInValue(value: Any): Boolean = {
+ value.isInstanceOf[Filter]
+ }
+
}
diff --git
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/WriteIntoPaimonTable.scala
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/WriteIntoPaimonTable.scala
index 06c9f0660..29f56776a 100644
---
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/WriteIntoPaimonTable.scala
+++
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/WriteIntoPaimonTable.scala
@@ -17,12 +17,14 @@
*/
package org.apache.paimon.spark.commands
+import org.apache.paimon.CoreOptions.DYNAMIC_PARTITION_OVERWRITE
import org.apache.paimon.data.BinaryRow
import org.apache.paimon.index.PartitionIndex
+import org.apache.paimon.spark.{DynamicOverWrite, InsertInto, Overwrite,
SaveMode}
import org.apache.paimon.spark.SparkRow
import org.apache.paimon.spark.SparkUtils.createIOManager
-import org.apache.paimon.table.FileStoreTable
-import org.apache.paimon.table.sink.{BatchWriteBuilder,
CommitMessageSerializer, DynamicBucketRow, InnerTableCommit,
RowPartitionKeyExtractor}
+import org.apache.paimon.table.{FileStoreTable, Table}
+import org.apache.paimon.table.sink.{BatchWriteBuilder,
CommitMessageSerializer, DynamicBucketRow, RowPartitionKeyExtractor}
import org.apache.paimon.types.RowType
import org.apache.spark.TaskContext
@@ -38,12 +40,14 @@ import scala.collection.JavaConverters._
import scala.collection.mutable
/** Used to write a [[DataFrame]] into a paimon table. */
-case class WriteIntoPaimonTable(table: FileStoreTable, overwrite: Boolean,
data: DataFrame)
+case class WriteIntoPaimonTable(_table: FileStoreTable, saveMode: SaveMode,
data: DataFrame)
extends RunnableCommand
with PaimonCommand {
import WriteIntoPaimonTable._
+ private var table = _table
+
private lazy val tableSchema = table.schema()
private lazy val rowType = table.rowType()
@@ -52,13 +56,14 @@ case class WriteIntoPaimonTable(table: FileStoreTable,
overwrite: Boolean, data:
private lazy val serializer = new CommitMessageSerializer
- if (overwrite) {
- throw new UnsupportedOperationException("Overwrite is unsupported.");
- }
-
override def run(sparkSession: SparkSession): Seq[Row] = {
import sparkSession.implicits._
+ val (dynamicPartitionOverwriteMode, overwritePartition) = parseSaveMode()
+ // use the extra options to rebuild the table object
+ table = table.copy(
+ Map(DYNAMIC_PARTITION_OVERWRITE.key() ->
dynamicPartitionOverwriteMode.toString).asJava)
+
val primaryKeyCols = tableSchema.trimmedPrimaryKeys().asScala.map(col)
val partitionCols = tableSchema.partitionKeys().asScala.map(col)
@@ -113,7 +118,11 @@ case class WriteIntoPaimonTable(table: FileStoreTable,
overwrite: Boolean, data:
.map(deserializeCommitMessage(serializer, _))
try {
- val tableCommit = writeBuilder.newCommit().asInstanceOf[InnerTableCommit]
+ val tableCommit = if (overwritePartition == null) {
+ writeBuilder.newCommit()
+ } else {
+ writeBuilder.withOverwrite(overwritePartition.asJava).newCommit()
+ }
tableCommit.commit(commitMessages.toList.asJava)
} catch {
case e: Throwable => throw new RuntimeException(e);
@@ -122,8 +131,29 @@ case class WriteIntoPaimonTable(table: FileStoreTable,
overwrite: Boolean, data:
Seq.empty
}
+ private def parseSaveMode(): (Boolean, Map[String, String]) = {
+ var dynamicPartitionOverwriteMode = false
+ val overwritePartition = saveMode match {
+ case InsertInto => null
+ case Overwrite(filter) =>
+ if (filter.isEmpty) {
+ Map.empty[String, String]
+ } else if (isTruncate(filter.get)) {
+ Map.empty[String, String]
+ } else {
+ convertFilterToMap(filter.get, tableSchema.logicalPartitionType())
+ }
+ case DynamicOverWrite =>
+ dynamicPartitionOverwriteMode = true
+ throw new UnsupportedOperationException("Dynamic Overwrite is
unsupported for now.")
+ }
+ (dynamicPartitionOverwriteMode, overwritePartition)
+ }
+
override def withNewChildrenInternal(newChildren: IndexedSeq[LogicalPlan]):
LogicalPlan =
this.asInstanceOf[WriteIntoPaimonTable]
+
+ override def getTable: Table = table
}
object WriteIntoPaimonTable {
diff --git
a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/InsertOverwriteTest.scala
b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/InsertOverwriteTest.scala
new file mode 100644
index 000000000..832a96d83
--- /dev/null
+++
b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/InsertOverwriteTest.scala
@@ -0,0 +1,221 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.paimon.spark.sql
+
+import org.apache.paimon.WriteMode.{APPEND_ONLY, CHANGE_LOG}
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.types._
+import org.scalactic.source.Position
+
+class InsertOverwriteTest extends PaimonSparkTestBase {
+
+ // 3: fixed bucket, -1: dynamic bucket
+ private val bucketModes = Seq(3, -1)
+
+ Seq(APPEND_ONLY, CHANGE_LOG).foreach {
+ writeMode =>
+ bucketModes.foreach {
+ bucket =>
+ test(s"insert overwrite non-partitioned table: write-mode:
$writeMode, bucket: $bucket") {
+ val primaryKeysProp = if (writeMode == CHANGE_LOG) {
+ "'primary-key'='a,b',"
+ } else {
+ ""
+ }
+
+ spark.sql(
+ s"""
+ |CREATE TABLE T (a INT, b INT, c STRING)
+ |TBLPROPERTIES ($primaryKeysProp
'write-mode'='${writeMode.toString}', 'bucket'='$bucket')
+ |""".stripMargin)
+
+ spark.sql("INSERT INTO T values (1, 1, '1'), (2, 2, '2')")
+ checkAnswer(
+ spark.sql("SELECT * FROM T ORDER BY a, b"),
+ Row(1, 1, "1") :: Row(2, 2, "2") :: Nil)
+
+ spark.sql("INSERT OVERWRITE T VALUES (1, 3, '3'), (2, 4, '4')");
+ checkAnswer(
+ spark.sql("SELECT * FROM T ORDER BY a, b"),
+ Row(1, 3, "3") :: Row(2, 4, "4") :: Nil)
+ }
+ }
+ }
+
+ Seq(APPEND_ONLY, CHANGE_LOG).foreach {
+ writeMode =>
+ bucketModes.foreach {
+ bucket =>
+ test(
+ s"insert overwrite single-partitioned table: write-mode:
$writeMode, bucket: $bucket") {
+ val primaryKeysProp = if (writeMode == CHANGE_LOG) {
+ "'primary-key'='a,b',"
+ } else {
+ ""
+ }
+
+ spark.sql(
+ s"""
+ |CREATE TABLE T (a INT, b INT, c STRING)
+ |TBLPROPERTIES ($primaryKeysProp
'write-mode'='${writeMode.toString}', 'bucket'='$bucket')
+ |PARTITIONED BY (a)
+ |""".stripMargin)
+
+ spark.sql("INSERT INTO T values (1, 1, '1'), (2, 2, '2')")
+ checkAnswer(
+ spark.sql("SELECT * FROM T ORDER BY a, b"),
+ Row(1, 1, "1") :: Row(2, 2, "2") :: Nil)
+
+ // overwrite the whole table
+ spark.sql("INSERT OVERWRITE T VALUES (1, 3, '3'), (2, 4, '4')")
+ checkAnswer(
+ spark.sql("SELECT * FROM T ORDER BY a, b"),
+ Row(1, 3, "3") :: Row(2, 4, "4") :: Nil)
+
+ // overwrite the a=1 partition
+ spark.sql("INSERT OVERWRITE T PARTITION (a = 1) VALUES (5, '5'),
(7, '7')")
+ checkAnswer(
+ spark.sql("SELECT * FROM T ORDER BY a, b"),
+ Row(1, 5, "5") :: Row(1, 7, "7") :: Row(2, 4, "4") :: Nil)
+
+ }
+ }
+ }
+
+ Seq(APPEND_ONLY, CHANGE_LOG).foreach {
+ writeMode =>
+ bucketModes.foreach {
+ bucket =>
+ test(
+ s"insert overwrite mutil-partitioned table: write-mode:
$writeMode, bucket: $bucket") {
+ val primaryKeysProp = if (writeMode == CHANGE_LOG) {
+ "'primary-key'='a,pt1,pt2',"
+ } else {
+ ""
+ }
+
+ spark.sql(
+ s"""
+ |CREATE TABLE T (a INT, b STRING, pt1 STRING, pt2 INT)
+ |TBLPROPERTIES ($primaryKeysProp
'write-mode'='${writeMode.toString}', 'bucket'='$bucket')
+ |PARTITIONED BY (pt1, pt2)
+ |""".stripMargin)
+
+ spark.sql(
+ "INSERT INTO T values (1, 'a', 'ptv1', 11), (2, 'b', 'ptv1',
11), (3, 'c', 'ptv1', 22), (4, 'd', 'ptv2', 22)")
+ checkAnswer(
+ spark.sql("SELECT * FROM T ORDER BY a"),
+ Row(1, "a", "ptv1", 11) :: Row(2, "b", "ptv1", 11) :: Row(3,
"c", "ptv1", 22) :: Row(
+ 4,
+ "d",
+ "ptv2",
+ 22) :: Nil)
+
+ // overwrite the pt2=22 partition
+ spark.sql(
+ "INSERT OVERWRITE T PARTITION (pt2 = 22) VALUES (3, 'c2',
'ptv1'), (4, 'd2', 'ptv3')")
+ checkAnswer(
+ spark.sql("SELECT * FROM T ORDER BY a"),
+ Row(1, "a", "ptv1", 11) :: Row(2, "b", "ptv1", 11) :: Row(3,
"c2", "ptv1", 22) :: Row(
+ 4,
+ "d2",
+ "ptv3",
+ 22) :: Nil)
+
+ // overwrite the pt1=ptv3 partition
+ spark.sql("INSERT OVERWRITE T PARTITION (pt1 = 'ptv3') VALUES (4,
'd3', 22)")
+ checkAnswer(
+ spark.sql("SELECT * FROM T ORDER BY a"),
+ Row(1, "a", "ptv1", 11) :: Row(2, "b", "ptv1", 11) :: Row(3,
"c2", "ptv1", 22) :: Row(
+ 4,
+ "d3",
+ "ptv3",
+ 22) :: Nil)
+
+ // overwrite the pt1=ptv1, pt2=11 partition
+ spark.sql("INSERT OVERWRITE T PARTITION (pt1 = 'ptv1', pt2=11)
VALUES (5, 'e')")
+ checkAnswer(
+ spark.sql("SELECT * FROM T ORDER BY a"),
+ Row(3, "c2", "ptv1", 22) :: Row(4, "d3", "ptv3", 22) :: Row(
+ 5,
+ "e",
+ "ptv1",
+ 11) :: Nil)
+
+ // overwrite the whole table
+ spark.sql(
+ "INSERT OVERWRITE T VALUES (1, 'a5', 'ptv1', 11), (3, 'c5',
'ptv1', 22), (4, 'd5', 'ptv3', 22)")
+ checkAnswer(
+ spark.sql("SELECT * FROM T ORDER BY a"),
+ Row(1, "a5", "ptv1", 11) :: Row(3, "c5", "ptv1", 22) :: Row(
+ 4,
+ "d5",
+ "ptv3",
+ 22) :: Nil)
+ }
+ }
+ }
+
+ // These cases that date/timestamp/bool is used as the partition field type
are to be supported.
+ Seq(IntegerType, LongType, FloatType, DoubleType, DecimalType).foreach {
+ dataType =>
+ test(s"insert overwrite table using $dataType as the partition field
type") {
+ case class PartitionSQLAndValue(sql: Any, value: Any)
+
+ val (ptField, sv1, sv2) = dataType match {
+ case IntegerType =>
+ ("INT", PartitionSQLAndValue(1, 1), PartitionSQLAndValue(2, 2))
+ case LongType =>
+ ("LONG", PartitionSQLAndValue(1L, 1L), PartitionSQLAndValue(2L,
2L))
+ case FloatType =>
+ ("FLOAT", PartitionSQLAndValue(12.3f, 12.3f),
PartitionSQLAndValue(45.6f, 45.6f))
+ case DoubleType =>
+ ("DOUBLE", PartitionSQLAndValue(12.3d, 12.3),
PartitionSQLAndValue(45.6d, 45.6))
+ case DecimalType =>
+ (
+ "DECIMAL(5, 2)",
+ PartitionSQLAndValue(11.222, 11.22),
+ PartitionSQLAndValue(66.777, 66.78))
+ }
+
+ spark.sql(s"""
+ |CREATE TABLE T (a INT, b STRING, pt $ptField)
+ |PARTITIONED BY (pt)
+ |""".stripMargin)
+
+ spark.sql(s"INSERT INTO T SELECT 1, 'a', ${sv1.sql} UNION ALL SELECT
2, 'b', ${sv2.sql}")
+ checkAnswer(
+ spark.sql("SELECT * FROM T ORDER BY a"),
+ Row(1, "a", sv1.value) :: Row(2, "b", sv2.value) :: Nil)
+
+ // overwrite the whole table
+ spark.sql(
+ s"INSERT OVERWRITE T SELECT 3, 'c', ${sv1.sql} UNION ALL SELECT 4,
'd', ${sv2.sql}")
+ checkAnswer(
+ spark.sql("SELECT * FROM T ORDER BY a"),
+ Row(3, "c", sv1.value) :: Row(4, "d", sv2.value) :: Nil)
+
+ // overwrite the a=1 partition
+ spark.sql(s"INSERT OVERWRITE T PARTITION (pt = ${sv1.value}) VALUES
(5, 'e'), (7, 'g')")
+ checkAnswer(
+ spark.sql("SELECT * FROM T ORDER BY a"),
+ Row(4, "d", sv2.value) :: Row(5, "e", sv1.value) :: Row(7, "g",
sv1.value) :: Nil)
+ }
+ }
+}
diff --git
a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/PaimonSparkTestBase.scala
b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/PaimonSparkTestBase.scala
new file mode 100644
index 000000000..bc2c0f548
--- /dev/null
+++
b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/PaimonSparkTestBase.scala
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.paimon.spark.sql
+
+import org.apache.paimon.spark.SparkCatalog
+
+import org.apache.spark.paimon.Utils
+import org.apache.spark.sql.QueryTest
+import org.apache.spark.sql.test.SharedSparkSession
+import org.scalactic.source.Position
+import org.scalatest.Tag
+
+import java.io.File
+
+class PaimonSparkTestBase extends QueryTest with SharedSparkSession {
+
+ protected var tempDBDir: File = _
+
+ protected val tableName0: String = "T"
+
+ override protected def beforeAll(): Unit = {
+ super.beforeAll()
+
+ tempDBDir = Utils.createTempDir
+ spark.conf.set("spark.sql.catalog.paimon", classOf[SparkCatalog].getName)
+ spark.conf.set("spark.sql.catalog.paimon.warehouse",
tempDBDir.getCanonicalPath)
+ spark.sql("CREATE DATABASE paimon.db")
+ spark.sql("USE paimon.db")
+ println(s"${tempDBDir.getCanonicalPath}")
+ }
+
+ override protected def afterAll(): Unit = {
+ try {
+ spark.sql("USE default")
+ spark.sql("DROP DATABASE paimon.db CASCADE")
+ } finally {
+ super.afterAll()
+ }
+ }
+
+ override protected def beforeEach(): Unit = {
+ super.beforeAll()
+ spark.sql(s"DROP TABLE IF EXISTS $tableName0")
+ }
+
+ override def test(testName: String, testTags: Tag*)(testFun: => Any)(implicit
+ pos: Position): Unit = {
+ println(testName)
+ super.test(testName, testTags: _*)(testFun)(pos)
+ }
+}
diff --git
a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkWriteBuilder.java
b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/spark/paimon/Utils.scala
similarity index 57%
rename from
paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkWriteBuilder.java
rename to
paimon-spark/paimon-spark-common/src/test/scala/org/apache/spark/paimon/Utils.scala
index 876a6cf61..974bbf0c7 100644
---
a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkWriteBuilder.java
+++
b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/spark/paimon/Utils.scala
@@ -7,7 +7,7 @@
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
- * http://www.apache.org/licenses/LICENSE-2.0
+ * http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
@@ -15,29 +15,17 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+package org.apache.spark.paimon
-package org.apache.paimon.spark;
+import org.apache.spark.util.{Utils => SparkUtils}
-import org.apache.paimon.table.FileStoreTable;
-
-import org.apache.spark.sql.connector.write.Write;
-import org.apache.spark.sql.connector.write.WriteBuilder;
+import java.io.File
/**
- * Spark {@link WriteBuilder}.
- *
- * <p>TODO: Support overwrite.
+ * A wrapper that some Objects or Classes is limited to access beyond
[[org.apache.spark]] package.
*/
-public class SparkWriteBuilder implements WriteBuilder {
-
- private final FileStoreTable table;
+object Utils {
- public SparkWriteBuilder(FileStoreTable table) {
- this.table = table;
- }
+ def createTempDir: File = SparkUtils.createTempDir()
- @Override
- public Write build() {
- return new SparkWrite(table);
- }
}