This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 5c43da68587 [SPARK-40956] SQL Equivalent for Dataframe overwrite 
command
5c43da68587 is described below

commit 5c43da6858721664318c3cdbcb051231b0e98175
Author: carlfu-db <[email protected]>
AuthorDate: Tue Nov 15 16:56:49 2022 +0800

    [SPARK-40956] SQL Equivalent for Dataframe overwrite command
    
    ### What changes were proposed in this pull request?
    
    Proposing syntax
    
     INSERT INTO tbl REPLACE whereClause identifierList
    to the spark SQL, as the equivalent of 
[dataframe.overwrite()](https://github.com/apache/spark/blob/35d00df9bba7238ad4f409999617fae4d04ddbfd/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala#L163)
 command.
    
    For Example
    
     INSERT INTO table1 REPLACE WHERE key = 3 SELECT * FROM table2
    will, in an atomic operation, 1) delete rows with key = 3 and 2) insert 
rows from table2
    
    ### Why are the changes needed?
    
    ### Does this PR introduce _any_ user-facing change?
    
    ### How was this patch tested?
    
    Add Unit Test in 
[DataSourceV2SQLSuite.scala](https://github.com/apache/spark/pull/38404/commits/9429a6c3430e00f273406e755e27f2066e6d352c#diff-eeb429a8e3eb55228451c8dbc2fccca044836be608d62e9166561b005030c940)
    
    Closes #38404 from carlfu-db/replacewhere.
    
    Lead-authored-by: carlfu-db <[email protected]>
    Co-authored-by: Wenchen Fan <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../spark/sql/catalyst/parser/SqlBaseParser.g4     |  1 +
 .../spark/sql/catalyst/parser/AstBuilder.scala     |  6 +++
 .../spark/sql/connector/DataSourceV2SQLSuite.scala | 46 ++++++++++++++++++++++
 3 files changed, 53 insertions(+)

diff --git 
a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
 
b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
index 7b673870af8..a3c5f4a7b07 100644
--- 
a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
+++ 
b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
@@ -319,6 +319,7 @@ query
 insertInto
     : INSERT OVERWRITE TABLE? multipartIdentifier (partitionSpec (IF NOT 
EXISTS)?)?  identifierList?        #insertOverwriteTable
     | INSERT INTO TABLE? multipartIdentifier partitionSpec? (IF NOT EXISTS)? 
identifierList?                #insertIntoTable
+    | INSERT INTO TABLE? multipartIdentifier REPLACE whereClause               
                             #insertIntoReplaceWhere
     | INSERT OVERWRITE LOCAL? DIRECTORY path=stringLit rowFormat? 
createFileFormat?                            #insertOverwriteHiveDir
     | INSERT OVERWRITE LOCAL? DIRECTORY (path=stringLit)? tableProvider 
(OPTIONS options=propertyList)?        #insertOverwriteDir
     ;
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 8edb1702028..af2097b5d0f 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -261,6 +261,7 @@ class AstBuilder extends SqlBaseParserBaseVisitor[AnyRef] 
with SQLConfHelper wit
    * {{{
    *   INSERT OVERWRITE TABLE tableIdentifier [partitionSpec [IF NOT EXISTS]]? 
[identifierList]
    *   INSERT INTO [TABLE] tableIdentifier [partitionSpec]  [identifierList]
+   *   INSERT INTO [TABLE] tableIdentifier REPLACE whereClause
    *   INSERT OVERWRITE [LOCAL] DIRECTORY STRING [rowFormat] [createFileFormat]
    *   INSERT OVERWRITE [LOCAL] DIRECTORY [STRING] tableProvider [OPTIONS 
tablePropertyList]
    * }}}
@@ -288,6 +289,11 @@ class AstBuilder extends SqlBaseParserBaseVisitor[AnyRef] 
with SQLConfHelper wit
           query,
           overwrite = true,
           ifPartitionNotExists)
+      case ctx: InsertIntoReplaceWhereContext =>
+        OverwriteByExpression.byPosition(
+          createUnresolvedRelation(ctx.multipartIdentifier),
+          query,
+          expression(ctx.whereClause().booleanExpression()))
       case dir: InsertOverwriteDirContext =>
         val (isLocal, storage, provider) = visitInsertOverwriteDir(dir)
         InsertIntoDir(isLocal, storage, provider, query, overwrite = true)
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
index 25faa34b697..de8612c3348 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
@@ -2729,6 +2729,52 @@ class DataSourceV2SQLSuiteV1Filter extends 
DataSourceV2SQLSuite with AlterTableT
     }
   }
 
+  test("Overwrite: overwrite by expression: True") {
+    val df = spark.createDataFrame(Seq((1L, "a"), (2L, "b"), (3L, 
"c"))).toDF("id", "data")
+    df.createOrReplaceTempView("source")
+    val df2 = spark.createDataFrame(Seq((4L, "d"), (5L, "e"), (6L, 
"f"))).toDF("id", "data")
+    df2.createOrReplaceTempView("source2")
+
+    val t = "testcat.tbl"
+    withTable(t) {
+      spark.sql(
+        s"CREATE TABLE $t (id bigint, data string) USING foo PARTITIONED BY 
(id)")
+      spark.sql(s"INSERT INTO TABLE $t SELECT * FROM source")
+
+      checkAnswer(
+        spark.table(s"$t"),
+        Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c")))
+
+      spark.sql(s"INSERT INTO $t REPLACE WHERE TRUE SELECT * FROM source2")
+      checkAnswer(
+        spark.table(s"$t"),
+        Seq(Row(4L, "d"), Row(5L, "e"), Row(6L, "f")))
+    }
+  }
+
+  test("Overwrite: overwrite by expression: id = 3") {
+    val df = spark.createDataFrame(Seq((1L, "a"), (2L, "b"), (3L, 
"c"))).toDF("id", "data")
+    df.createOrReplaceTempView("source")
+    val df2 = spark.createDataFrame(Seq((4L, "d"), (5L, "e"), (6L, 
"f"))).toDF("id", "data")
+    df2.createOrReplaceTempView("source2")
+
+    val t = "testcat.tbl"
+    withTable(t) {
+      spark.sql(
+        s"CREATE TABLE $t (id bigint, data string) USING foo PARTITIONED BY 
(id)")
+      spark.sql(s"INSERT INTO TABLE $t SELECT * FROM source")
+
+      checkAnswer(
+        spark.table(s"$t"),
+        Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c")))
+
+      spark.sql(s"INSERT INTO $t REPLACE WHERE id = 3 SELECT * FROM source2")
+      checkAnswer(
+        spark.table(s"$t"),
+        Seq(Row(1L, "a"), Row(2L, "b"), Row(4L, "d"), Row(5L, "e"), Row(6L, 
"f")))
+    }
+  }
+
   private def testNotSupportedV2Command(sqlCommand: String, sqlParams: 
String): Unit = {
     checkError(
       exception = intercept[AnalysisException] {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to