Repository: spark Updated Branches: refs/heads/branch-1.3 8e75b0ed0 -> a21090ebe
http://git-wip-us.apache.org/repos/asf/spark/blob/a21090eb/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala deleted file mode 100644 index 7c1d113..0000000 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala +++ /dev/null @@ -1,455 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.test - -import java.io.File -import java.util.{Set => JavaSet} - -import scala.collection.mutable -import scala.language.implicitConversions - -import org.apache.hadoop.hive.ql.exec.FunctionRegistry -import org.apache.hadoop.hive.ql.io.avro.{AvroContainerInputFormat, AvroContainerOutputFormat} -import org.apache.hadoop.hive.ql.metadata.Table -import org.apache.hadoop.hive.ql.processors._ -import org.apache.hadoop.hive.serde2.RegexSerDe -import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe -import org.apache.hadoop.hive.serde2.avro.AvroSerDe - -import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.util.Utils -import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.execution.CacheTableCommand -import org.apache.spark.sql.hive._ -import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.hive.execution.HiveNativeCommand - -/* Implicit conversions */ -import scala.collection.JavaConversions._ - -// SPARK-3729: Test key required to check for initialization errors with config. -object TestHive - extends TestHiveContext( - new SparkContext("local[2]", "TestSQLContext", new SparkConf().set("spark.sql.test", ""))) - -/** - * A locally running test instance of Spark's Hive execution engine. - * - * Data from [[testTables]] will be automatically loaded whenever a query is run over those tables. - * Calling [[reset]] will delete all tables and other state in the database, leaving the database - * in a "clean" state. - * - * TestHive is singleton object version of this class because instantiating multiple copies of the - * hive metastore seems to lead to weird non-deterministic failures. Therefore, the execution of - * test cases that rely on TestHive must be serialized. - */ -class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { - self => - - // By clearing the port we force Spark to pick a new one. This allows us to rerun tests - // without restarting the JVM. - System.clearProperty("spark.hostPort") - CommandProcessorFactory.clean(hiveconf) - - hiveconf.set("hive.plan.serialization.format", "javaXML") - - lazy val warehousePath = getTempFilePath("sparkHiveWarehouse").getCanonicalPath - lazy val metastorePath = getTempFilePath("sparkHiveMetastore").getCanonicalPath - - /** Sets up the system initially or after a RESET command */ - protected def configure(): Unit = { - setConf("javax.jdo.option.ConnectionURL", - s"jdbc:derby:;databaseName=$metastorePath;create=true") - setConf("hive.metastore.warehouse.dir", warehousePath) - Utils.registerShutdownDeleteDir(new File(warehousePath)) - Utils.registerShutdownDeleteDir(new File(metastorePath)) - } - - val testTempDir = File.createTempFile("testTempFiles", "spark.hive.tmp") - testTempDir.delete() - testTempDir.mkdir() - Utils.registerShutdownDeleteDir(testTempDir) - - // For some hive test case which contain ${system:test.tmp.dir} - System.setProperty("test.tmp.dir", testTempDir.getCanonicalPath) - - configure() // Must be called before initializing the catalog below. - - /** The location of the compiled hive distribution */ - lazy val hiveHome = envVarToFile("HIVE_HOME") - /** The location of the hive source code. */ - lazy val hiveDevHome = envVarToFile("HIVE_DEV_HOME") - - // Override so we can intercept relative paths and rewrite them to point at hive. - override def runSqlHive(sql: String): Seq[String] = super.runSqlHive(rewritePaths(sql)) - - override def executePlan(plan: LogicalPlan): this.QueryExecution = - new this.QueryExecution(plan) - - /** Fewer partitions to speed up testing. */ - protected[sql] override lazy val conf: SQLConf = new SQLConf { - override def numShufflePartitions: Int = getConf(SQLConf.SHUFFLE_PARTITIONS, "5").toInt - override def dialect: String = getConf(SQLConf.DIALECT, "hiveql") - } - - /** - * Returns the value of specified environmental variable as a [[java.io.File]] after checking - * to ensure it exists - */ - private def envVarToFile(envVar: String): Option[File] = { - Option(System.getenv(envVar)).map(new File(_)) - } - - /** - * Replaces relative paths to the parent directory "../" with hiveDevHome since this is how the - * hive test cases assume the system is set up. - */ - private def rewritePaths(cmd: String): String = - if (cmd.toUpperCase contains "LOAD DATA") { - val testDataLocation = - hiveDevHome.map(_.getCanonicalPath).getOrElse(inRepoTests.getCanonicalPath) - cmd.replaceAll("\\.\\./\\.\\./", testDataLocation + "/") - } else { - cmd - } - - val hiveFilesTemp = File.createTempFile("catalystHiveFiles", "") - hiveFilesTemp.delete() - hiveFilesTemp.mkdir() - Utils.registerShutdownDeleteDir(hiveFilesTemp) - - val inRepoTests = if (System.getProperty("user.dir").endsWith("sql" + File.separator + "hive")) { - new File("src" + File.separator + "test" + File.separator + "resources" + File.separator) - } else { - new File("sql" + File.separator + "hive" + File.separator + "src" + File.separator + "test" + - File.separator + "resources") - } - - def getHiveFile(path: String): File = { - val stripped = path.replaceAll("""\.\.\/""", "").replace('/', File.separatorChar) - hiveDevHome - .map(new File(_, stripped)) - .filter(_.exists) - .getOrElse(new File(inRepoTests, stripped)) - } - - val describedTable = "DESCRIBE (\\w+)".r - - protected[hive] class HiveQLQueryExecution(hql: String) - extends this.QueryExecution(HiveQl.parseSql(hql)) { - def hiveExec() = runSqlHive(hql) - override def toString = hql + "\n" + super.toString - } - - /** - * Override QueryExecution with special debug workflow. - */ - class QueryExecution(logicalPlan: LogicalPlan) - extends super.QueryExecution(logicalPlan) { - override lazy val analyzed = { - val describedTables = logical match { - case HiveNativeCommand(describedTable(tbl)) => tbl :: Nil - case CacheTableCommand(tbl, _, _) => tbl :: Nil - case _ => Nil - } - - // Make sure any test tables referenced are loaded. - val referencedTables = - describedTables ++ - logical.collect { case UnresolvedRelation(tableIdent, _) => tableIdent.last } - val referencedTestTables = referencedTables.filter(testTables.contains) - logDebug(s"Query references test tables: ${referencedTestTables.mkString(", ")}") - referencedTestTables.foreach(loadTestTable) - // Proceed with analysis. - analyzer(logical) - } - } - - case class TestTable(name: String, commands: (()=>Unit)*) - - protected[hive] implicit class SqlCmd(sql: String) { - def cmd = () => new HiveQLQueryExecution(sql).stringResult(): Unit - } - - /** - * A list of test tables and the DDL required to initialize them. A test table is loaded on - * demand when a query are run against it. - */ - lazy val testTables = new mutable.HashMap[String, TestTable]() - def registerTestTable(testTable: TestTable) = testTables += (testTable.name -> testTable) - - // The test tables that are defined in the Hive QTestUtil. - // /itests/util/src/main/java/org/apache/hadoop/hive/ql/QTestUtil.java - val hiveQTestUtilTables = Seq( - TestTable("src", - "CREATE TABLE src (key INT, value STRING)".cmd, - s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv1.txt")}' INTO TABLE src".cmd), - TestTable("src1", - "CREATE TABLE src1 (key INT, value STRING)".cmd, - s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv3.txt")}' INTO TABLE src1".cmd), - TestTable("srcpart", () => { - runSqlHive( - "CREATE TABLE srcpart (key INT, value STRING) PARTITIONED BY (ds STRING, hr STRING)") - for (ds <- Seq("2008-04-08", "2008-04-09"); hr <- Seq("11", "12")) { - runSqlHive( - s"""LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv1.txt")}' - |OVERWRITE INTO TABLE srcpart PARTITION (ds='$ds',hr='$hr') - """.stripMargin) - } - }), - TestTable("srcpart1", () => { - runSqlHive("CREATE TABLE srcpart1 (key INT, value STRING) PARTITIONED BY (ds STRING, hr INT)") - for (ds <- Seq("2008-04-08", "2008-04-09"); hr <- 11 to 12) { - runSqlHive( - s"""LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv1.txt")}' - |OVERWRITE INTO TABLE srcpart1 PARTITION (ds='$ds',hr='$hr') - """.stripMargin) - } - }), - TestTable("src_thrift", () => { - import org.apache.thrift.protocol.TBinaryProtocol - import org.apache.hadoop.hive.serde2.thrift.test.Complex - import org.apache.hadoop.hive.serde2.thrift.ThriftDeserializer - import org.apache.hadoop.mapred.SequenceFileInputFormat - import org.apache.hadoop.mapred.SequenceFileOutputFormat - - val srcThrift = new Table("default", "src_thrift") - srcThrift.setFields(Nil) - srcThrift.setInputFormatClass(classOf[SequenceFileInputFormat[_,_]].getName) - // In Hive, SequenceFileOutputFormat will be substituted by HiveSequenceFileOutputFormat. - srcThrift.setOutputFormatClass(classOf[SequenceFileOutputFormat[_,_]].getName) - srcThrift.setSerializationLib(classOf[ThriftDeserializer].getName) - srcThrift.setSerdeParam("serialization.class", classOf[Complex].getName) - srcThrift.setSerdeParam("serialization.format", classOf[TBinaryProtocol].getName) - catalog.client.createTable(srcThrift) - - - runSqlHive( - s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/complex.seq")}' INTO TABLE src_thrift") - }), - TestTable("serdeins", - s"""CREATE TABLE serdeins (key INT, value STRING) - |ROW FORMAT SERDE '${classOf[LazySimpleSerDe].getCanonicalName}' - |WITH SERDEPROPERTIES ('field.delim'='\\t') - """.stripMargin.cmd, - "INSERT OVERWRITE TABLE serdeins SELECT * FROM src".cmd), - TestTable("sales", - s"""CREATE TABLE IF NOT EXISTS sales (key STRING, value INT) - |ROW FORMAT SERDE '${classOf[RegexSerDe].getCanonicalName}' - |WITH SERDEPROPERTIES ("input.regex" = "([^ ]*)\t([^ ]*)") - """.stripMargin.cmd, - s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/sales.txt")}' INTO TABLE sales".cmd), - TestTable("episodes", - s"""CREATE TABLE episodes (title STRING, air_date STRING, doctor INT) - |ROW FORMAT SERDE '${classOf[AvroSerDe].getCanonicalName}' - |STORED AS - |INPUTFORMAT '${classOf[AvroContainerInputFormat].getCanonicalName}' - |OUTPUTFORMAT '${classOf[AvroContainerOutputFormat].getCanonicalName}' - |TBLPROPERTIES ( - | 'avro.schema.literal'='{ - | "type": "record", - | "name": "episodes", - | "namespace": "testing.hive.avro.serde", - | "fields": [ - | { - | "name": "title", - | "type": "string", - | "doc": "episode title" - | }, - | { - | "name": "air_date", - | "type": "string", - | "doc": "initial date" - | }, - | { - | "name": "doctor", - | "type": "int", - | "doc": "main actor playing the Doctor in episode" - | } - | ] - | }' - |) - """.stripMargin.cmd, - s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/episodes.avro")}' INTO TABLE episodes".cmd - ), - // THIS TABLE IS NOT THE SAME AS THE HIVE TEST TABLE episodes_partitioned AS DYNAMIC PARITIONING - // IS NOT YET SUPPORTED - TestTable("episodes_part", - s"""CREATE TABLE episodes_part (title STRING, air_date STRING, doctor INT) - |PARTITIONED BY (doctor_pt INT) - |ROW FORMAT SERDE '${classOf[AvroSerDe].getCanonicalName}' - |STORED AS - |INPUTFORMAT '${classOf[AvroContainerInputFormat].getCanonicalName}' - |OUTPUTFORMAT '${classOf[AvroContainerOutputFormat].getCanonicalName}' - |TBLPROPERTIES ( - | 'avro.schema.literal'='{ - | "type": "record", - | "name": "episodes", - | "namespace": "testing.hive.avro.serde", - | "fields": [ - | { - | "name": "title", - | "type": "string", - | "doc": "episode title" - | }, - | { - | "name": "air_date", - | "type": "string", - | "doc": "initial date" - | }, - | { - | "name": "doctor", - | "type": "int", - | "doc": "main actor playing the Doctor in episode" - | } - | ] - | }' - |) - """.stripMargin.cmd, - // WORKAROUND: Required to pass schema to SerDe for partitioned tables. - // TODO: Pass this automatically from the table to partitions. - s""" - |ALTER TABLE episodes_part SET SERDEPROPERTIES ( - | 'avro.schema.literal'='{ - | "type": "record", - | "name": "episodes", - | "namespace": "testing.hive.avro.serde", - | "fields": [ - | { - | "name": "title", - | "type": "string", - | "doc": "episode title" - | }, - | { - | "name": "air_date", - | "type": "string", - | "doc": "initial date" - | }, - | { - | "name": "doctor", - | "type": "int", - | "doc": "main actor playing the Doctor in episode" - | } - | ] - | }' - |) - """.stripMargin.cmd, - s""" - INSERT OVERWRITE TABLE episodes_part PARTITION (doctor_pt=1) - SELECT title, air_date, doctor FROM episodes - """.cmd - ) - ) - - hiveQTestUtilTables.foreach(registerTestTable) - - private val loadedTables = new collection.mutable.HashSet[String] - - var cacheTables: Boolean = false - def loadTestTable(name: String) { - if (!(loadedTables contains name)) { - // Marks the table as loaded first to prevent infinite mutually recursive table loading. - loadedTables += name - logInfo(s"Loading test table $name") - val createCmds = - testTables.get(name).map(_.commands).getOrElse(sys.error(s"Unknown test table $name")) - createCmds.foreach(_()) - - if (cacheTables) { - cacheTable(name) - } - } - } - - /** - * Records the UDFs present when the server starts, so we can delete ones that are created by - * tests. - */ - protected val originalUdfs: JavaSet[String] = FunctionRegistry.getFunctionNames - - // Database default may not exist in 0.13.1, create it if not exist - HiveShim.createDefaultDBIfNeeded(this) - - /** - * Resets the test instance by deleting any tables that have been created. - * TODO: also clear out UDFs, views, etc. - */ - def reset() { - try { - // HACK: Hive is too noisy by default. - org.apache.log4j.LogManager.getCurrentLoggers.foreach { log => - log.asInstanceOf[org.apache.log4j.Logger].setLevel(org.apache.log4j.Level.WARN) - } - - cacheManager.clearCache() - loadedTables.clear() - catalog.cachedDataSourceTables.invalidateAll() - catalog.client.getAllTables("default").foreach { t => - logDebug(s"Deleting table $t") - val table = catalog.client.getTable("default", t) - - catalog.client.getIndexes("default", t, 255).foreach { index => - catalog.client.dropIndex("default", t, index.getIndexName, true) - } - - if (!table.isIndexTable) { - catalog.client.dropTable("default", t) - } - } - - catalog.client.getAllDatabases.filterNot(_ == "default").foreach { db => - logDebug(s"Dropping Database: $db") - catalog.client.dropDatabase(db, true, false, true) - } - - catalog.unregisterAllTables() - - FunctionRegistry.getFunctionNames.filterNot(originalUdfs.contains(_)).foreach { udfName => - FunctionRegistry.unregisterTemporaryUDF(udfName) - } - - // Some tests corrupt this value on purpose, which breaks the RESET call below. - hiveconf.set("fs.default.name", new File(".").toURI.toString) - // It is important that we RESET first as broken hooks that might have been set could break - // other sql exec here. - runSqlHive("RESET") - // For some reason, RESET does not reset the following variables... - // https://issues.apache.org/jira/browse/HIVE-9004 - runSqlHive("set hive.table.parameters.default=") - runSqlHive("set datanucleus.cache.collections=true") - runSqlHive("set datanucleus.cache.collections.lazy=true") - // Lots of tests fail if we do not change the partition whitelist from the default. - runSqlHive("set hive.metastore.partition.name.whitelist.pattern=.*") - configure() - - runSqlHive("USE default") - - // Just loading src makes a lot of tests pass. This is because some tests do something like - // drop an index on src at the beginning. Since we just pass DDL to hive this bypasses our - // Analyzer and thus the test table auto-loading mechanism. - // Remove after we handle more DDL operations natively. - loadTestTable("src") - loadTestTable("srcpart") - } catch { - case e: Exception => - logError("FATAL ERROR: Failed to reset TestDB state.", e) - } - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/a21090eb/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index 95dcacc..f6bea1c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -18,7 +18,9 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.sources.ResolvedDataSource +import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.sources._ import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -105,7 +107,8 @@ case class CreateMetastoreDataSource( userSpecifiedSchema: Option[StructType], provider: String, options: Map[String, String], - allowExisting: Boolean) extends RunnableCommand { + allowExisting: Boolean, + managedIfNoPath: Boolean) extends RunnableCommand { override def run(sqlContext: SQLContext): Seq[Row] = { val hiveContext = sqlContext.asInstanceOf[HiveContext] @@ -120,7 +123,7 @@ case class CreateMetastoreDataSource( var isExternal = true val optionsWithPath = - if (!options.contains("path")) { + if (!options.contains("path") && managedIfNoPath) { isExternal = false options + ("path" -> hiveContext.catalog.hiveDefaultTableFilePath(tableName)) } else { @@ -141,22 +144,13 @@ case class CreateMetastoreDataSource( case class CreateMetastoreDataSourceAsSelect( tableName: String, provider: String, + mode: SaveMode, options: Map[String, String], - allowExisting: Boolean, query: LogicalPlan) extends RunnableCommand { override def run(sqlContext: SQLContext): Seq[Row] = { val hiveContext = sqlContext.asInstanceOf[HiveContext] - - if (hiveContext.catalog.tableExists(tableName :: Nil)) { - if (allowExisting) { - return Seq.empty[Row] - } else { - sys.error(s"Table $tableName already exists.") - } - } - - val df = DataFrame(hiveContext, query) + var createMetastoreTable = false var isExternal = true val optionsWithPath = if (!options.contains("path")) { @@ -166,15 +160,82 @@ case class CreateMetastoreDataSourceAsSelect( options } - // Create the relation based on the data of df. - ResolvedDataSource(sqlContext, provider, optionsWithPath, df) + if (sqlContext.catalog.tableExists(Seq(tableName))) { + // Check if we need to throw an exception or just return. + mode match { + case SaveMode.ErrorIfExists => + sys.error(s"Table $tableName already exists. " + + s"If you want to append into it, please set mode to SaveMode.Append. " + + s"Or, if you want to overwrite it, please set mode to SaveMode.Overwrite.") + case SaveMode.Ignore => + // Since the table already exists and the save mode is Ignore, we will just return. + return Seq.empty[Row] + case SaveMode.Append => + // Check if the specified data source match the data source of the existing table. + val resolved = + ResolvedDataSource(sqlContext, Some(query.schema), provider, optionsWithPath) + val createdRelation = LogicalRelation(resolved.relation) + EliminateAnalysisOperators(sqlContext.table(tableName).logicalPlan) match { + case l @ LogicalRelation(i: InsertableRelation) => + if (l.schema != createdRelation.schema) { + val errorDescription = + s"Cannot append to table $tableName because the schema of this " + + s"DataFrame does not match the schema of table $tableName." + val errorMessage = + s""" + |$errorDescription + |== Schemas == + |${sideBySide( + s"== Expected Schema ==" +: + l.schema.treeString.split("\\\n"), + s"== Actual Schema ==" +: + createdRelation.schema.treeString.split("\\\n")).mkString("\n")} + """.stripMargin + sys.error(errorMessage) + } else if (i != createdRelation.relation) { + val errorDescription = + s"Cannot append to table $tableName because the resolved relation does not " + + s"match the existing relation of $tableName. " + + s"You can use insertInto($tableName, false) to append this DataFrame to the " + + s"table $tableName and using its data source and options." + val errorMessage = + s""" + |$errorDescription + |== Relations == + |${sideBySide( + s"== Expected Relation ==" :: + l.toString :: Nil, + s"== Actual Relation ==" :: + createdRelation.toString :: Nil).mkString("\n")} + """.stripMargin + sys.error(errorMessage) + } + case o => + sys.error(s"Saving data in ${o.toString} is not supported.") + } + case SaveMode.Overwrite => + hiveContext.sql(s"DROP TABLE IF EXISTS $tableName") + // Need to create the table again. + createMetastoreTable = true + } + } else { + // The table does not exist. We need to create it in metastore. + createMetastoreTable = true + } - hiveContext.catalog.createDataSourceTable( - tableName, - None, - provider, - optionsWithPath, - isExternal) + val df = DataFrame(hiveContext, query) + + // Create the relation based on the data of df. + ResolvedDataSource(sqlContext, provider, mode, optionsWithPath, df) + + if (createMetastoreTable) { + hiveContext.catalog.createDataSourceTable( + tableName, + Some(df.schema), + provider, + optionsWithPath, + isExternal) + } Seq.empty[Row] } http://git-wip-us.apache.org/repos/asf/spark/blob/a21090eb/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala new file mode 100644 index 0000000..840fbc1 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -0,0 +1,453 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.test + +import java.io.File +import java.util.{Set => JavaSet} + +import org.apache.hadoop.hive.ql.exec.FunctionRegistry +import org.apache.hadoop.hive.ql.io.avro.{AvroContainerInputFormat, AvroContainerOutputFormat} +import org.apache.hadoop.hive.ql.metadata.Table +import org.apache.hadoop.hive.ql.processors._ +import org.apache.hadoop.hive.serde2.RegexSerDe +import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe +import org.apache.hadoop.hive.serde2.avro.AvroSerDe +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.execution.CacheTableCommand +import org.apache.spark.sql.hive._ +import org.apache.spark.sql.hive.execution.HiveNativeCommand +import org.apache.spark.util.Utils +import org.apache.spark.{SparkConf, SparkContext} + +import scala.collection.mutable +import scala.language.implicitConversions + +/* Implicit conversions */ +import scala.collection.JavaConversions._ + +// SPARK-3729: Test key required to check for initialization errors with config. +object TestHive + extends TestHiveContext( + new SparkContext("local[2]", "TestSQLContext", new SparkConf().set("spark.sql.test", ""))) + +/** + * A locally running test instance of Spark's Hive execution engine. + * + * Data from [[testTables]] will be automatically loaded whenever a query is run over those tables. + * Calling [[reset]] will delete all tables and other state in the database, leaving the database + * in a "clean" state. + * + * TestHive is singleton object version of this class because instantiating multiple copies of the + * hive metastore seems to lead to weird non-deterministic failures. Therefore, the execution of + * test cases that rely on TestHive must be serialized. + */ +class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { + self => + + // By clearing the port we force Spark to pick a new one. This allows us to rerun tests + // without restarting the JVM. + System.clearProperty("spark.hostPort") + CommandProcessorFactory.clean(hiveconf) + + hiveconf.set("hive.plan.serialization.format", "javaXML") + + lazy val warehousePath = getTempFilePath("sparkHiveWarehouse").getCanonicalPath + lazy val metastorePath = getTempFilePath("sparkHiveMetastore").getCanonicalPath + + /** Sets up the system initially or after a RESET command */ + protected def configure(): Unit = { + setConf("javax.jdo.option.ConnectionURL", + s"jdbc:derby:;databaseName=$metastorePath;create=true") + setConf("hive.metastore.warehouse.dir", warehousePath) + Utils.registerShutdownDeleteDir(new File(warehousePath)) + Utils.registerShutdownDeleteDir(new File(metastorePath)) + } + + val testTempDir = File.createTempFile("testTempFiles", "spark.hive.tmp") + testTempDir.delete() + testTempDir.mkdir() + Utils.registerShutdownDeleteDir(testTempDir) + + // For some hive test case which contain ${system:test.tmp.dir} + System.setProperty("test.tmp.dir", testTempDir.getCanonicalPath) + + configure() // Must be called before initializing the catalog below. + + /** The location of the compiled hive distribution */ + lazy val hiveHome = envVarToFile("HIVE_HOME") + /** The location of the hive source code. */ + lazy val hiveDevHome = envVarToFile("HIVE_DEV_HOME") + + // Override so we can intercept relative paths and rewrite them to point at hive. + override def runSqlHive(sql: String): Seq[String] = super.runSqlHive(rewritePaths(sql)) + + override def executePlan(plan: LogicalPlan): this.QueryExecution = + new this.QueryExecution(plan) + + /** Fewer partitions to speed up testing. */ + protected[sql] override lazy val conf: SQLConf = new SQLConf { + override def numShufflePartitions: Int = getConf(SQLConf.SHUFFLE_PARTITIONS, "5").toInt + override def dialect: String = getConf(SQLConf.DIALECT, "hiveql") + } + + /** + * Returns the value of specified environmental variable as a [[java.io.File]] after checking + * to ensure it exists + */ + private def envVarToFile(envVar: String): Option[File] = { + Option(System.getenv(envVar)).map(new File(_)) + } + + /** + * Replaces relative paths to the parent directory "../" with hiveDevHome since this is how the + * hive test cases assume the system is set up. + */ + private def rewritePaths(cmd: String): String = + if (cmd.toUpperCase contains "LOAD DATA") { + val testDataLocation = + hiveDevHome.map(_.getCanonicalPath).getOrElse(inRepoTests.getCanonicalPath) + cmd.replaceAll("\\.\\./\\.\\./", testDataLocation + "/") + } else { + cmd + } + + val hiveFilesTemp = File.createTempFile("catalystHiveFiles", "") + hiveFilesTemp.delete() + hiveFilesTemp.mkdir() + Utils.registerShutdownDeleteDir(hiveFilesTemp) + + val inRepoTests = if (System.getProperty("user.dir").endsWith("sql" + File.separator + "hive")) { + new File("src" + File.separator + "test" + File.separator + "resources" + File.separator) + } else { + new File("sql" + File.separator + "hive" + File.separator + "src" + File.separator + "test" + + File.separator + "resources") + } + + def getHiveFile(path: String): File = { + val stripped = path.replaceAll("""\.\.\/""", "").replace('/', File.separatorChar) + hiveDevHome + .map(new File(_, stripped)) + .filter(_.exists) + .getOrElse(new File(inRepoTests, stripped)) + } + + val describedTable = "DESCRIBE (\\w+)".r + + protected[hive] class HiveQLQueryExecution(hql: String) + extends this.QueryExecution(HiveQl.parseSql(hql)) { + def hiveExec() = runSqlHive(hql) + override def toString = hql + "\n" + super.toString + } + + /** + * Override QueryExecution with special debug workflow. + */ + class QueryExecution(logicalPlan: LogicalPlan) + extends super.QueryExecution(logicalPlan) { + override lazy val analyzed = { + val describedTables = logical match { + case HiveNativeCommand(describedTable(tbl)) => tbl :: Nil + case CacheTableCommand(tbl, _, _) => tbl :: Nil + case _ => Nil + } + + // Make sure any test tables referenced are loaded. + val referencedTables = + describedTables ++ + logical.collect { case UnresolvedRelation(tableIdent, _) => tableIdent.last } + val referencedTestTables = referencedTables.filter(testTables.contains) + logDebug(s"Query references test tables: ${referencedTestTables.mkString(", ")}") + referencedTestTables.foreach(loadTestTable) + // Proceed with analysis. + analyzer(logical) + } + } + + case class TestTable(name: String, commands: (()=>Unit)*) + + protected[hive] implicit class SqlCmd(sql: String) { + def cmd = () => new HiveQLQueryExecution(sql).stringResult(): Unit + } + + /** + * A list of test tables and the DDL required to initialize them. A test table is loaded on + * demand when a query are run against it. + */ + lazy val testTables = new mutable.HashMap[String, TestTable]() + def registerTestTable(testTable: TestTable) = testTables += (testTable.name -> testTable) + + // The test tables that are defined in the Hive QTestUtil. + // /itests/util/src/main/java/org/apache/hadoop/hive/ql/QTestUtil.java + val hiveQTestUtilTables = Seq( + TestTable("src", + "CREATE TABLE src (key INT, value STRING)".cmd, + s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv1.txt")}' INTO TABLE src".cmd), + TestTable("src1", + "CREATE TABLE src1 (key INT, value STRING)".cmd, + s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv3.txt")}' INTO TABLE src1".cmd), + TestTable("srcpart", () => { + runSqlHive( + "CREATE TABLE srcpart (key INT, value STRING) PARTITIONED BY (ds STRING, hr STRING)") + for (ds <- Seq("2008-04-08", "2008-04-09"); hr <- Seq("11", "12")) { + runSqlHive( + s"""LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv1.txt")}' + |OVERWRITE INTO TABLE srcpart PARTITION (ds='$ds',hr='$hr') + """.stripMargin) + } + }), + TestTable("srcpart1", () => { + runSqlHive("CREATE TABLE srcpart1 (key INT, value STRING) PARTITIONED BY (ds STRING, hr INT)") + for (ds <- Seq("2008-04-08", "2008-04-09"); hr <- 11 to 12) { + runSqlHive( + s"""LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv1.txt")}' + |OVERWRITE INTO TABLE srcpart1 PARTITION (ds='$ds',hr='$hr') + """.stripMargin) + } + }), + TestTable("src_thrift", () => { + import org.apache.hadoop.hive.serde2.thrift.ThriftDeserializer + import org.apache.hadoop.hive.serde2.thrift.test.Complex + import org.apache.hadoop.mapred.{SequenceFileInputFormat, SequenceFileOutputFormat} + import org.apache.thrift.protocol.TBinaryProtocol + + val srcThrift = new Table("default", "src_thrift") + srcThrift.setFields(Nil) + srcThrift.setInputFormatClass(classOf[SequenceFileInputFormat[_,_]].getName) + // In Hive, SequenceFileOutputFormat will be substituted by HiveSequenceFileOutputFormat. + srcThrift.setOutputFormatClass(classOf[SequenceFileOutputFormat[_,_]].getName) + srcThrift.setSerializationLib(classOf[ThriftDeserializer].getName) + srcThrift.setSerdeParam("serialization.class", classOf[Complex].getName) + srcThrift.setSerdeParam("serialization.format", classOf[TBinaryProtocol].getName) + catalog.client.createTable(srcThrift) + + + runSqlHive( + s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/complex.seq")}' INTO TABLE src_thrift") + }), + TestTable("serdeins", + s"""CREATE TABLE serdeins (key INT, value STRING) + |ROW FORMAT SERDE '${classOf[LazySimpleSerDe].getCanonicalName}' + |WITH SERDEPROPERTIES ('field.delim'='\\t') + """.stripMargin.cmd, + "INSERT OVERWRITE TABLE serdeins SELECT * FROM src".cmd), + TestTable("sales", + s"""CREATE TABLE IF NOT EXISTS sales (key STRING, value INT) + |ROW FORMAT SERDE '${classOf[RegexSerDe].getCanonicalName}' + |WITH SERDEPROPERTIES ("input.regex" = "([^ ]*)\t([^ ]*)") + """.stripMargin.cmd, + s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/sales.txt")}' INTO TABLE sales".cmd), + TestTable("episodes", + s"""CREATE TABLE episodes (title STRING, air_date STRING, doctor INT) + |ROW FORMAT SERDE '${classOf[AvroSerDe].getCanonicalName}' + |STORED AS + |INPUTFORMAT '${classOf[AvroContainerInputFormat].getCanonicalName}' + |OUTPUTFORMAT '${classOf[AvroContainerOutputFormat].getCanonicalName}' + |TBLPROPERTIES ( + | 'avro.schema.literal'='{ + | "type": "record", + | "name": "episodes", + | "namespace": "testing.hive.avro.serde", + | "fields": [ + | { + | "name": "title", + | "type": "string", + | "doc": "episode title" + | }, + | { + | "name": "air_date", + | "type": "string", + | "doc": "initial date" + | }, + | { + | "name": "doctor", + | "type": "int", + | "doc": "main actor playing the Doctor in episode" + | } + | ] + | }' + |) + """.stripMargin.cmd, + s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/episodes.avro")}' INTO TABLE episodes".cmd + ), + // THIS TABLE IS NOT THE SAME AS THE HIVE TEST TABLE episodes_partitioned AS DYNAMIC PARITIONING + // IS NOT YET SUPPORTED + TestTable("episodes_part", + s"""CREATE TABLE episodes_part (title STRING, air_date STRING, doctor INT) + |PARTITIONED BY (doctor_pt INT) + |ROW FORMAT SERDE '${classOf[AvroSerDe].getCanonicalName}' + |STORED AS + |INPUTFORMAT '${classOf[AvroContainerInputFormat].getCanonicalName}' + |OUTPUTFORMAT '${classOf[AvroContainerOutputFormat].getCanonicalName}' + |TBLPROPERTIES ( + | 'avro.schema.literal'='{ + | "type": "record", + | "name": "episodes", + | "namespace": "testing.hive.avro.serde", + | "fields": [ + | { + | "name": "title", + | "type": "string", + | "doc": "episode title" + | }, + | { + | "name": "air_date", + | "type": "string", + | "doc": "initial date" + | }, + | { + | "name": "doctor", + | "type": "int", + | "doc": "main actor playing the Doctor in episode" + | } + | ] + | }' + |) + """.stripMargin.cmd, + // WORKAROUND: Required to pass schema to SerDe for partitioned tables. + // TODO: Pass this automatically from the table to partitions. + s""" + |ALTER TABLE episodes_part SET SERDEPROPERTIES ( + | 'avro.schema.literal'='{ + | "type": "record", + | "name": "episodes", + | "namespace": "testing.hive.avro.serde", + | "fields": [ + | { + | "name": "title", + | "type": "string", + | "doc": "episode title" + | }, + | { + | "name": "air_date", + | "type": "string", + | "doc": "initial date" + | }, + | { + | "name": "doctor", + | "type": "int", + | "doc": "main actor playing the Doctor in episode" + | } + | ] + | }' + |) + """.stripMargin.cmd, + s""" + INSERT OVERWRITE TABLE episodes_part PARTITION (doctor_pt=1) + SELECT title, air_date, doctor FROM episodes + """.cmd + ) + ) + + hiveQTestUtilTables.foreach(registerTestTable) + + private val loadedTables = new collection.mutable.HashSet[String] + + var cacheTables: Boolean = false + def loadTestTable(name: String) { + if (!(loadedTables contains name)) { + // Marks the table as loaded first to prevent infinite mutually recursive table loading. + loadedTables += name + logInfo(s"Loading test table $name") + val createCmds = + testTables.get(name).map(_.commands).getOrElse(sys.error(s"Unknown test table $name")) + createCmds.foreach(_()) + + if (cacheTables) { + cacheTable(name) + } + } + } + + /** + * Records the UDFs present when the server starts, so we can delete ones that are created by + * tests. + */ + protected val originalUdfs: JavaSet[String] = FunctionRegistry.getFunctionNames + + // Database default may not exist in 0.13.1, create it if not exist + HiveShim.createDefaultDBIfNeeded(this) + + /** + * Resets the test instance by deleting any tables that have been created. + * TODO: also clear out UDFs, views, etc. + */ + def reset() { + try { + // HACK: Hive is too noisy by default. + org.apache.log4j.LogManager.getCurrentLoggers.foreach { log => + log.asInstanceOf[org.apache.log4j.Logger].setLevel(org.apache.log4j.Level.WARN) + } + + cacheManager.clearCache() + loadedTables.clear() + catalog.cachedDataSourceTables.invalidateAll() + catalog.client.getAllTables("default").foreach { t => + logDebug(s"Deleting table $t") + val table = catalog.client.getTable("default", t) + + catalog.client.getIndexes("default", t, 255).foreach { index => + catalog.client.dropIndex("default", t, index.getIndexName, true) + } + + if (!table.isIndexTable) { + catalog.client.dropTable("default", t) + } + } + + catalog.client.getAllDatabases.filterNot(_ == "default").foreach { db => + logDebug(s"Dropping Database: $db") + catalog.client.dropDatabase(db, true, false, true) + } + + catalog.unregisterAllTables() + + FunctionRegistry.getFunctionNames.filterNot(originalUdfs.contains(_)).foreach { udfName => + FunctionRegistry.unregisterTemporaryUDF(udfName) + } + + // Some tests corrupt this value on purpose, which breaks the RESET call below. + hiveconf.set("fs.default.name", new File(".").toURI.toString) + // It is important that we RESET first as broken hooks that might have been set could break + // other sql exec here. + runSqlHive("RESET") + // For some reason, RESET does not reset the following variables... + // https://issues.apache.org/jira/browse/HIVE-9004 + runSqlHive("set hive.table.parameters.default=") + runSqlHive("set datanucleus.cache.collections=true") + runSqlHive("set datanucleus.cache.collections.lazy=true") + // Lots of tests fail if we do not change the partition whitelist from the default. + runSqlHive("set hive.metastore.partition.name.whitelist.pattern=.*") + configure() + + runSqlHive("USE default") + + // Just loading src makes a lot of tests pass. This is because some tests do something like + // drop an index on src at the beginning. Since we just pass DDL to hive this bypasses our + // Analyzer and thus the test table auto-loading mechanism. + // Remove after we handle more DDL operations natively. + loadTestTable("src") + loadTestTable("srcpart") + } catch { + case e: Exception => + logError("FATAL ERROR: Failed to reset TestDB state.", e) + } + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/a21090eb/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java ---------------------------------------------------------------------- diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java new file mode 100644 index 0000000..9744a2a --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.hive; + +import java.io.File; +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.spark.sql.sources.SaveMode; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.QueryTest$; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.hive.test.TestHive$; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.util.Utils; + +public class JavaMetastoreDataSourcesSuite { + private transient JavaSparkContext sc; + private transient HiveContext sqlContext; + + String originalDefaultSource; + File path; + Path hiveManagedPath; + FileSystem fs; + DataFrame df; + + private void checkAnswer(DataFrame actual, List<Row> expected) { + String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected); + if (errorMessage != null) { + Assert.fail(errorMessage); + } + } + + @Before + public void setUp() throws IOException { + sqlContext = TestHive$.MODULE$; + sc = new JavaSparkContext(sqlContext.sparkContext()); + + originalDefaultSource = sqlContext.conf().defaultDataSourceName(); + path = + Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource").getCanonicalFile(); + if (path.exists()) { + path.delete(); + } + hiveManagedPath = new Path(sqlContext.catalog().hiveDefaultTableFilePath("javaSavedTable")); + fs = hiveManagedPath.getFileSystem(sc.hadoopConfiguration()); + if (fs.exists(hiveManagedPath)){ + fs.delete(hiveManagedPath, true); + } + + List<String> jsonObjects = new ArrayList<String>(10); + for (int i = 0; i < 10; i++) { + jsonObjects.add("{\"a\":" + i + ", \"b\":\"str" + i + "\"}"); + } + JavaRDD<String> rdd = sc.parallelize(jsonObjects); + df = sqlContext.jsonRDD(rdd); + df.registerTempTable("jsonTable"); + } + + @After + public void tearDown() throws IOException { + // Clean up tables. + sqlContext.sql("DROP TABLE IF EXISTS javaSavedTable"); + sqlContext.sql("DROP TABLE IF EXISTS externalTable"); + } + + @Test + public void saveExternalTableAndQueryIt() { + Map<String, String> options = new HashMap<String, String>(); + options.put("path", path.toString()); + df.saveAsTable("javaSavedTable", "org.apache.spark.sql.json", SaveMode.Append, options); + + checkAnswer( + sqlContext.sql("SELECT * FROM javaSavedTable"), + df.collectAsList()); + + DataFrame loadedDF = + sqlContext.createExternalTable("externalTable", "org.apache.spark.sql.json", options); + + checkAnswer(loadedDF, df.collectAsList()); + checkAnswer( + sqlContext.sql("SELECT * FROM externalTable"), + df.collectAsList()); + } + + @Test + public void saveExternalTableWithSchemaAndQueryIt() { + Map<String, String> options = new HashMap<String, String>(); + options.put("path", path.toString()); + df.saveAsTable("javaSavedTable", "org.apache.spark.sql.json", SaveMode.Append, options); + + checkAnswer( + sqlContext.sql("SELECT * FROM javaSavedTable"), + df.collectAsList()); + + List<StructField> fields = new ArrayList<>(); + fields.add(DataTypes.createStructField("b", DataTypes.StringType, true)); + StructType schema = DataTypes.createStructType(fields); + DataFrame loadedDF = + sqlContext.createExternalTable("externalTable", "org.apache.spark.sql.json", schema, options); + + checkAnswer( + loadedDF, + sqlContext.sql("SELECT b FROM javaSavedTable").collectAsList()); + checkAnswer( + sqlContext.sql("SELECT * FROM externalTable"), + sqlContext.sql("SELECT b FROM javaSavedTable").collectAsList()); + } + + @Test + public void saveTableAndQueryIt() { + Map<String, String> options = new HashMap<String, String>(); + df.saveAsTable("javaSavedTable", "org.apache.spark.sql.json", SaveMode.Append, options); + + checkAnswer( + sqlContext.sql("SELECT * FROM javaSavedTable"), + df.collectAsList()); + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/a21090eb/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala index ba39129..0270e63 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -17,10 +17,8 @@ package org.apache.spark.sql -import org.scalatest.FunSuite +import scala.collection.JavaConversions._ -import org.apache.spark.sql.catalyst.expressions.{ExprId, AttributeReference} -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.util._ @@ -55,9 +53,36 @@ class QueryTest extends PlanTest { /** * Runs the plan and makes sure the answer matches the expected result. * @param rdd the [[DataFrame]] to be executed - * @param expectedAnswer the expected result, can either be an Any, Seq[Product], or Seq[ Seq[Any] ]. + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. */ protected def checkAnswer(rdd: DataFrame, expectedAnswer: Seq[Row]): Unit = { + QueryTest.checkAnswer(rdd, expectedAnswer) match { + case Some(errorMessage) => fail(errorMessage) + case None => + } + } + + protected def checkAnswer(rdd: DataFrame, expectedAnswer: Row): Unit = { + checkAnswer(rdd, Seq(expectedAnswer)) + } + + def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext): Unit = { + test(sqlString) { + checkAnswer(sqlContext.sql(sqlString), expectedAnswer) + } + } +} + +object QueryTest { + /** + * Runs the plan and makes sure the answer matches the expected result. + * If there was exception during the execution or the contents of the DataFrame does not + * match the expected result, an error message will be returned. Otherwise, a [[None]] will + * be returned. + * @param rdd the [[DataFrame]] to be executed + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + */ + def checkAnswer(rdd: DataFrame, expectedAnswer: Seq[Row]): Option[String] = { val isSorted = rdd.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty def prepareAnswer(answer: Seq[Row]): Seq[Row] = { // Converts data to types that we can do equality comparison using Scala collections. @@ -73,18 +98,20 @@ class QueryTest extends PlanTest { } val sparkAnswer = try rdd.collect().toSeq catch { case e: Exception => - fail( + val errorMessage = s""" |Exception thrown while executing query: |${rdd.queryExecution} |== Exception == |$e |${org.apache.spark.sql.catalyst.util.stackTraceToString(e)} - """.stripMargin) + """.stripMargin + return Some(errorMessage) } if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) { - fail(s""" + val errorMessage = + s""" |Results do not match for query: |${rdd.logicalPlan} |== Analyzed Plan == @@ -93,22 +120,21 @@ class QueryTest extends PlanTest { |${rdd.queryExecution.executedPlan} |== Results == |${sideBySide( - s"== Correct Answer - ${expectedAnswer.size} ==" +: - prepareAnswer(expectedAnswer).map(_.toString), - s"== Spark Answer - ${sparkAnswer.size} ==" +: - prepareAnswer(sparkAnswer).map(_.toString)).mkString("\n")} - """.stripMargin) + s"== Correct Answer - ${expectedAnswer.size} ==" +: + prepareAnswer(expectedAnswer).map(_.toString), + s"== Spark Answer - ${sparkAnswer.size} ==" +: + prepareAnswer(sparkAnswer).map(_.toString)).mkString("\n")} + """.stripMargin + return Some(errorMessage) } - } - protected def checkAnswer(rdd: DataFrame, expectedAnswer: Row): Unit = { - checkAnswer(rdd, Seq(expectedAnswer)) + return None } - def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext): Unit = { - test(sqlString) { - checkAnswer(sqlContext.sql(sqlString), expectedAnswer) + def checkAnswer(rdd: DataFrame, expectedAnswer: java.util.List[Row]): String = { + checkAnswer(rdd, expectedAnswer.toSeq) match { + case Some(errorMessage) => errorMessage + case None => null } } - } http://git-wip-us.apache.org/repos/asf/spark/blob/a21090eb/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index 869d01e..43da751 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -19,7 +19,11 @@ package org.apache.spark.sql.hive import java.io.File +import org.scalatest.BeforeAndAfter + import com.google.common.io.Files + +import org.apache.spark.sql.execution.QueryExecutionException import org.apache.spark.sql.{QueryTest, _} import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.types._ @@ -29,15 +33,22 @@ import org.apache.spark.sql.hive.test.TestHive._ case class TestData(key: Int, value: String) -class InsertIntoHiveTableSuite extends QueryTest { +class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { import org.apache.spark.sql.hive.test.TestHive.implicits._ val testData = TestHive.sparkContext.parallelize( (1 to 100).map(i => TestData(i, i.toString))) - testData.registerTempTable("testData") + + before { + // Since every we are doing tests for DDL statements, + // it is better to reset before every test. + TestHive.reset() + // Register the testData, which will be used in every test. + testData.registerTempTable("testData") + } test("insertInto() HiveTable") { - createTable[TestData]("createAndInsertTest") + sql("CREATE TABLE createAndInsertTest (key int, value string)") // Add some data. testData.insertInto("createAndInsertTest") @@ -68,16 +79,18 @@ class InsertIntoHiveTableSuite extends QueryTest { } test("Double create fails when allowExisting = false") { - createTable[TestData]("doubleCreateAndInsertTest") + sql("CREATE TABLE doubleCreateAndInsertTest (key int, value string)") - intercept[org.apache.hadoop.hive.ql.metadata.HiveException] { - createTable[TestData]("doubleCreateAndInsertTest", allowExisting = false) - } + val message = intercept[QueryExecutionException] { + sql("CREATE TABLE doubleCreateAndInsertTest (key int, value string)") + }.getMessage + + println("message!!!!" + message) } test("Double create does not fail when allowExisting = true") { - createTable[TestData]("createAndInsertTest") - createTable[TestData]("createAndInsertTest") + sql("CREATE TABLE doubleCreateAndInsertTest (key int, value string)") + sql("CREATE TABLE IF NOT EXISTS doubleCreateAndInsertTest (key int, value string)") } test("SPARK-4052: scala.collection.Map as value type of MapType") { @@ -98,7 +111,7 @@ class InsertIntoHiveTableSuite extends QueryTest { } test("SPARK-4203:random partition directory order") { - createTable[TestData]("tmp_table") + sql("CREATE TABLE tmp_table (key int, value string)") val tmpDir = Files.createTempDir() sql(s"CREATE TABLE table_with_partition(c1 string) PARTITIONED by (p1 string,p2 string,p3 string,p4 string,p5 string) location '${tmpDir.toURI.toString}' ") sql("INSERT OVERWRITE TABLE table_with_partition partition (p1='a',p2='b',p3='c',p4='c',p5='1') SELECT 'blarr' FROM tmp_table") http://git-wip-us.apache.org/repos/asf/spark/blob/a21090eb/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 9ce0589..f94aabd 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -19,10 +19,12 @@ package org.apache.spark.sql.hive import java.io.File +import org.apache.spark.sql.sources.SaveMode import org.scalatest.BeforeAndAfterEach import org.apache.commons.io.FileUtils import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapred.InvalidInputException import org.apache.spark.sql.catalyst.util import org.apache.spark.sql._ @@ -41,11 +43,11 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { override def afterEach(): Unit = { reset() - if (ctasPath.exists()) Utils.deleteRecursively(ctasPath) + if (tempPath.exists()) Utils.deleteRecursively(tempPath) } val filePath = Utils.getSparkClassLoader.getResource("sample.json").getFile - var ctasPath: File = util.getTempFilePath("jsonCTAS").getCanonicalFile + var tempPath: File = util.getTempFilePath("jsonCTAS").getCanonicalFile test ("persistent JSON table") { sql( @@ -270,7 +272,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { |CREATE TABLE ctasJsonTable |USING org.apache.spark.sql.json.DefaultSource |OPTIONS ( - | path '${ctasPath}' + | path '${tempPath}' |) AS |SELECT * FROM jsonTable """.stripMargin) @@ -297,7 +299,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { |CREATE TABLE ctasJsonTable |USING org.apache.spark.sql.json.DefaultSource |OPTIONS ( - | path '${ctasPath}' + | path '${tempPath}' |) AS |SELECT * FROM jsonTable """.stripMargin) @@ -309,7 +311,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { |CREATE TABLE ctasJsonTable |USING org.apache.spark.sql.json.DefaultSource |OPTIONS ( - | path '${ctasPath}' + | path '${tempPath}' |) AS |SELECT * FROM jsonTable """.stripMargin) @@ -325,7 +327,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { |CREATE TABLE IF NOT EXISTS ctasJsonTable |USING org.apache.spark.sql.json.DefaultSource |OPTIONS ( - | path '${ctasPath}' + | path '${tempPath}' |) AS |SELECT a FROM jsonTable """.stripMargin) @@ -400,38 +402,122 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { sql("DROP TABLE jsonTable").collect().foreach(println) } - test("save and load table") { + test("save table") { val originalDefaultSource = conf.defaultDataSourceName - conf.setConf("spark.sql.default.datasource", "org.apache.spark.sql.json") val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) val df = jsonRDD(rdd) + conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") + // Save the df as a managed table (by not specifiying the path). df.saveAsTable("savedJsonTable") checkAnswer( sql("SELECT * FROM savedJsonTable"), df.collect()) - createTable("createdJsonTable", catalog.hiveDefaultTableFilePath("savedJsonTable"), false) + // Right now, we cannot append to an existing JSON table. + intercept[RuntimeException] { + df.saveAsTable("savedJsonTable", SaveMode.Append) + } + + // We can overwrite it. + df.saveAsTable("savedJsonTable", SaveMode.Overwrite) + checkAnswer( + sql("SELECT * FROM savedJsonTable"), + df.collect()) + + // When the save mode is Ignore, we will do nothing when the table already exists. + df.select("b").saveAsTable("savedJsonTable", SaveMode.Ignore) + assert(df.schema === table("savedJsonTable").schema) + checkAnswer( + sql("SELECT * FROM savedJsonTable"), + df.collect()) + + // Drop table will also delete the data. + sql("DROP TABLE savedJsonTable") + intercept[InvalidInputException] { + jsonFile(catalog.hiveDefaultTableFilePath("savedJsonTable")) + } + + // Create an external table by specifying the path. + conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") + df.saveAsTable( + "savedJsonTable", + "org.apache.spark.sql.json", + SaveMode.Append, + Map("path" -> tempPath.toString)) + checkAnswer( + sql("SELECT * FROM savedJsonTable"), + df.collect()) + + // Data should not be deleted after we drop the table. + sql("DROP TABLE savedJsonTable") + checkAnswer( + jsonFile(tempPath.toString), + df.collect()) + + conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) + } + + test("create external table") { + val originalDefaultSource = conf.defaultDataSourceName + + val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) + val df = jsonRDD(rdd) + + conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") + df.saveAsTable( + "savedJsonTable", + "org.apache.spark.sql.json", + SaveMode.Append, + Map("path" -> tempPath.toString)) + + conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") + createExternalTable("createdJsonTable", tempPath.toString) assert(table("createdJsonTable").schema === df.schema) checkAnswer( sql("SELECT * FROM createdJsonTable"), df.collect()) - val message = intercept[RuntimeException] { - createTable("createdJsonTable", filePath.toString, false) + var message = intercept[RuntimeException] { + createExternalTable("createdJsonTable", filePath.toString) }.getMessage assert(message.contains("Table createdJsonTable already exists."), "We should complain that ctasJsonTable already exists") - createTable("createdJsonTable", filePath.toString, true) - // createdJsonTable should be not changed. - assert(table("createdJsonTable").schema === df.schema) + // Data should not be deleted. + sql("DROP TABLE createdJsonTable") checkAnswer( - sql("SELECT * FROM createdJsonTable"), + jsonFile(tempPath.toString), df.collect()) - conf.setConf("spark.sql.default.datasource", originalDefaultSource) + // Try to specify the schema. + conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") + val schema = StructType(StructField("b", StringType, true) :: Nil) + createExternalTable( + "createdJsonTable", + "org.apache.spark.sql.json", + schema, + Map("path" -> tempPath.toString)) + checkAnswer( + sql("SELECT * FROM createdJsonTable"), + sql("SELECT b FROM savedJsonTable").collect()) + + sql("DROP TABLE createdJsonTable") + + message = intercept[RuntimeException] { + createExternalTable( + "createdJsonTable", + "org.apache.spark.sql.json", + schema, + Map.empty[String, String]) + }.getMessage + assert( + message.contains("Option 'path' not specified"), + "We should complain that path is not specified.") + + sql("DROP TABLE savedJsonTable") + conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org