Repository: spark
Updated Branches:
  refs/heads/master 9ad77b303 -> 2333a34d3


[SPARK-22880][SQL] Add cascadeTruncate option to JDBC datasource

This commit adds the `cascadeTruncate` option to the JDBC datasource
API, for databases that support this functionality (PostgreSQL and
Oracle at the moment). This allows for applying a cascading truncate
that affects tables that have foreign key constraints on the table
being truncated.

## What changes were proposed in this pull request?

Add `cascadeTruncate` option to JDBC datasource API. Allow this to affect the
`TRUNCATE` query for databases that support this option.

## How was this patch tested?
Existing tests for `truncateQuery` were updated. Also, an additional test was 
added
to ensure that the correct syntax was applied, and that enabling the config for 
databases
that do not support this option does not result in invalid queries.

Author: Daniel van der Ende <daniel.vandere...@gmail.com>

Closes #20057 from danielvdende/SPARK-22880.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2333a34d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2333a34d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2333a34d

Branch: refs/heads/master
Commit: 2333a34d390f2fa19b939b8007be0deb31f31d3c
Parents: 9ad77b3
Author: Daniel van der Ende <daniel.vandere...@gmail.com>
Authored: Fri Jul 20 13:03:57 2018 -0700
Committer: Xiao Li <gatorsm...@gmail.com>
Committed: Fri Jul 20 13:03:57 2018 -0700

----------------------------------------------------------------------
 docs/sql-programming-guide.md                   |  7 +++
 .../datasources/jdbc/JDBCOptions.scala          |  3 ++
 .../execution/datasources/jdbc/JdbcUtils.scala  |  7 ++-
 .../spark/sql/jdbc/AggregatedDialect.scala      | 13 +++++-
 .../apache/spark/sql/jdbc/DerbyDialect.scala    |  2 +
 .../apache/spark/sql/jdbc/JdbcDialects.scala    | 20 ++++++++-
 .../apache/spark/sql/jdbc/OracleDialect.scala   | 16 +++++++
 .../apache/spark/sql/jdbc/PostgresDialect.scala | 29 ++++++++----
 .../apache/spark/sql/jdbc/TeradataDialect.scala | 18 ++++++++
 .../org/apache/spark/sql/jdbc/JDBCSuite.scala   | 46 +++++++++++++++++---
 10 files changed, 140 insertions(+), 21 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2333a34d/docs/sql-programming-guide.md
----------------------------------------------------------------------
diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md
index ad23dae..4bab58a 100644
--- a/docs/sql-programming-guide.md
+++ b/docs/sql-programming-guide.md
@@ -1407,6 +1407,13 @@ the following case-insensitive options:
      This is a JDBC writer related option. When 
<code>SaveMode.Overwrite</code> is enabled, this option causes Spark to 
truncate an existing table instead of dropping and recreating it. This can be 
more efficient, and prevents the table metadata (e.g., indices) from being 
removed. However, it will not work in some cases, such as when the new data has 
a different schema. It defaults to <code>false</code>. This option applies only 
to writing.
    </td>
   </tr>
+  
+  <tr>
+    <td><code>cascadeTruncate</code></td>
+    <td>
+        This is a JDBC writer related option. If enabled and supported by the 
JDBC database (PostgreSQL and Oracle at the moment), this options allows 
execution of a <code>TRUNCATE TABLE t CASCADE</code> (in the case of PostgreSQL 
a <code>TRUNCATE TABLE ONLY t CASCADE</code> is executed to prevent 
inadvertently truncating descendant tables). This will affect other tables, and 
thus should be used with care. This option applies only to writing. It defaults 
to the default cascading truncate behaviour of the JDBC database in question, 
specified in the <code>isCascadeTruncate</code> in each JDBCDialect.
+    </td>
+  </tr>
 
   <tr>
     <td><code>createTableOptions</code></td>

http://git-wip-us.apache.org/repos/asf/spark/blob/2333a34d/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
index eea966d..574aed4 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
@@ -157,6 +157,8 @@ class JDBCOptions(
   // ------------------------------------------------------------
   // if to truncate the table from the JDBC database
   val isTruncate = parameters.getOrElse(JDBC_TRUNCATE, "false").toBoolean
+
+  val isCascadeTruncate: Option[Boolean] = 
parameters.get(JDBC_CASCADE_TRUNCATE).map(_.toBoolean)
   // the create table option , which can be table_options or partition_options.
   // E.g., "CREATE TABLE t (name string) ENGINE=InnoDB DEFAULT CHARSET=utf8"
   // TODO: to reuse the existing partition parameters for those partition 
specific options
@@ -225,6 +227,7 @@ object JDBCOptions {
   val JDBC_QUERY_TIMEOUT = newOption("queryTimeout")
   val JDBC_BATCH_FETCH_SIZE = newOption("fetchsize")
   val JDBC_TRUNCATE = newOption("truncate")
+  val JDBC_CASCADE_TRUNCATE = newOption("cascadeTruncate")
   val JDBC_CREATE_TABLE_OPTIONS = newOption("createTableOptions")
   val JDBC_CREATE_TABLE_COLUMN_TYPES = newOption("createTableColumnTypes")
   val JDBC_CUSTOM_DATAFRAME_COLUMN_TYPES = newOption("customSchema")

http://git-wip-us.apache.org/repos/asf/spark/blob/2333a34d/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
index 6cc7922..edea549 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
@@ -105,7 +105,12 @@ object JdbcUtils extends Logging {
     val statement = conn.createStatement
     try {
       statement.setQueryTimeout(options.queryTimeout)
-      statement.executeUpdate(dialect.getTruncateQuery(options.table))
+      val truncateQuery = if (options.isCascadeTruncate.isDefined) {
+        dialect.getTruncateQuery(options.table, options.isCascadeTruncate)
+      } else {
+        dialect.getTruncateQuery(options.table)
+      }
+      statement.executeUpdate(truncateQuery)
     } finally {
       statement.close()
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/2333a34d/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala
index 8b92c8b..3a3246a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala
@@ -64,7 +64,16 @@ private class AggregatedDialect(dialects: List[JdbcDialect]) 
extends JdbcDialect
     }
   }
 
-  override def getTruncateQuery(table: String): String = {
-    dialects.head.getTruncateQuery(table)
+  /**
+   * The SQL query used to truncate a table.
+   * @param table The table to truncate.
+   * @param cascade Whether or not to cascade the truncation. Default value is 
the
+   *                value of isCascadingTruncateTable()
+   * @return The SQL query to use for truncating a table
+   */
+  override def getTruncateQuery(
+      table: String,
+      cascade: Option[Boolean] = isCascadingTruncateTable): String = {
+    dialects.head.getTruncateQuery(table, cascade)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/2333a34d/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala
index 84f68e7..d13c29e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala
@@ -41,4 +41,6 @@ private object DerbyDialect extends JdbcDialect {
       Option(JdbcType("DECIMAL(31,5)", java.sql.Types.DECIMAL))
     case _ => None
   }
+
+  override def isCascadingTruncateTable(): Option[Boolean] = Some(false)
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/2333a34d/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
index 83d87a1..f76c1fa 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
@@ -22,6 +22,7 @@ import java.sql.{Connection, Date, Timestamp}
 import org.apache.commons.lang3.StringUtils
 
 import org.apache.spark.annotation.{DeveloperApi, InterfaceStability, Since}
+import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions
 import org.apache.spark.sql.types._
 
 /**
@@ -120,12 +121,27 @@ abstract class JdbcDialect extends Serializable {
    * The SQL query that should be used to truncate a table. Dialects can 
override this method to
    * return a query that is suitable for a particular database. For 
PostgreSQL, for instance,
    * a different query is used to prevent "TRUNCATE" affecting other tables.
-   * @param table The name of the table.
+   * @param table The table to truncate
    * @return The SQL query to use for truncating a table
    */
   @Since("2.3.0")
   def getTruncateQuery(table: String): String = {
-    s"TRUNCATE TABLE $table"
+    getTruncateQuery(table, isCascadingTruncateTable)
+  }
+
+  /**
+   * The SQL query that should be used to truncate a table. Dialects can 
override this method to
+   * return a query that is suitable for a particular database. For 
PostgreSQL, for instance,
+   * a different query is used to prevent "TRUNCATE" affecting other tables.
+   * @param table The table to truncate
+   * @param cascade Whether or not to cascade the truncation
+   * @return The SQL query to use for truncating a table
+   */
+  @Since("2.4.0")
+  def getTruncateQuery(
+    table: String,
+    cascade: Option[Boolean] = isCascadingTruncateTable): String = {
+      s"TRUNCATE TABLE $table"
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/2333a34d/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala
index 6ef77f2..f4a6d0a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala
@@ -95,4 +95,20 @@ private case object OracleDialect extends JdbcDialect {
   }
 
   override def isCascadingTruncateTable(): Option[Boolean] = Some(false)
+
+  /**
+   * The SQL query used to truncate a table.
+   * @param table The table to truncate
+   * @param cascade Whether or not to cascade the truncation. Default value is 
the
+   *                value of isCascadingTruncateTable()
+   * @return The SQL query to use for truncating a table
+   */
+  override def getTruncateQuery(
+      table: String,
+      cascade: Option[Boolean] = isCascadingTruncateTable): String = {
+    cascade match {
+      case Some(true) => s"TRUNCATE TABLE $table CASCADE"
+      case _ => s"TRUNCATE TABLE $table"
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/2333a34d/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala
index 13a2035..f8d2bc8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala
@@ -85,15 +85,27 @@ private object PostgresDialect extends JdbcDialect {
     s"SELECT 1 FROM $table LIMIT 1"
   }
 
+  override def isCascadingTruncateTable(): Option[Boolean] = Some(false)
+
   /**
-  * The SQL query used to truncate a table. For Postgres, the default 
behaviour is to
-  * also truncate any descendant tables. As this is a (possibly unwanted) 
side-effect,
-  * the Postgres dialect adds 'ONLY' to truncate only the table in question
-  * @param table The name of the table.
-  * @return The SQL query to use for truncating a table
-  */
-  override def getTruncateQuery(table: String): String = {
-    s"TRUNCATE TABLE ONLY $table"
+   * The SQL query used to truncate a table. For Postgres, the default 
behaviour is to
+   * also truncate any descendant tables. As this is a (possibly unwanted) 
side-effect,
+   * the Postgres dialect adds 'ONLY' to truncate only the table in question
+   * @param table The table to truncate
+   * @param cascade Whether or not to cascade the truncation. Default value is 
the value of
+   *                isCascadingTruncateTable(). Cascading a truncation will 
truncate tables
+    *               with a foreign key relationship to the target table. 
However, it will not
+    *               truncate tables with an inheritance relationship to the 
target table, as
+    *               the truncate query always includes "ONLY" to prevent this 
behaviour.
+   * @return The SQL query to use for truncating a table
+   */
+  override def getTruncateQuery(
+      table: String,
+      cascade: Option[Boolean] = isCascadingTruncateTable): String = {
+    cascade match {
+      case Some(true) => s"TRUNCATE TABLE ONLY $table CASCADE"
+      case _ => s"TRUNCATE TABLE ONLY $table"
+    }
   }
 
   override def beforeFetch(connection: Connection, properties: Map[String, 
String]): Unit = {
@@ -110,5 +122,4 @@ private object PostgresDialect extends JdbcDialect {
     }
   }
 
-  override def isCascadingTruncateTable(): Option[Boolean] = Some(false)
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/2333a34d/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala
index 5749b79..6c17bd7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala
@@ -31,4 +31,22 @@ private case object TeradataDialect extends JdbcDialect {
     case BooleanType => Option(JdbcType("CHAR(1)", java.sql.Types.CHAR))
     case _ => None
   }
+
+  // Teradata does not support cascading a truncation
+  override def isCascadingTruncateTable(): Option[Boolean] = Some(false)
+
+  /**
+   * The SQL query used to truncate a table. Teradata does not support the 
'TRUNCATE' syntax that
+   * other dialects use. Instead, we need to use a 'DELETE FROM' statement.
+   * @param table The table to truncate.
+   * @param cascade Whether or not to cascade the truncation. Default value is 
the
+   *                value of isCascadingTruncateTable(). Teradata does not 
support cascading a
+   *                'DELETE FROM' statement (and as mentioned, does not 
support 'TRUNCATE' syntax)
+   * @return The SQL query to use for truncating a table
+   */
+  override def getTruncateQuery(
+      table: String,
+      cascade: Option[Boolean] = isCascadingTruncateTable): String = {
+    s"DELETE FROM $table ALL"
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/2333a34d/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
index 0389273..09facb9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
@@ -861,19 +861,51 @@ class JDBCSuite extends QueryTest
   }
 
   test("truncate table query by jdbc dialect") {
-    val MySQL = JdbcDialects.get("jdbc:mysql://127.0.0.1/db")
-    val Postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db")
+    val mysql = JdbcDialects.get("jdbc:mysql://127.0.0.1/db")
+    val postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db")
     val db2 = JdbcDialects.get("jdbc:db2://127.0.0.1/db")
     val h2 = JdbcDialects.get(url)
     val derby = JdbcDialects.get("jdbc:derby:db")
+    val oracle = JdbcDialects.get("jdbc:oracle://127.0.0.1/db")
+    val teradata = JdbcDialects.get("jdbc:teradata://127.0.0.1/db")
+
     val table = "weblogs"
     val defaultQuery = s"TRUNCATE TABLE $table"
     val postgresQuery = s"TRUNCATE TABLE ONLY $table"
-    assert(MySQL.getTruncateQuery(table) == defaultQuery)
-    assert(Postgres.getTruncateQuery(table) == postgresQuery)
-    assert(db2.getTruncateQuery(table) == defaultQuery)
-    assert(h2.getTruncateQuery(table) == defaultQuery)
-    assert(derby.getTruncateQuery(table) == defaultQuery)
+    val teradataQuery = s"DELETE FROM $table ALL"
+
+    Seq(mysql, db2, h2, derby).foreach{ dialect =>
+      assert(dialect.getTruncateQuery(table, Some(true)) == defaultQuery)
+    }
+
+    assert(postgres.getTruncateQuery(table) == postgresQuery)
+    assert(oracle.getTruncateQuery(table) == defaultQuery)
+    assert(teradata.getTruncateQuery(table) == teradataQuery)
+  }
+
+  test("SPARK-22880: Truncate table with CASCADE by jdbc dialect") {
+    // cascade in a truncate should only be applied for databases that support 
this,
+    // even if the parameter is passed.
+    val mysql = JdbcDialects.get("jdbc:mysql://127.0.0.1/db")
+    val postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db")
+    val db2 = JdbcDialects.get("jdbc:db2://127.0.0.1/db")
+    val h2 = JdbcDialects.get(url)
+    val derby = JdbcDialects.get("jdbc:derby:db")
+    val oracle = JdbcDialects.get("jdbc:oracle://127.0.0.1/db")
+    val teradata = JdbcDialects.get("jdbc:teradata://127.0.0.1/db")
+
+    val table = "weblogs"
+    val defaultQuery = s"TRUNCATE TABLE $table"
+    val postgresQuery = s"TRUNCATE TABLE ONLY $table CASCADE"
+    val oracleQuery = s"TRUNCATE TABLE $table CASCADE"
+    val teradataQuery = s"DELETE FROM $table ALL"
+
+    Seq(mysql, db2, h2, derby).foreach{ dialect =>
+      assert(dialect.getTruncateQuery(table, Some(true)) == defaultQuery)
+    }
+    assert(postgres.getTruncateQuery(table, Some(true)) == postgresQuery)
+    assert(oracle.getTruncateQuery(table, Some(true)) == oracleQuery)
+    assert(teradata.getTruncateQuery(table, Some(true)) == teradataQuery)
   }
 
   test("Test DataFrame.where for Date and Timestamp") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to