Repository: spark
Updated Branches:
  refs/heads/master f362363d1 -> 35d9c8aa6


[SPARK-14747][SQL] Add assertStreaming/assertNoneStreaming checks in 
DataFrameWriter

## Problem

If an end user happens to write code mixed with continuous-query-oriented 
methods and non-continuous-query-oriented methods:

```scala
ctx.read
   .format("text")
   .stream("...")  // continuous query
   .write
   .text("...")    // non-continuous query; should be startStream() here
```

He/she would get this somehow confusing exception:

>
Exception in thread "main" java.lang.AssertionError: assertion failed: No plan 
for FileSource[./continuous_query_test_input]
        at scala.Predef$.assert(Predef.scala:170)
        at 
org.apache.spark.sql.catalyst.planning.QueryPlanner.plan(QueryPlanner.scala:59)
        at 
org.apache.spark.sql.catalyst.planning.QueryPlanner.planLater(QueryPlanner.scala:54)
        at ...

## What changes were proposed in this pull request?

This PR adds checks for continuous-query-oriented methods and 
non-continuous-query-oriented methods in `DataFrameWriter`:

<table>
<tr>
        <td align="center"></td>
        <td align="center"><strong>can be called on continuous 
query?</strong></td>
        <td align="center"><strong>can be called on non-continuous 
query?</strong></td>
</tr>
<tr>
        <td align="center">mode</td>
        <td align="center"></td>
        <td align="center">yes</td>
</tr>
<tr>
        <td align="center">trigger</td>
        <td align="center">yes</td>
        <td align="center"></td>
</tr>
<tr>
        <td align="center">format</td>
        <td align="center">yes</td>
        <td align="center">yes</td>
</tr>
<tr>
        <td align="center">option/options</td>
        <td align="center">yes</td>
        <td align="center">yes</td>
</tr>
<tr>
        <td align="center">partitionBy</td>
        <td align="center">yes</td>
        <td align="center">yes</td>
</tr>
<tr>
        <td align="center">bucketBy</td>
        <td align="center"></td>
        <td align="center">yes</td>
</tr>
<tr>
        <td align="center">sortBy</td>
        <td align="center"></td>
        <td align="center">yes</td>
</tr>
<tr>
        <td align="center">save</td>
        <td align="center"></td>
        <td align="center">yes</td>
</tr>
<tr>
        <td align="center">queryName</td>
        <td align="center">yes</td>
        <td align="center"></td>
</tr>
<tr>
        <td align="center">startStream</td>
        <td align="center">yes</td>
        <td align="center"></td>
</tr>
<tr>
        <td align="center">insertInto</td>
        <td align="center"></td>
        <td align="center">yes</td>
</tr>
<tr>
        <td align="center">saveAsTable</td>
        <td align="center"></td>
        <td align="center">yes</td>
</tr>
<tr>
        <td align="center">jdbc</td>
        <td align="center"></td>
        <td align="center">yes</td>
</tr>
<tr>
        <td align="center">json</td>
        <td align="center"></td>
        <td align="center">yes</td>
</tr>
<tr>
        <td align="center">parquet</td>
        <td align="center"></td>
        <td align="center">yes</td>
</tr>
<tr>
        <td align="center">orc</td>
        <td align="center"></td>
        <td align="center">yes</td>
</tr>
<tr>
        <td align="center">text</td>
        <td align="center"></td>
        <td align="center">yes</td>
</tr>
<tr>
        <td align="center">csv</td>
        <td align="center"></td>
        <td align="center">yes</td>
</tr>
</table>

After this PR's change, the friendly exception would be:
>
Exception in thread "main" org.apache.spark.sql.AnalysisException: text() can 
only be called on non-continuous queries;
        at 
org.apache.spark.sql.DataFrameWriter.assertNotStreaming(DataFrameWriter.scala:678)
        at org.apache.spark.sql.DataFrameWriter.text(DataFrameWriter.scala:629)
        at ss.SSDemo$.main(SSDemo.scala:47)

## How was this patch tested?

dedicated unit tests were added

Author: Liwei Lin <lwl...@gmail.com>

Closes #12521 from lw-lin/dataframe-writer-check.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/35d9c8aa
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/35d9c8aa
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/35d9c8aa

Branch: refs/heads/master
Commit: 35d9c8aa69c650f33037813607dc939922c5fc27
Parents: f362363
Author: Liwei Lin <lwl...@gmail.com>
Authored: Mon May 2 16:48:20 2016 -0700
Committer: Michael Armbrust <mich...@databricks.com>
Committed: Mon May 2 16:48:20 2016 -0700

----------------------------------------------------------------------
 .../org/apache/spark/sql/DataFrameWriter.scala  |  59 ++++++-
 .../streaming/DataFrameReaderWriterSuite.scala  | 156 +++++++++++++++++++
 2 files changed, 210 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/35d9c8aa/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
index a57d47d..a8f96a9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
@@ -53,6 +53,9 @@ final class DataFrameWriter private[sql](df: DataFrame) {
    * @since 1.4.0
    */
   def mode(saveMode: SaveMode): DataFrameWriter = {
+    // mode() is used for non-continuous queries
+    // outputMode() is used for continuous queries
+    assertNotStreaming("mode() can only be called on non-continuous queries")
     this.mode = saveMode
     this
   }
@@ -67,6 +70,9 @@ final class DataFrameWriter private[sql](df: DataFrame) {
    * @since 1.4.0
    */
   def mode(saveMode: String): DataFrameWriter = {
+    // mode() is used for non-continuous queries
+    // outputMode() is used for continuous queries
+    assertNotStreaming("mode() can only be called on non-continuous queries")
     this.mode = saveMode.toLowerCase match {
       case "overwrite" => SaveMode.Overwrite
       case "append" => SaveMode.Append
@@ -103,6 +109,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
    */
   @Experimental
   def trigger(trigger: Trigger): DataFrameWriter = {
+    assertStreaming("trigger() can only be called on continuous queries")
     this.trigger = trigger
     this
   }
@@ -236,6 +243,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
    */
   def save(): Unit = {
     assertNotBucketed()
+    assertNotStreaming("save() can only be called on non-continuous queries")
     val dataSource = DataSource(
       df.sparkSession,
       className = source,
@@ -253,6 +261,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
    * @since 2.0.0
    */
   def queryName(queryName: String): DataFrameWriter = {
+    assertStreaming("queryName() can only be called on continuous queries")
     this.extraOptions += ("queryName" -> queryName)
     this
   }
@@ -276,6 +285,9 @@ final class DataFrameWriter private[sql](df: DataFrame) {
    * @since 2.0.0
    */
   def startStream(): ContinuousQuery = {
+    assertNotBucketed
+    assertStreaming("startStream() can only be called on continuous queries")
+
     if (source == "memory") {
       val queryName =
         extraOptions.getOrElse(
@@ -348,6 +360,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
 
   private def insertInto(tableIdent: TableIdentifier): Unit = {
     assertNotBucketed()
+    assertNotStreaming("insertInto() can only be called on non-continuous 
queries")
     val partitions = normalizedParCols.map(_.map(col => col -> (None: 
Option[String])).toMap)
     val overwrite = mode == SaveMode.Overwrite
 
@@ -446,6 +459,8 @@ final class DataFrameWriter private[sql](df: DataFrame) {
   }
 
   private def saveAsTable(tableIdent: TableIdentifier): Unit = {
+    assertNotStreaming("saveAsTable() can only be called on non-continuous 
queries")
+
     val tableExists = 
df.sparkSession.sessionState.catalog.tableExists(tableIdent)
 
     (tableExists, mode) match {
@@ -486,6 +501,8 @@ final class DataFrameWriter private[sql](df: DataFrame) {
    * @since 1.4.0
    */
   def jdbc(url: String, table: String, connectionProperties: Properties): Unit 
= {
+    assertNotStreaming("jdbc() can only be called on non-continuous queries")
+
     val props = new Properties()
     extraOptions.foreach { case (key, value) =>
       props.put(key, value)
@@ -542,7 +559,10 @@ final class DataFrameWriter private[sql](df: DataFrame) {
    *
    * @since 1.4.0
    */
-  def json(path: String): Unit = format("json").save(path)
+  def json(path: String): Unit = {
+    assertNotStreaming("json() can only be called on non-continuous queries")
+    format("json").save(path)
+  }
 
   /**
    * Saves the content of the [[DataFrame]] in Parquet format at the specified 
path.
@@ -558,7 +578,10 @@ final class DataFrameWriter private[sql](df: DataFrame) {
    *
    * @since 1.4.0
    */
-  def parquet(path: String): Unit = format("parquet").save(path)
+  def parquet(path: String): Unit = {
+    assertNotStreaming("parquet() can only be called on non-continuous 
queries")
+    format("parquet").save(path)
+  }
 
   /**
    * Saves the content of the [[DataFrame]] in ORC format at the specified 
path.
@@ -575,7 +598,10 @@ final class DataFrameWriter private[sql](df: DataFrame) {
    * @since 1.5.0
    * @note Currently, this method can only be used together with `HiveContext`.
    */
-  def orc(path: String): Unit = format("orc").save(path)
+  def orc(path: String): Unit = {
+    assertNotStreaming("orc() can only be called on non-continuous queries")
+    format("orc").save(path)
+  }
 
   /**
    * Saves the content of the [[DataFrame]] in a text file at the specified 
path.
@@ -596,7 +622,10 @@ final class DataFrameWriter private[sql](df: DataFrame) {
    *
    * @since 1.6.0
    */
-  def text(path: String): Unit = format("text").save(path)
+  def text(path: String): Unit = {
+    assertNotStreaming("text() can only be called on non-continuous queries")
+    format("text").save(path)
+  }
 
   /**
    * Saves the content of the [[DataFrame]] in CSV format at the specified 
path.
@@ -620,7 +649,10 @@ final class DataFrameWriter private[sql](df: DataFrame) {
    *
    * @since 2.0.0
    */
-  def csv(path: String): Unit = format("csv").save(path)
+  def csv(path: String): Unit = {
+    assertNotStreaming("csv() can only be called on non-continuous queries")
+    format("csv").save(path)
+  }
 
   
///////////////////////////////////////////////////////////////////////////////////////
   // Builder pattern config options
@@ -641,4 +673,21 @@ final class DataFrameWriter private[sql](df: DataFrame) {
   private var numBuckets: Option[Int] = None
 
   private var sortColumnNames: Option[Seq[String]] = None
+
+  
///////////////////////////////////////////////////////////////////////////////////////
+  // Helper functions
+  
///////////////////////////////////////////////////////////////////////////////////////
+
+  private def assertNotStreaming(errMsg: String): Unit = {
+    if (df.isStreaming) {
+      throw new AnalysisException(errMsg)
+    }
+  }
+
+  private def assertStreaming(errMsg: String): Unit = {
+    if (!df.isStreaming) {
+      throw new AnalysisException(errMsg)
+    }
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/35d9c8aa/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala
index 00efe21..c7b2b99 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala
@@ -368,4 +368,160 @@ class DataFrameReaderWriterSuite extends StreamTest with 
SharedSQLContext with B
       "org.apache.spark.sql.streaming.test",
       Map.empty)
   }
+
+  private def newTextInput = Utils.createTempDir(namePrefix = 
"text").getCanonicalPath
+
+  test("check trigger() can only be called on continuous queries") {
+    val df = sqlContext.read.text(newTextInput)
+    val w = df.write.option("checkpointLocation", newMetadataDir)
+    val e = intercept[AnalysisException](w.trigger(ProcessingTime("10 
seconds")))
+    assert(e.getMessage == "trigger() can only be called on continuous 
queries;")
+  }
+
+  test("check queryName() can only be called on continuous queries") {
+    val df = sqlContext.read.text(newTextInput)
+    val w = df.write.option("checkpointLocation", newMetadataDir)
+    val e = intercept[AnalysisException](w.queryName("queryName"))
+    assert(e.getMessage == "queryName() can only be called on continuous 
queries;")
+  }
+
+  test("check startStream() can only be called on continuous queries") {
+    val df = sqlContext.read.text(newTextInput)
+    val w = df.write.option("checkpointLocation", newMetadataDir)
+    val e = intercept[AnalysisException](w.startStream())
+    assert(e.getMessage == "startStream() can only be called on continuous 
queries;")
+  }
+
+  test("check startStream(path) can only be called on continuous queries") {
+    val df = sqlContext.read.text(newTextInput)
+    val w = df.write.option("checkpointLocation", newMetadataDir)
+    val e = intercept[AnalysisException](w.startStream("non_exist_path"))
+    assert(e.getMessage == "startStream() can only be called on continuous 
queries;")
+  }
+
+  test("check mode(SaveMode) can only be called on non-continuous queries") {
+    val df = sqlContext.read
+      .format("org.apache.spark.sql.streaming.test")
+      .stream()
+    val w = df.write
+    val e = intercept[AnalysisException](w.mode(SaveMode.Append))
+    assert(e.getMessage == "mode() can only be called on non-continuous 
queries;")
+  }
+
+  test("check mode(string) can only be called on non-continuous queries") {
+    val df = sqlContext.read
+      .format("org.apache.spark.sql.streaming.test")
+      .stream()
+    val w = df.write
+    val e = intercept[AnalysisException](w.mode("append"))
+    assert(e.getMessage == "mode() can only be called on non-continuous 
queries;")
+  }
+
+  test("check bucketBy() can only be called on non-continuous queries") {
+    val df = sqlContext.read
+      .format("org.apache.spark.sql.streaming.test")
+      .stream()
+    val w = df.write
+    val e = intercept[IllegalArgumentException](w.bucketBy(1, 
"text").startStream())
+    assert(e.getMessage == "Currently we don't support writing bucketed data 
to this data source.")
+  }
+
+  test("check sortBy() can only be called on non-continuous queries;") {
+    val df = sqlContext.read
+      .format("org.apache.spark.sql.streaming.test")
+      .stream()
+    val w = df.write
+    val e = intercept[IllegalArgumentException](w.sortBy("text").startStream())
+    assert(e.getMessage == "Currently we don't support writing bucketed data 
to this data source.")
+  }
+
+  test("check save(path) can only be called on non-continuous queries") {
+    val df = sqlContext.read
+      .format("org.apache.spark.sql.streaming.test")
+      .stream()
+    val w = df.write
+    val e = intercept[AnalysisException](w.save("non_exist_path"))
+    assert(e.getMessage == "save() can only be called on non-continuous 
queries;")
+  }
+
+  test("check save() can only be called on non-continuous queries") {
+    val df = sqlContext.read
+      .format("org.apache.spark.sql.streaming.test")
+      .stream()
+    val w = df.write
+    val e = intercept[AnalysisException](w.save())
+    assert(e.getMessage == "save() can only be called on non-continuous 
queries;")
+  }
+
+  test("check insertInto() can only be called on non-continuous queries") {
+    val df = sqlContext.read
+      .format("org.apache.spark.sql.streaming.test")
+      .stream()
+    val w = df.write
+    val e = intercept[AnalysisException](w.insertInto("non_exsit_table"))
+    assert(e.getMessage == "insertInto() can only be called on non-continuous 
queries;")
+  }
+
+  test("check saveAsTable() can only be called on non-continuous queries") {
+    val df = sqlContext.read
+      .format("org.apache.spark.sql.streaming.test")
+      .stream()
+    val w = df.write
+    val e = intercept[AnalysisException](w.saveAsTable("non_exsit_table"))
+    assert(e.getMessage == "saveAsTable() can only be called on non-continuous 
queries;")
+  }
+
+  test("check jdbc() can only be called on non-continuous queries") {
+    val df = sqlContext.read
+      .format("org.apache.spark.sql.streaming.test")
+      .stream()
+    val w = df.write
+    val e = intercept[AnalysisException](w.jdbc(null, null, null))
+    assert(e.getMessage == "jdbc() can only be called on non-continuous 
queries;")
+  }
+
+  test("check json() can only be called on non-continuous queries") {
+    val df = sqlContext.read
+      .format("org.apache.spark.sql.streaming.test")
+      .stream()
+    val w = df.write
+    val e = intercept[AnalysisException](w.json("non_exist_path"))
+    assert(e.getMessage == "json() can only be called on non-continuous 
queries;")
+  }
+
+  test("check parquet() can only be called on non-continuous queries") {
+    val df = sqlContext.read
+      .format("org.apache.spark.sql.streaming.test")
+      .stream()
+    val w = df.write
+    val e = intercept[AnalysisException](w.parquet("non_exist_path"))
+    assert(e.getMessage == "parquet() can only be called on non-continuous 
queries;")
+  }
+
+  test("check orc() can only be called on non-continuous queries") {
+    val df = sqlContext.read
+      .format("org.apache.spark.sql.streaming.test")
+      .stream()
+    val w = df.write
+    val e = intercept[AnalysisException](w.orc("non_exist_path"))
+    assert(e.getMessage == "orc() can only be called on non-continuous 
queries;")
+  }
+
+  test("check text() can only be called on non-continuous queries") {
+    val df = sqlContext.read
+      .format("org.apache.spark.sql.streaming.test")
+      .stream()
+    val w = df.write
+    val e = intercept[AnalysisException](w.text("non_exist_path"))
+    assert(e.getMessage == "text() can only be called on non-continuous 
queries;")
+  }
+
+  test("check csv() can only be called on non-continuous queries") {
+    val df = sqlContext.read
+      .format("org.apache.spark.sql.streaming.test")
+      .stream()
+    val w = df.write
+    val e = intercept[AnalysisException](w.csv("non_exist_path"))
+    assert(e.getMessage == "csv() can only be called on non-continuous 
queries;")
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to