This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 5c43da68587 [SPARK-40956] SQL Equivalent for Dataframe overwrite
command
5c43da68587 is described below
commit 5c43da6858721664318c3cdbcb051231b0e98175
Author: carlfu-db <[email protected]>
AuthorDate: Tue Nov 15 16:56:49 2022 +0800
[SPARK-40956] SQL Equivalent for Dataframe overwrite command
### What changes were proposed in this pull request?
Proposing syntax
INSERT INTO tbl REPLACE whereClause identifierList
to the spark SQL, as the equivalent of
[dataframe.overwrite()](https://github.com/apache/spark/blob/35d00df9bba7238ad4f409999617fae4d04ddbfd/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala#L163)
command.
For Example
INSERT INTO table1 REPLACE WHERE key = 3 SELECT * FROM table2
will, in an atomic operation, 1) delete rows with key = 3 and 2) insert
rows from table2
### Why are the changes needed?
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
Add Unit Test in
[DataSourceV2SQLSuite.scala](https://github.com/apache/spark/pull/38404/commits/9429a6c3430e00f273406e755e27f2066e6d352c#diff-eeb429a8e3eb55228451c8dbc2fccca044836be608d62e9166561b005030c940)
Closes #38404 from carlfu-db/replacewhere.
Lead-authored-by: carlfu-db <[email protected]>
Co-authored-by: Wenchen Fan <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../spark/sql/catalyst/parser/SqlBaseParser.g4 | 1 +
.../spark/sql/catalyst/parser/AstBuilder.scala | 6 +++
.../spark/sql/connector/DataSourceV2SQLSuite.scala | 46 ++++++++++++++++++++++
3 files changed, 53 insertions(+)
diff --git
a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
index 7b673870af8..a3c5f4a7b07 100644
---
a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
+++
b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
@@ -319,6 +319,7 @@ query
insertInto
: INSERT OVERWRITE TABLE? multipartIdentifier (partitionSpec (IF NOT
EXISTS)?)? identifierList? #insertOverwriteTable
| INSERT INTO TABLE? multipartIdentifier partitionSpec? (IF NOT EXISTS)?
identifierList? #insertIntoTable
+ | INSERT INTO TABLE? multipartIdentifier REPLACE whereClause
#insertIntoReplaceWhere
| INSERT OVERWRITE LOCAL? DIRECTORY path=stringLit rowFormat?
createFileFormat? #insertOverwriteHiveDir
| INSERT OVERWRITE LOCAL? DIRECTORY (path=stringLit)? tableProvider
(OPTIONS options=propertyList)? #insertOverwriteDir
;
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 8edb1702028..af2097b5d0f 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -261,6 +261,7 @@ class AstBuilder extends SqlBaseParserBaseVisitor[AnyRef]
with SQLConfHelper wit
* {{{
* INSERT OVERWRITE TABLE tableIdentifier [partitionSpec [IF NOT EXISTS]]?
[identifierList]
* INSERT INTO [TABLE] tableIdentifier [partitionSpec] [identifierList]
+ * INSERT INTO [TABLE] tableIdentifier REPLACE whereClause
* INSERT OVERWRITE [LOCAL] DIRECTORY STRING [rowFormat] [createFileFormat]
* INSERT OVERWRITE [LOCAL] DIRECTORY [STRING] tableProvider [OPTIONS
tablePropertyList]
* }}}
@@ -288,6 +289,11 @@ class AstBuilder extends SqlBaseParserBaseVisitor[AnyRef]
with SQLConfHelper wit
query,
overwrite = true,
ifPartitionNotExists)
+ case ctx: InsertIntoReplaceWhereContext =>
+ OverwriteByExpression.byPosition(
+ createUnresolvedRelation(ctx.multipartIdentifier),
+ query,
+ expression(ctx.whereClause().booleanExpression()))
case dir: InsertOverwriteDirContext =>
val (isLocal, storage, provider) = visitInsertOverwriteDir(dir)
InsertIntoDir(isLocal, storage, provider, query, overwrite = true)
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
index 25faa34b697..de8612c3348 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
@@ -2729,6 +2729,52 @@ class DataSourceV2SQLSuiteV1Filter extends
DataSourceV2SQLSuite with AlterTableT
}
}
+ test("Overwrite: overwrite by expression: True") {
+ val df = spark.createDataFrame(Seq((1L, "a"), (2L, "b"), (3L,
"c"))).toDF("id", "data")
+ df.createOrReplaceTempView("source")
+ val df2 = spark.createDataFrame(Seq((4L, "d"), (5L, "e"), (6L,
"f"))).toDF("id", "data")
+ df2.createOrReplaceTempView("source2")
+
+ val t = "testcat.tbl"
+ withTable(t) {
+ spark.sql(
+ s"CREATE TABLE $t (id bigint, data string) USING foo PARTITIONED BY
(id)")
+ spark.sql(s"INSERT INTO TABLE $t SELECT * FROM source")
+
+ checkAnswer(
+ spark.table(s"$t"),
+ Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c")))
+
+ spark.sql(s"INSERT INTO $t REPLACE WHERE TRUE SELECT * FROM source2")
+ checkAnswer(
+ spark.table(s"$t"),
+ Seq(Row(4L, "d"), Row(5L, "e"), Row(6L, "f")))
+ }
+ }
+
+ test("Overwrite: overwrite by expression: id = 3") {
+ val df = spark.createDataFrame(Seq((1L, "a"), (2L, "b"), (3L,
"c"))).toDF("id", "data")
+ df.createOrReplaceTempView("source")
+ val df2 = spark.createDataFrame(Seq((4L, "d"), (5L, "e"), (6L,
"f"))).toDF("id", "data")
+ df2.createOrReplaceTempView("source2")
+
+ val t = "testcat.tbl"
+ withTable(t) {
+ spark.sql(
+ s"CREATE TABLE $t (id bigint, data string) USING foo PARTITIONED BY
(id)")
+ spark.sql(s"INSERT INTO TABLE $t SELECT * FROM source")
+
+ checkAnswer(
+ spark.table(s"$t"),
+ Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c")))
+
+ spark.sql(s"INSERT INTO $t REPLACE WHERE id = 3 SELECT * FROM source2")
+ checkAnswer(
+ spark.table(s"$t"),
+ Seq(Row(1L, "a"), Row(2L, "b"), Row(4L, "d"), Row(5L, "e"), Row(6L,
"f")))
+ }
+ }
+
private def testNotSupportedV2Command(sqlCommand: String, sqlParams:
String): Unit = {
checkError(
exception = intercept[AnalysisException] {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]