Repository: spark
Updated Branches:
  refs/heads/master 04b401da8 -> 5b6cd65cd


[SPARK-5746][SQL] Check invalid cases for the write path of data source API

JIRA: https://issues.apache.org/jira/browse/SPARK-5746

liancheng marmbrus

Author: Yin Huai <yh...@databricks.com>

Closes #4617 from yhuai/insertOverwrite and squashes the following commits:

8e3019d [Yin Huai] Fix compilation error.
499e8e7 [Yin Huai] Merge remote-tracking branch 'upstream/master' into 
insertOverwrite
e76e85a [Yin Huai] Address comments.
ac31b3c [Yin Huai] Merge remote-tracking branch 'upstream/master' into 
insertOverwrite
f30bdad [Yin Huai] Use toDF.
99da57e [Yin Huai] Merge remote-tracking branch 'upstream/master' into 
insertOverwrite
6b7545c [Yin Huai] Add a pre write check to the data source API.
a88c516 [Yin Huai] DDLParser will take a parsering function to take care CTAS 
statements.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5b6cd65c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5b6cd65c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5b6cd65c

Branch: refs/heads/master
Commit: 5b6cd65cd611b1a46a7d5eb33139c6224b96264e
Parents: 04b401d
Author: Yin Huai <yh...@databricks.com>
Authored: Mon Feb 16 15:51:59 2015 -0800
Committer: Michael Armbrust <mich...@databricks.com>
Committed: Mon Feb 16 15:51:59 2015 -0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/analysis/Analyzer.scala  |  13 +-
 .../org/apache/spark/sql/DataFrameImpl.scala    |   8 +-
 .../scala/org/apache/spark/sql/SQLContext.scala |   5 +-
 .../spark/sql/execution/SparkStrategies.scala   |  10 +-
 .../spark/sql/sources/DataSourceStrategy.scala  |   5 +-
 .../org/apache/spark/sql/sources/ddl.scala      |  29 ++-
 .../org/apache/spark/sql/sources/rules.scala    |  72 +++++-
 .../spark/sql/parquet/ParquetQuerySuite.scala   |  13 +-
 .../sql/sources/CreateTableAsSelectSuite.scala  |  28 +++
 .../spark/sql/sources/DataSourceTest.scala      |   3 +-
 .../spark/sql/sources/InsertIntoSuite.scala     | 176 ---------------
 .../apache/spark/sql/sources/InsertSuite.scala  | 218 +++++++++++++++++++
 .../org/apache/spark/sql/hive/HiveContext.scala |  12 +-
 .../spark/sql/hive/HiveMetastoreCatalog.scala   |   2 +-
 .../apache/spark/sql/hive/HiveStrategies.scala  |   8 +-
 15 files changed, 371 insertions(+), 231 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/5b6cd65c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index aa4320b..fc37b8c 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -50,7 +50,13 @@ class Analyzer(catalog: Catalog,
   /**
    * Override to provide additional rules for the "Resolution" batch.
    */
-  val extendedRules: Seq[Rule[LogicalPlan]] = Nil
+  val extendedResolutionRules: Seq[Rule[LogicalPlan]] = Nil
+
+  /**
+   * Override to provide additional rules for the "Check Analysis" batch.
+   * These rules will be evaluated after our built-in check rules.
+   */
+  val extendedCheckRules: Seq[Rule[LogicalPlan]] = Nil
 
   lazy val batches: Seq[Batch] = Seq(
     Batch("Resolution", fixedPoint,
@@ -64,9 +70,10 @@ class Analyzer(catalog: Catalog,
       UnresolvedHavingClauseAttributes ::
       TrimGroupingAliases ::
       typeCoercionRules ++
-      extendedRules : _*),
+      extendedResolutionRules : _*),
     Batch("Check Analysis", Once,
-      CheckResolution),
+      CheckResolution +:
+      extendedCheckRules: _*),
     Batch("Remove SubQueries", fixedPoint,
       EliminateSubQueries)
   )

http://git-wip-us.apache.org/repos/asf/spark/blob/5b6cd65c/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
index 500e3c9..3c1cf8d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
@@ -67,7 +67,11 @@ private[sql] class DataFrameImpl protected[sql](
   @transient protected[sql] override val logicalPlan: LogicalPlan = 
queryExecution.logical match {
     // For various commands (like DDL) and queries with side effects, we force 
query optimization to
     // happen right away to let these side effects take place eagerly.
-    case _: Command | _: InsertIntoTable | _: CreateTableAsSelect[_] |_: 
WriteToFile =>
+    case _: Command |
+         _: InsertIntoTable |
+         _: CreateTableAsSelect[_] |
+         _: CreateTableUsingAsSelect |
+         _: WriteToFile =>
       LogicalRDD(queryExecution.analyzed.output, 
queryExecution.toRdd)(sqlContext)
     case _ =>
       queryExecution.logical
@@ -386,7 +390,7 @@ private[sql] class DataFrameImpl protected[sql](
       mode: SaveMode,
       options: Map[String, String]): Unit = {
     val cmd =
-      CreateTableUsingAsLogicalPlan(
+      CreateTableUsingAsSelect(
         tableName,
         source,
         temporary = false,

http://git-wip-us.apache.org/repos/asf/spark/blob/5b6cd65c/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 1442250..d08c2d1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -92,7 +92,8 @@ class SQLContext(@transient val sparkContext: SparkContext)
   @transient
   protected[sql] lazy val analyzer: Analyzer =
     new Analyzer(catalog, functionRegistry, caseSensitive = true) {
-      override val extendedRules =
+      override val extendedResolutionRules =
+        sources.PreWriteCheck(catalog) ::
         sources.PreInsertCastAndRename ::
         Nil
     }
@@ -101,7 +102,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
   protected[sql] lazy val optimizer: Optimizer = DefaultOptimizer
 
   @transient
-  protected[sql] val ddlParser = new DDLParser
+  protected[sql] val ddlParser = new DDLParser(sqlParser.apply(_))
 
   @transient
   protected[sql] val sqlParser = {

http://git-wip-us.apache.org/repos/asf/spark/blob/5b6cd65c/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index e915e0e..5281c75 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -319,18 +319,10 @@ private[sql] abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
         sys.error("allowExisting should be set to false when creating a 
temporary table.")
 
       case CreateTableUsingAsSelect(tableName, provider, true, mode, opts, 
query) =>
-        val logicalPlan = sqlContext.parseSql(query)
-        val cmd =
-          CreateTempTableUsingAsSelect(tableName, provider, mode, opts, 
logicalPlan)
-        ExecutedCommand(cmd) :: Nil
-      case c: CreateTableUsingAsSelect if !c.temporary =>
-        sys.error("Tables created with SQLContext must be TEMPORARY. Use a 
HiveContext instead.")
-
-      case CreateTableUsingAsLogicalPlan(tableName, provider, true, mode, 
opts, query) =>
         val cmd =
           CreateTempTableUsingAsSelect(tableName, provider, mode, opts, query)
         ExecutedCommand(cmd) :: Nil
-      case c: CreateTableUsingAsLogicalPlan if !c.temporary =>
+      case c: CreateTableUsingAsSelect if !c.temporary =>
         sys.error("Tables created with SQLContext must be TEMPORARY. Use a 
HiveContext instead.")
 
       case LogicalDescribeCommand(table, isExtended) =>

http://git-wip-us.apache.org/repos/asf/spark/blob/5b6cd65c/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala
index a853385..67f3507 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala
@@ -55,10 +55,7 @@ private[sql] object DataSourceStrategy extends Strategy {
       execution.PhysicalRDD(l.output, t.buildScan()) :: Nil
 
     case i @ logical.InsertIntoTable(
-      l @ LogicalRelation(t: InsertableRelation), partition, query, overwrite) 
=>
-      if (partition.nonEmpty) {
-        sys.error(s"Insert into a partition is not allowed because $l is not 
partitioned.")
-      }
+      l @ LogicalRelation(t: InsertableRelation), part, query, overwrite) if 
part.isEmpty =>
       execution.ExecutedCommand(InsertIntoDataSource(l, query, overwrite)) :: 
Nil
 
     case _ => Nil

http://git-wip-us.apache.org/repos/asf/spark/blob/5b6cd65c/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
index 1b5e8c2..dd8b3d2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.{SaveMode, DataFrame, SQLContext}
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.AbstractSparkSQLParser
 import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
-import org.apache.spark.sql.catalyst.expressions.{Row, AttributeReference}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, 
AttributeReference, Row}
 import org.apache.spark.sql.execution.RunnableCommand
 import org.apache.spark.sql.types._
 import org.apache.spark.util.Utils
@@ -32,7 +32,8 @@ import org.apache.spark.util.Utils
 /**
  * A parser for foreign DDL commands.
  */
-private[sql] class DDLParser extends AbstractSparkSQLParser with Logging {
+private[sql] class DDLParser(
+    parseQuery: String => LogicalPlan) extends AbstractSparkSQLParser with 
Logging {
 
   def apply(input: String, exceptionOnError: Boolean): Option[LogicalPlan] = {
     try {
@@ -105,6 +106,7 @@ private[sql] class DDLParser extends AbstractSparkSQLParser 
with Logging {
    * AS SELECT ...
    */
   protected lazy val createTable: Parser[LogicalPlan] =
+    // TODO: Support database.table.
     (CREATE ~> TEMPORARY.? <~ TABLE) ~ (IF ~> NOT <~ EXISTS).? ~ ident ~
       tableCols.? ~ (USING ~> className) ~ (OPTIONS ~> options).? ~ (AS ~> 
restInput).? ^^ {
       case temp ~ allowExisting ~ tableName ~ columns ~ provider ~ opts ~ 
query =>
@@ -128,12 +130,13 @@ private[sql] class DDLParser extends 
AbstractSparkSQLParser with Logging {
             SaveMode.ErrorIfExists
           }
 
+          val queryPlan = parseQuery(query.get)
           CreateTableUsingAsSelect(tableName,
             provider,
             temp.isDefined,
             mode,
             options,
-            query.get)
+            queryPlan)
         } else {
           val userSpecifiedSchema = columns.flatMap(fields => 
Some(StructType(fields)))
           CreateTableUsing(
@@ -345,21 +348,23 @@ private[sql] case class CreateTableUsing(
     allowExisting: Boolean,
     managedIfNoPath: Boolean) extends Command
 
+/**
+ * A node used to support CTAS statements and saveAsTable for the data source 
API.
+ * This node is a [[UnaryNode]] instead of a [[Command]] because we want the 
analyzer
+ * can analyze the logical plan that will be used to populate the table.
+ * So, [[PreWriteCheck]] can detect cases that are not allowed.
+ */
 private[sql] case class CreateTableUsingAsSelect(
     tableName: String,
     provider: String,
     temporary: Boolean,
     mode: SaveMode,
     options: Map[String, String],
-    query: String) extends Command
-
-private[sql] case class CreateTableUsingAsLogicalPlan(
-    tableName: String,
-    provider: String,
-    temporary: Boolean,
-    mode: SaveMode,
-    options: Map[String, String],
-    query: LogicalPlan) extends Command
+    child: LogicalPlan) extends UnaryNode {
+  override def output = Seq.empty[Attribute]
+  // TODO: Override resolved after we support databaseName.
+  // override lazy val resolved = databaseName != None && childrenResolved
+}
 
 private[sql] case class CreateTempTableUsing(
     tableName: String,

http://git-wip-us.apache.org/repos/asf/spark/blob/5b6cd65c/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala
index 4ed22d3..36a9c0b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala
@@ -17,7 +17,10 @@
 
 package org.apache.spark.sql.sources
 
+import org.apache.spark.sql.{SaveMode, AnalysisException}
+import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, Catalog}
 import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Alias}
+import org.apache.spark.sql.catalyst.plans.logical
 import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, 
LogicalPlan, Project}
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.types.DataType
@@ -26,11 +29,9 @@ import org.apache.spark.sql.types.DataType
  * A rule to do pre-insert data type casting and field renaming. Before we 
insert into
  * an [[InsertableRelation]], we will use this rule to make sure that
  * the columns to be inserted have the correct data type and fields have the 
correct names.
- * @param resolver The resolver used by the Analyzer.
  */
 private[sql] object PreInsertCastAndRename extends Rule[LogicalPlan] {
-  def apply(plan: LogicalPlan): LogicalPlan = {
-    plan.transform {
+  def apply(plan: LogicalPlan): LogicalPlan = plan transform {
       // Wait until children are resolved.
       case p: LogicalPlan if !p.childrenResolved => p
 
@@ -46,7 +47,6 @@ private[sql] object PreInsertCastAndRename extends 
Rule[LogicalPlan] {
         }
         castAndRenameChildOutput(i, l.output, child)
       }
-    }
   }
 
   /** If necessary, cast data types and rename fields to the expected types 
and names. */
@@ -74,3 +74,67 @@ private[sql] object PreInsertCastAndRename extends 
Rule[LogicalPlan] {
     }
   }
 }
+
+/**
+ * A rule to do various checks before inserting into or writing to a data 
source table.
+ */
+private[sql] case class PreWriteCheck(catalog: Catalog) extends 
Rule[LogicalPlan] {
+  def failAnalysis(msg: String) = { throw new AnalysisException(msg) }
+
+  def apply(plan: LogicalPlan): LogicalPlan = {
+    plan.foreach {
+      case i @ logical.InsertIntoTable(
+        l @ LogicalRelation(t: InsertableRelation), partition, query, 
overwrite) =>
+        // Right now, we do not support insert into a data source table with 
partition specs.
+        if (partition.nonEmpty) {
+          failAnalysis(s"Insert into a partition is not allowed because $l is 
not partitioned.")
+        } else {
+          // Get all input data source relations of the query.
+          val srcRelations = query.collect {
+            case LogicalRelation(src: BaseRelation) => src
+          }
+          if (srcRelations.exists(src => src == t)) {
+            failAnalysis(
+              "Cannot insert overwrite into table that is also being read 
from.")
+          } else {
+            // OK
+          }
+        }
+
+      case i @ logical.InsertIntoTable(
+        l: LogicalRelation, partition, query, overwrite) if 
!l.isInstanceOf[InsertableRelation] =>
+        // The relation in l is not an InsertableRelation.
+        failAnalysis(s"$l does not allow insertion.")
+
+      case CreateTableUsingAsSelect(tableName, _, _, SaveMode.Overwrite, _, 
query) =>
+        // When the SaveMode is Overwrite, we need to check if the table is an 
input table of
+        // the query. If so, we will throw an AnalysisException to let users 
know it is not allowed.
+        if (catalog.tableExists(Seq(tableName))) {
+          // Need to remove SubQuery operator.
+          EliminateSubQueries(catalog.lookupRelation(Seq(tableName))) match {
+            // Only do the check if the table is a data source table
+            // (the relation is a BaseRelation).
+            case l @ LogicalRelation(dest: BaseRelation) =>
+              // Get all input data source relations of the query.
+              val srcRelations = query.collect {
+                case LogicalRelation(src: BaseRelation) => src
+              }
+              if (srcRelations.exists(src => src == dest)) {
+                failAnalysis(
+                  s"Cannot overwrite table $tableName that is also being read 
from.")
+              } else {
+                // OK
+              }
+
+            case _ => // OK
+          }
+        } else {
+          // OK
+        }
+
+      case _ => // OK
+    }
+
+    plan
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/5b6cd65c/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
index d066545..9318c15 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
@@ -38,21 +38,22 @@ class ParquetQuerySuiteBase extends QueryTest with 
ParquetTest {
 
   test("appending") {
     val data = (0 until 10).map(i => (i, i.toString))
+    createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp")
     withParquetTable(data, "t") {
-      sql("INSERT INTO TABLE t SELECT * FROM t")
+      sql("INSERT INTO TABLE t SELECT * FROM tmp")
       checkAnswer(table("t"), (data ++ data).map(Row.fromTuple))
     }
+    catalog.unregisterTable(Seq("tmp"))
   }
 
-  // This test case will trigger the NPE mentioned in
-  // https://issues.apache.org/jira/browse/PARQUET-151.
-  // Update: This also triggers SPARK-5746, should re enable it when we get 
both fixed.
-  ignore("overwriting") {
+  test("overwriting") {
     val data = (0 until 10).map(i => (i, i.toString))
+    createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp")
     withParquetTable(data, "t") {
-      sql("INSERT OVERWRITE TABLE t SELECT * FROM t")
+      sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp")
       checkAnswer(table("t"), data.map(Row.fromTuple))
     }
+    catalog.unregisterTable(Seq("tmp"))
   }
 
   test("self-join") {

http://git-wip-us.apache.org/repos/asf/spark/blob/5b6cd65c/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala
index 29caed9..6035541 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.sources
 
 import java.io.File
 
+import org.apache.spark.sql.AnalysisException
 import org.scalatest.BeforeAndAfterAll
 
 import org.apache.spark.sql.catalyst.util
@@ -157,4 +158,31 @@ class CreateTableAsSelectSuite extends DataSourceTest with 
BeforeAndAfterAll {
       """.stripMargin)
     }
   }
+
+  test("it is not allowed to write to a table while querying it.") {
+    sql(
+      s"""
+        |CREATE TEMPORARY TABLE jsonTable
+        |USING org.apache.spark.sql.json.DefaultSource
+        |OPTIONS (
+        |  path '${path.toString}'
+        |) AS
+        |SELECT a, b FROM jt
+      """.stripMargin)
+
+    val message = intercept[AnalysisException] {
+      sql(
+        s"""
+        |CREATE TEMPORARY TABLE jsonTable
+        |USING org.apache.spark.sql.json.DefaultSource
+        |OPTIONS (
+        |  path '${path.toString}'
+        |) AS
+        |SELECT a, b FROM jsonTable
+      """.stripMargin)
+    }.getMessage
+    assert(
+      message.contains("Cannot overwrite table "),
+      "Writing to a table while querying it should not be allowed.")
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/5b6cd65c/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala
index 53f5f74..0ec6881 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala
@@ -29,7 +29,8 @@ abstract class DataSourceTest extends QueryTest with 
BeforeAndAfter {
     @transient
     override protected[sql] lazy val analyzer: Analyzer =
       new Analyzer(catalog, functionRegistry, caseSensitive = false) {
-        override val extendedRules =
+        override val extendedResolutionRules =
+          PreWriteCheck(catalog) ::
           PreInsertCastAndRename ::
           Nil
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/5b6cd65c/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertIntoSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertIntoSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertIntoSuite.scala
deleted file mode 100644
index 36e504e..0000000
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertIntoSuite.scala
+++ /dev/null
@@ -1,176 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.sources
-
-import java.io.File
-
-import org.scalatest.BeforeAndAfterAll
-
-import org.apache.spark.sql.Row
-import org.apache.spark.sql.catalyst.util
-import org.apache.spark.util.Utils
-
-class InsertIntoSuite extends DataSourceTest with BeforeAndAfterAll {
-
-  import caseInsensisitiveContext._
-
-  var path: File = null
-
-  override def beforeAll: Unit = {
-    path = util.getTempFilePath("jsonCTAS").getCanonicalFile
-    val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, 
"b":"str${i}"}"""))
-    jsonRDD(rdd).registerTempTable("jt")
-    sql(
-      s"""
-        |CREATE TEMPORARY TABLE jsonTable (a int, b string)
-        |USING org.apache.spark.sql.json.DefaultSource
-        |OPTIONS (
-        |  path '${path.toString}'
-        |)
-      """.stripMargin)
-  }
-
-  override def afterAll: Unit = {
-    dropTempTable("jsonTable")
-    dropTempTable("jt")
-    if (path.exists()) Utils.deleteRecursively(path)
-  }
-
-  test("Simple INSERT OVERWRITE a JSONRelation") {
-    sql(
-      s"""
-        |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt
-      """.stripMargin)
-
-    checkAnswer(
-      sql("SELECT a, b FROM jsonTable"),
-      (1 to 10).map(i => Row(i, s"str$i"))
-    )
-  }
-
-  test("PreInsert casting and renaming") {
-    sql(
-      s"""
-        |INSERT OVERWRITE TABLE jsonTable SELECT a * 2, a * 4 FROM jt
-      """.stripMargin)
-
-    checkAnswer(
-      sql("SELECT a, b FROM jsonTable"),
-      (1 to 10).map(i => Row(i * 2, s"${i * 4}"))
-    )
-
-    sql(
-      s"""
-        |INSERT OVERWRITE TABLE jsonTable SELECT a * 4 AS A, a * 6 as c FROM jt
-      """.stripMargin)
-
-    checkAnswer(
-      sql("SELECT a, b FROM jsonTable"),
-      (1 to 10).map(i => Row(i * 4, s"${i * 6}"))
-    )
-  }
-
-  test("SELECT clause generating a different number of columns is not 
allowed.") {
-    val message = intercept[RuntimeException] {
-      sql(
-        s"""
-        |INSERT OVERWRITE TABLE jsonTable SELECT a FROM jt
-      """.stripMargin)
-    }.getMessage
-    assert(
-      message.contains("generates the same number of columns as its schema"),
-      "SELECT clause generating a different number of columns should not be 
not allowed."
-    )
-  }
-
-  test("INSERT OVERWRITE a JSONRelation multiple times") {
-    sql(
-      s"""
-        |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt
-      """.stripMargin)
-
-    sql(
-      s"""
-        |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt
-      """.stripMargin)
-
-    sql(
-      s"""
-        |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt
-      """.stripMargin)
-
-    checkAnswer(
-      sql("SELECT a, b FROM jsonTable"),
-      (1 to 10).map(i => Row(i, s"str$i"))
-    )
-  }
-
-  test("INSERT INTO not supported for JSONRelation for now") {
-    intercept[RuntimeException]{
-      sql(
-        s"""
-        |INSERT INTO TABLE jsonTable SELECT a, b FROM jt
-      """.stripMargin)
-    }
-  }
-
-  test("Caching")  {
-    // Cached Query Execution
-    cacheTable("jsonTable")
-    assertCached(sql("SELECT * FROM jsonTable"))
-    checkAnswer(
-      sql("SELECT * FROM jsonTable"),
-      (1 to 10).map(i => Row(i, s"str$i")))
-
-    assertCached(sql("SELECT a FROM jsonTable"))
-    checkAnswer(
-      sql("SELECT a FROM jsonTable"),
-      (1 to 10).map(Row(_)).toSeq)
-
-    assertCached(sql("SELECT a FROM jsonTable WHERE a < 5"))
-    checkAnswer(
-      sql("SELECT a FROM jsonTable WHERE a < 5"),
-      (1 to 4).map(Row(_)).toSeq)
-
-    assertCached(sql("SELECT a * 2 FROM jsonTable"))
-    checkAnswer(
-      sql("SELECT a * 2 FROM jsonTable"),
-      (1 to 10).map(i => Row(i * 2)).toSeq)
-
-    assertCached(sql("SELECT x.a, y.a FROM jsonTable x JOIN jsonTable y ON x.a 
= y.a + 1"), 2)
-    checkAnswer(
-      sql("SELECT x.a, y.a FROM jsonTable x JOIN jsonTable y ON x.a = y.a + 
1"),
-      (2 to 10).map(i => Row(i, i - 1)).toSeq)
-
-    // Insert overwrite and keep the same schema.
-    sql(
-      s"""
-        |INSERT OVERWRITE TABLE jsonTable SELECT a * 2, b FROM jt
-      """.stripMargin)
-    // jsonTable should be recached.
-    assertCached(sql("SELECT * FROM jsonTable"))
-    // The cached data is the new data.
-    checkAnswer(
-      sql("SELECT a, b FROM jsonTable"),
-      sql("SELECT a * 2, b FROM jt").collect())
-
-    // Verify uncaching
-    uncacheTable("jsonTable")
-    assertCached(sql("SELECT * FROM jsonTable"), 0)
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/5b6cd65c/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala
new file mode 100644
index 0000000..5682e5a
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala
@@ -0,0 +1,218 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.sources
+
+import java.io.File
+
+import org.scalatest.BeforeAndAfterAll
+
+import org.apache.spark.sql.{AnalysisException, Row}
+import org.apache.spark.sql.catalyst.util
+import org.apache.spark.util.Utils
+
+class InsertSuite extends DataSourceTest with BeforeAndAfterAll {
+
+  import caseInsensisitiveContext._
+
+  var path: File = null
+
+  override def beforeAll: Unit = {
+    path = util.getTempFilePath("jsonCTAS").getCanonicalFile
+    val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, 
"b":"str${i}"}"""))
+    jsonRDD(rdd).registerTempTable("jt")
+    sql(
+      s"""
+        |CREATE TEMPORARY TABLE jsonTable (a int, b string)
+        |USING org.apache.spark.sql.json.DefaultSource
+        |OPTIONS (
+        |  path '${path.toString}'
+        |)
+      """.stripMargin)
+  }
+
+  override def afterAll: Unit = {
+    dropTempTable("jsonTable")
+    dropTempTable("jt")
+    if (path.exists()) Utils.deleteRecursively(path)
+  }
+
+  test("Simple INSERT OVERWRITE a JSONRelation") {
+    sql(
+      s"""
+        |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt
+      """.stripMargin)
+
+    checkAnswer(
+      sql("SELECT a, b FROM jsonTable"),
+      (1 to 10).map(i => Row(i, s"str$i"))
+    )
+  }
+
+  test("PreInsert casting and renaming") {
+    sql(
+      s"""
+        |INSERT OVERWRITE TABLE jsonTable SELECT a * 2, a * 4 FROM jt
+      """.stripMargin)
+
+    checkAnswer(
+      sql("SELECT a, b FROM jsonTable"),
+      (1 to 10).map(i => Row(i * 2, s"${i * 4}"))
+    )
+
+    sql(
+      s"""
+        |INSERT OVERWRITE TABLE jsonTable SELECT a * 4 AS A, a * 6 as c FROM jt
+      """.stripMargin)
+
+    checkAnswer(
+      sql("SELECT a, b FROM jsonTable"),
+      (1 to 10).map(i => Row(i * 4, s"${i * 6}"))
+    )
+  }
+
+  test("SELECT clause generating a different number of columns is not 
allowed.") {
+    val message = intercept[RuntimeException] {
+      sql(
+        s"""
+        |INSERT OVERWRITE TABLE jsonTable SELECT a FROM jt
+      """.stripMargin)
+    }.getMessage
+    assert(
+      message.contains("generates the same number of columns as its schema"),
+      "SELECT clause generating a different number of columns should not be 
not allowed."
+    )
+  }
+
+  test("INSERT OVERWRITE a JSONRelation multiple times") {
+    sql(
+      s"""
+        |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt
+      """.stripMargin)
+
+    sql(
+      s"""
+        |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt
+      """.stripMargin)
+
+    sql(
+      s"""
+        |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt
+      """.stripMargin)
+
+    checkAnswer(
+      sql("SELECT a, b FROM jsonTable"),
+      (1 to 10).map(i => Row(i, s"str$i"))
+    )
+  }
+
+  test("INSERT INTO not supported for JSONRelation for now") {
+    intercept[RuntimeException]{
+      sql(
+        s"""
+        |INSERT INTO TABLE jsonTable SELECT a, b FROM jt
+      """.stripMargin)
+    }
+  }
+
+  test("it is not allowed to write to a table while querying it.") {
+    val message = intercept[AnalysisException] {
+      sql(
+        s"""
+        |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jsonTable
+      """.stripMargin)
+    }.getMessage
+    assert(
+      message.contains("Cannot insert overwrite into table that is also being 
read from."),
+      "INSERT OVERWRITE to a table while querying it should not be allowed.")
+  }
+
+  test("Caching")  {
+    // Cached Query Execution
+    cacheTable("jsonTable")
+    assertCached(sql("SELECT * FROM jsonTable"))
+    checkAnswer(
+      sql("SELECT * FROM jsonTable"),
+      (1 to 10).map(i => Row(i, s"str$i")))
+
+    assertCached(sql("SELECT a FROM jsonTable"))
+    checkAnswer(
+      sql("SELECT a FROM jsonTable"),
+      (1 to 10).map(Row(_)).toSeq)
+
+    assertCached(sql("SELECT a FROM jsonTable WHERE a < 5"))
+    checkAnswer(
+      sql("SELECT a FROM jsonTable WHERE a < 5"),
+      (1 to 4).map(Row(_)).toSeq)
+
+    assertCached(sql("SELECT a * 2 FROM jsonTable"))
+    checkAnswer(
+      sql("SELECT a * 2 FROM jsonTable"),
+      (1 to 10).map(i => Row(i * 2)).toSeq)
+
+    assertCached(sql("SELECT x.a, y.a FROM jsonTable x JOIN jsonTable y ON x.a 
= y.a + 1"), 2)
+    checkAnswer(
+      sql("SELECT x.a, y.a FROM jsonTable x JOIN jsonTable y ON x.a = y.a + 
1"),
+      (2 to 10).map(i => Row(i, i - 1)).toSeq)
+
+    // Insert overwrite and keep the same schema.
+    sql(
+      s"""
+        |INSERT OVERWRITE TABLE jsonTable SELECT a * 2, b FROM jt
+      """.stripMargin)
+    // jsonTable should be recached.
+    assertCached(sql("SELECT * FROM jsonTable"))
+    // The cached data is the new data.
+    checkAnswer(
+      sql("SELECT a, b FROM jsonTable"),
+      sql("SELECT a * 2, b FROM jt").collect())
+
+    // Verify uncaching
+    uncacheTable("jsonTable")
+    assertCached(sql("SELECT * FROM jsonTable"), 0)
+  }
+
+  test("it's not allowed to insert into a relation that is not an 
InsertableRelation") {
+    sql(
+      """
+        |CREATE TEMPORARY TABLE oneToTen
+        |USING org.apache.spark.sql.sources.SimpleScanSource
+        |OPTIONS (
+        |  From '1',
+        |  To '10'
+        |)
+      """.stripMargin)
+
+    checkAnswer(
+      sql("SELECT * FROM oneToTen"),
+      (1 to 10).map(Row(_)).toSeq
+    )
+
+    val message = intercept[AnalysisException] {
+      sql(
+        s"""
+        |INSERT OVERWRITE TABLE oneToTen SELECT a FROM jt
+        """.stripMargin)
+    }.getMessage
+    assert(
+      message.contains("does not allow insertion."),
+      "It is not allowed to insert into a table that is not an 
InsertableRelation."
+    )
+
+    dropTempTable("oneToTen")
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/5b6cd65c/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index 87b380f..6c55bc6 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.analysis.{Analyzer, 
EliminateSubQueries, Ov
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUdfs, 
QueryExecutionException, SetCommand}
 import org.apache.spark.sql.hive.execution.{DescribeHiveTableCommand, 
HiveNativeCommand}
-import org.apache.spark.sql.sources.DataSourceStrategy
+import org.apache.spark.sql.sources.{DDLParser, DataSourceStrategy}
 import org.apache.spark.sql.types._
 
 /**
@@ -64,14 +64,17 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
   override protected[sql] def executePlan(plan: LogicalPlan): 
this.QueryExecution =
     new this.QueryExecution(plan)
 
+  @transient
+  protected[sql] val ddlParserWithHiveQL = new DDLParser(HiveQl.parseSql(_))
+
   override def sql(sqlText: String): DataFrame = {
     val substituted = new VariableSubstitution().substitute(hiveconf, sqlText)
     // TODO: Create a framework for registering parsers instead of just 
hardcoding if statements.
     if (conf.dialect == "sql") {
       super.sql(substituted)
     } else if (conf.dialect == "hiveql") {
-      DataFrame(this,
-        ddlParser(sqlText, exceptionOnError = 
false).getOrElse(HiveQl.parseSql(substituted)))
+      val ddlPlan = ddlParserWithHiveQL(sqlText, exceptionOnError = false)
+      DataFrame(this, ddlPlan.getOrElse(HiveQl.parseSql(substituted)))
     }  else {
       sys.error(s"Unsupported SQL dialect: ${conf.dialect}. Try 'sql' or 
'hiveql'")
     }
@@ -241,12 +244,13 @@ class HiveContext(sc: SparkContext) extends 
SQLContext(sc) {
   @transient
   override protected[sql] lazy val analyzer =
     new Analyzer(catalog, functionRegistry, caseSensitive = false) {
-      override val extendedRules =
+      override val extendedResolutionRules =
         catalog.ParquetConversions ::
         catalog.CreateTables ::
         catalog.PreInsertionCasts ::
         ExtractPythonUdfs ::
         ResolveUdtfsAlias ::
+        sources.PreWriteCheck(catalog) ::
         sources.PreInsertCastAndRename ::
         Nil
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/5b6cd65c/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
index 580c570..72211fe 100644
--- 
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
+++ 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -663,7 +663,7 @@ private[hive] case class MetastoreRelation
 }
 
 object HiveMetastoreTypes {
-  protected val ddlParser = new DDLParser
+  protected val ddlParser = new DDLParser(HiveQl.parseSql(_))
 
   def toDataType(metastoreType: String): DataType = synchronized {
     ddlParser.parseType(metastoreType)

http://git-wip-us.apache.org/repos/asf/spark/blob/5b6cd65c/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
index 965d159..d2c39ab 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
@@ -34,7 +34,7 @@ import org.apache.spark.sql.execution.{DescribeCommand => 
RunnableDescribeComman
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.hive.execution._
 import org.apache.spark.sql.parquet.ParquetRelation
-import org.apache.spark.sql.sources.{CreateTableUsingAsLogicalPlan, 
CreateTableUsingAsSelect, CreateTableUsing}
+import org.apache.spark.sql.sources.{CreateTableUsingAsSelect, 
CreateTableUsing}
 import org.apache.spark.sql.types.StringType
 
 
@@ -227,12 +227,6 @@ private[hive] trait HiveStrategies {
             tableName, userSpecifiedSchema, provider, opts, allowExisting, 
managedIfNoPath)) :: Nil
 
       case CreateTableUsingAsSelect(tableName, provider, false, mode, opts, 
query) =>
-        val logicalPlan = hiveContext.parseSql(query)
-        val cmd =
-          CreateMetastoreDataSourceAsSelect(tableName, provider, mode, opts, 
logicalPlan)
-        ExecutedCommand(cmd) :: Nil
-
-      case CreateTableUsingAsLogicalPlan(tableName, provider, false, mode, 
opts, query) =>
         val cmd =
           CreateMetastoreDataSourceAsSelect(tableName, provider, mode, opts, 
query)
         ExecutedCommand(cmd) :: Nil


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to