http://git-wip-us.apache.org/repos/asf/spark/blob/b3ee53b8/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala deleted file mode 100644 index f091615..0000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala +++ /dev/null @@ -1,141 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources.text - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileStatus, Path} -import org.apache.hadoop.io.{NullWritable, Text} -import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} -import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat - -import org.apache.spark.sql.{AnalysisException, Row, SparkSession} -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter} -import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.{StringType, StructType} -import org.apache.spark.util.SerializableConfiguration - -/** - * A data source for reading text files. - */ -class DefaultSource extends FileFormat with DataSourceRegister { - - override def shortName(): String = "text" - - private def verifySchema(schema: StructType): Unit = { - if (schema.size != 1) { - throw new AnalysisException( - s"Text data source supports only a single column, and you have ${schema.size} columns.") - } - val tpe = schema(0).dataType - if (tpe != StringType) { - throw new AnalysisException( - s"Text data source supports only a string column, but you have ${tpe.simpleString}.") - } - } - - override def inferSchema( - sparkSession: SparkSession, - options: Map[String, String], - files: Seq[FileStatus]): Option[StructType] = Some(new StructType().add("value", StringType)) - - override def prepareWrite( - sparkSession: SparkSession, - job: Job, - options: Map[String, String], - dataSchema: StructType): OutputWriterFactory = { - verifySchema(dataSchema) - - val conf = job.getConfiguration - val compressionCodec = options.get("compression").map(CompressionCodecs.getCodecClassName) - compressionCodec.foreach { codec => - CompressionCodecs.setCodecConfiguration(conf, codec) - } - - new OutputWriterFactory { - override def newInstance( - path: String, - bucketId: Option[Int], - dataSchema: StructType, - context: TaskAttemptContext): OutputWriter = { - if (bucketId.isDefined) { - throw new AnalysisException("Text doesn't support bucketing") - } - new TextOutputWriter(path, dataSchema, context) - } - } - } - - override def buildReader( - sparkSession: SparkSession, - dataSchema: StructType, - partitionSchema: StructType, - requiredSchema: StructType, - filters: Seq[Filter], - options: Map[String, String], - hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = { - val broadcastedHadoopConf = - sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) - - (file: PartitionedFile) => { - val unsafeRow = new UnsafeRow(1) - val bufferHolder = new BufferHolder(unsafeRow) - val unsafeRowWriter = new UnsafeRowWriter(bufferHolder, 1) - - new HadoopFileLinesReader(file, broadcastedHadoopConf.value.value).map { line => - // Writes to an UnsafeRow directly - bufferHolder.reset() - unsafeRowWriter.write(0, line.getBytes, 0, line.getLength) - unsafeRow.setTotalSize(bufferHolder.totalSize()) - unsafeRow - } - } - } -} - -class TextOutputWriter(path: String, dataSchema: StructType, context: TaskAttemptContext) - extends OutputWriter { - - private[this] val buffer = new Text() - - private val recordWriter: RecordWriter[NullWritable, Text] = { - new TextOutputFormat[NullWritable, Text]() { - override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - val configuration = context.getConfiguration - val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID") - val taskAttemptId = context.getTaskAttemptID - val split = taskAttemptId.getTaskID.getId - new Path(path, f"part-r-$split%05d-$uniqueWriteJobId.txt$extension") - } - }.getRecordWriter(context) - } - - override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal") - - override protected[sql] def writeInternal(row: InternalRow): Unit = { - val utf8string = row.getUTF8String(0) - buffer.set(utf8string.getBytes) - recordWriter.write(NullWritable.get(), buffer) - } - - override def close(): Unit = { - recordWriter.close(context) - } -}
http://git-wip-us.apache.org/repos/asf/spark/blob/b3ee53b8/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala new file mode 100644 index 0000000..d9525ef --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.text + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.io.{NullWritable, Text} +import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} +import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat + +import org.apache.spark.sql.{AnalysisException, Row, SparkSession} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter} +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.util.SerializableConfiguration + +/** + * A data source for reading text files. + */ +class TextFileFormat extends FileFormat with DataSourceRegister { + + override def shortName(): String = "text" + + private def verifySchema(schema: StructType): Unit = { + if (schema.size != 1) { + throw new AnalysisException( + s"Text data source supports only a single column, and you have ${schema.size} columns.") + } + val tpe = schema(0).dataType + if (tpe != StringType) { + throw new AnalysisException( + s"Text data source supports only a string column, but you have ${tpe.simpleString}.") + } + } + + override def inferSchema( + sparkSession: SparkSession, + options: Map[String, String], + files: Seq[FileStatus]): Option[StructType] = Some(new StructType().add("value", StringType)) + + override def prepareWrite( + sparkSession: SparkSession, + job: Job, + options: Map[String, String], + dataSchema: StructType): OutputWriterFactory = { + verifySchema(dataSchema) + + val conf = job.getConfiguration + val compressionCodec = options.get("compression").map(CompressionCodecs.getCodecClassName) + compressionCodec.foreach { codec => + CompressionCodecs.setCodecConfiguration(conf, codec) + } + + new OutputWriterFactory { + override def newInstance( + path: String, + bucketId: Option[Int], + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + if (bucketId.isDefined) { + throw new AnalysisException("Text doesn't support bucketing") + } + new TextOutputWriter(path, dataSchema, context) + } + } + } + + override def buildReader( + sparkSession: SparkSession, + dataSchema: StructType, + partitionSchema: StructType, + requiredSchema: StructType, + filters: Seq[Filter], + options: Map[String, String], + hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = { + val broadcastedHadoopConf = + sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + + (file: PartitionedFile) => { + val unsafeRow = new UnsafeRow(1) + val bufferHolder = new BufferHolder(unsafeRow) + val unsafeRowWriter = new UnsafeRowWriter(bufferHolder, 1) + + new HadoopFileLinesReader(file, broadcastedHadoopConf.value.value).map { line => + // Writes to an UnsafeRow directly + bufferHolder.reset() + unsafeRowWriter.write(0, line.getBytes, 0, line.getLength) + unsafeRow.setTotalSize(bufferHolder.totalSize()) + unsafeRow + } + } + } +} + +class TextOutputWriter(path: String, dataSchema: StructType, context: TaskAttemptContext) + extends OutputWriter { + + private[this] val buffer = new Text() + + private val recordWriter: RecordWriter[NullWritable, Text] = { + new TextOutputFormat[NullWritable, Text]() { + override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { + val configuration = context.getConfiguration + val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID") + val taskAttemptId = context.getTaskAttemptID + val split = taskAttemptId.getTaskID.getId + new Path(path, f"part-r-$split%05d-$uniqueWriteJobId.txt$extension") + } + }.getRecordWriter(context) + } + + override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal") + + override protected[sql] def writeInternal(row: InternalRow): Unit = { + val utf8string = row.getUTF8String(0) + buffer.set(utf8string.getBytes) + recordWriter.write(NullWritable.get(), buffer) + } + + override def close(): Unit = { + recordWriter.close(context) + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/b3ee53b8/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 46213a2..500d8ff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -1320,7 +1320,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { userSpecifiedSchema = None, partitionColumns = Array.empty[String], bucketSpec = None, - className = classOf[DefaultSource].getCanonicalName, + className = classOf[JsonFileFormat].getCanonicalName, options = Map("path" -> path)).resolveRelation() val d2 = DataSource( @@ -1328,7 +1328,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { userSpecifiedSchema = None, partitionColumns = Array.empty[String], bucketSpec = None, - className = classOf[DefaultSource].getCanonicalName, + className = classOf[JsonFileFormat].getCanonicalName, options = Map("path" -> path)).resolveRelation() assert(d1 === d2) }) http://git-wip-us.apache.org/repos/asf/spark/blob/b3ee53b8/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala index c43b142..6db6492 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala @@ -375,7 +375,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { StructField("lowerCase", StringType), StructField("UPPERCase", DoubleType, nullable = false)))) { - ParquetRelation.mergeMetastoreParquetSchema( + ParquetFileFormat.mergeMetastoreParquetSchema( StructType(Seq( StructField("lowercase", StringType), StructField("uppercase", DoubleType, nullable = false))), @@ -390,7 +390,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { StructType(Seq( StructField("UPPERCase", DoubleType, nullable = false)))) { - ParquetRelation.mergeMetastoreParquetSchema( + ParquetFileFormat.mergeMetastoreParquetSchema( StructType(Seq( StructField("uppercase", DoubleType, nullable = false))), @@ -401,7 +401,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { // Metastore schema contains additional non-nullable fields. assert(intercept[Throwable] { - ParquetRelation.mergeMetastoreParquetSchema( + ParquetFileFormat.mergeMetastoreParquetSchema( StructType(Seq( StructField("uppercase", DoubleType, nullable = false), StructField("lowerCase", BinaryType, nullable = false))), @@ -412,7 +412,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { // Conflicting non-nullable field names intercept[Throwable] { - ParquetRelation.mergeMetastoreParquetSchema( + ParquetFileFormat.mergeMetastoreParquetSchema( StructType(Seq(StructField("lower", StringType, nullable = false))), StructType(Seq(StructField("lowerCase", BinaryType)))) } @@ -426,7 +426,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { StructField("firstField", StringType, nullable = true), StructField("secondField", StringType, nullable = true), StructField("thirdfield", StringType, nullable = true)))) { - ParquetRelation.mergeMetastoreParquetSchema( + ParquetFileFormat.mergeMetastoreParquetSchema( StructType(Seq( StructField("firstfield", StringType, nullable = true), StructField("secondfield", StringType, nullable = true), @@ -439,7 +439,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { // Merge should fail if the Metastore contains any additional fields that are not // nullable. assert(intercept[Throwable] { - ParquetRelation.mergeMetastoreParquetSchema( + ParquetFileFormat.mergeMetastoreParquetSchema( StructType(Seq( StructField("firstfield", StringType, nullable = true), StructField("secondfield", StringType, nullable = true), http://git-wip-us.apache.org/repos/asf/spark/blob/b3ee53b8/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala index 4f6df54..320aaea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala @@ -27,37 +27,37 @@ class ResolvedDataSourceSuite extends SparkFunSuite { test("jdbc") { assert( getProvidingClass("jdbc") === - classOf[org.apache.spark.sql.execution.datasources.jdbc.DefaultSource]) + classOf[org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider]) assert( getProvidingClass("org.apache.spark.sql.execution.datasources.jdbc") === - classOf[org.apache.spark.sql.execution.datasources.jdbc.DefaultSource]) + classOf[org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider]) assert( getProvidingClass("org.apache.spark.sql.jdbc") === - classOf[org.apache.spark.sql.execution.datasources.jdbc.DefaultSource]) + classOf[org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider]) } test("json") { assert( getProvidingClass("json") === - classOf[org.apache.spark.sql.execution.datasources.json.DefaultSource]) + classOf[org.apache.spark.sql.execution.datasources.json.JsonFileFormat]) assert( getProvidingClass("org.apache.spark.sql.execution.datasources.json") === - classOf[org.apache.spark.sql.execution.datasources.json.DefaultSource]) + classOf[org.apache.spark.sql.execution.datasources.json.JsonFileFormat]) assert( getProvidingClass("org.apache.spark.sql.json") === - classOf[org.apache.spark.sql.execution.datasources.json.DefaultSource]) + classOf[org.apache.spark.sql.execution.datasources.json.JsonFileFormat]) } test("parquet") { assert( getProvidingClass("parquet") === - classOf[org.apache.spark.sql.execution.datasources.parquet.DefaultSource]) + classOf[org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat]) assert( getProvidingClass("org.apache.spark.sql.execution.datasources.parquet") === - classOf[org.apache.spark.sql.execution.datasources.parquet.DefaultSource]) + classOf[org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat]) assert( getProvidingClass("org.apache.spark.sql.parquet") === - classOf[org.apache.spark.sql.execution.datasources.parquet.DefaultSource]) + classOf[org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat]) } test("error message for unknown data sources") { http://git-wip-us.apache.org/repos/asf/spark/blob/b3ee53b8/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala deleted file mode 100644 index e6c0ce9..0000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala +++ /dev/null @@ -1,542 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.streaming.test - -import java.util.concurrent.TimeUnit - -import scala.concurrent.duration._ - -import org.mockito.Mockito._ -import org.scalatest.BeforeAndAfter - -import org.apache.spark.sql._ -import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.sources.{StreamSinkProvider, StreamSourceProvider} -import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{IntegerType, StructField, StructType} -import org.apache.spark.util.Utils - -object LastOptions { - - var mockStreamSourceProvider = mock(classOf[StreamSourceProvider]) - var mockStreamSinkProvider = mock(classOf[StreamSinkProvider]) - var parameters: Map[String, String] = null - var schema: Option[StructType] = null - var partitionColumns: Seq[String] = Nil - - def clear(): Unit = { - parameters = null - schema = null - partitionColumns = null - reset(mockStreamSourceProvider) - reset(mockStreamSinkProvider) - } -} - -/** Dummy provider: returns no-op source/sink and records options in [[LastOptions]]. */ -class DefaultSource extends StreamSourceProvider with StreamSinkProvider { - - private val fakeSchema = StructType(StructField("a", IntegerType) :: Nil) - - override def sourceSchema( - spark: SQLContext, - schema: Option[StructType], - providerName: String, - parameters: Map[String, String]): (String, StructType) = { - LastOptions.parameters = parameters - LastOptions.schema = schema - LastOptions.mockStreamSourceProvider.sourceSchema(spark, schema, providerName, parameters) - ("dummySource", fakeSchema) - } - - override def createSource( - spark: SQLContext, - metadataPath: String, - schema: Option[StructType], - providerName: String, - parameters: Map[String, String]): Source = { - LastOptions.parameters = parameters - LastOptions.schema = schema - LastOptions.mockStreamSourceProvider.createSource( - spark, metadataPath, schema, providerName, parameters) - new Source { - override def schema: StructType = fakeSchema - - override def getOffset: Option[Offset] = Some(new LongOffset(0)) - - override def getBatch(start: Option[Offset], end: Offset): DataFrame = { - import spark.implicits._ - - Seq[Int]().toDS().toDF() - } - } - } - - override def createSink( - spark: SQLContext, - parameters: Map[String, String], - partitionColumns: Seq[String]): Sink = { - LastOptions.parameters = parameters - LastOptions.partitionColumns = partitionColumns - LastOptions.mockStreamSinkProvider.createSink(spark, parameters, partitionColumns) - new Sink { - override def addBatch(batchId: Long, data: DataFrame): Unit = {} - } - } -} - -class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with BeforeAndAfter { - import testImplicits._ - - private def newMetadataDir = - Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath - - after { - spark.streams.active.foreach(_.stop()) - } - - test("resolve default source") { - spark.read - .format("org.apache.spark.sql.streaming.test") - .stream() - .write - .format("org.apache.spark.sql.streaming.test") - .option("checkpointLocation", newMetadataDir) - .startStream() - .stop() - } - - test("resolve full class") { - spark.read - .format("org.apache.spark.sql.streaming.test.DefaultSource") - .stream() - .write - .format("org.apache.spark.sql.streaming.test") - .option("checkpointLocation", newMetadataDir) - .startStream() - .stop() - } - - test("options") { - val map = new java.util.HashMap[String, String] - map.put("opt3", "3") - - val df = spark.read - .format("org.apache.spark.sql.streaming.test") - .option("opt1", "1") - .options(Map("opt2" -> "2")) - .options(map) - .stream() - - assert(LastOptions.parameters("opt1") == "1") - assert(LastOptions.parameters("opt2") == "2") - assert(LastOptions.parameters("opt3") == "3") - - LastOptions.clear() - - df.write - .format("org.apache.spark.sql.streaming.test") - .option("opt1", "1") - .options(Map("opt2" -> "2")) - .options(map) - .option("checkpointLocation", newMetadataDir) - .startStream() - .stop() - - assert(LastOptions.parameters("opt1") == "1") - assert(LastOptions.parameters("opt2") == "2") - assert(LastOptions.parameters("opt3") == "3") - } - - test("partitioning") { - val df = spark.read - .format("org.apache.spark.sql.streaming.test") - .stream() - - df.write - .format("org.apache.spark.sql.streaming.test") - .option("checkpointLocation", newMetadataDir) - .startStream() - .stop() - assert(LastOptions.partitionColumns == Nil) - - df.write - .format("org.apache.spark.sql.streaming.test") - .option("checkpointLocation", newMetadataDir) - .partitionBy("a") - .startStream() - .stop() - assert(LastOptions.partitionColumns == Seq("a")) - - withSQLConf("spark.sql.caseSensitive" -> "false") { - df.write - .format("org.apache.spark.sql.streaming.test") - .option("checkpointLocation", newMetadataDir) - .partitionBy("A") - .startStream() - .stop() - assert(LastOptions.partitionColumns == Seq("a")) - } - - intercept[AnalysisException] { - df.write - .format("org.apache.spark.sql.streaming.test") - .option("checkpointLocation", newMetadataDir) - .partitionBy("b") - .startStream() - .stop() - } - } - - test("stream paths") { - val df = spark.read - .format("org.apache.spark.sql.streaming.test") - .option("checkpointLocation", newMetadataDir) - .stream("/test") - - assert(LastOptions.parameters("path") == "/test") - - LastOptions.clear() - - df.write - .format("org.apache.spark.sql.streaming.test") - .option("checkpointLocation", newMetadataDir) - .startStream("/test") - .stop() - - assert(LastOptions.parameters("path") == "/test") - } - - test("test different data types for options") { - val df = spark.read - .format("org.apache.spark.sql.streaming.test") - .option("intOpt", 56) - .option("boolOpt", false) - .option("doubleOpt", 6.7) - .stream("/test") - - assert(LastOptions.parameters("intOpt") == "56") - assert(LastOptions.parameters("boolOpt") == "false") - assert(LastOptions.parameters("doubleOpt") == "6.7") - - LastOptions.clear() - df.write - .format("org.apache.spark.sql.streaming.test") - .option("intOpt", 56) - .option("boolOpt", false) - .option("doubleOpt", 6.7) - .option("checkpointLocation", newMetadataDir) - .startStream("/test") - .stop() - - assert(LastOptions.parameters("intOpt") == "56") - assert(LastOptions.parameters("boolOpt") == "false") - assert(LastOptions.parameters("doubleOpt") == "6.7") - } - - test("unique query names") { - - /** Start a query with a specific name */ - def startQueryWithName(name: String = ""): ContinuousQuery = { - spark.read - .format("org.apache.spark.sql.streaming.test") - .stream("/test") - .write - .format("org.apache.spark.sql.streaming.test") - .option("checkpointLocation", newMetadataDir) - .queryName(name) - .startStream() - } - - /** Start a query without specifying a name */ - def startQueryWithoutName(): ContinuousQuery = { - spark.read - .format("org.apache.spark.sql.streaming.test") - .stream("/test") - .write - .format("org.apache.spark.sql.streaming.test") - .option("checkpointLocation", newMetadataDir) - .startStream() - } - - /** Get the names of active streams */ - def activeStreamNames: Set[String] = { - val streams = spark.streams.active - val names = streams.map(_.name).toSet - assert(streams.length === names.size, s"names of active queries are not unique: $names") - names - } - - val q1 = startQueryWithName("name") - - // Should not be able to start another query with the same name - intercept[IllegalArgumentException] { - startQueryWithName("name") - } - assert(activeStreamNames === Set("name")) - - // Should be able to start queries with other names - val q3 = startQueryWithName("another-name") - assert(activeStreamNames === Set("name", "another-name")) - - // Should be able to start queries with auto-generated names - val q4 = startQueryWithoutName() - assert(activeStreamNames.contains(q4.name)) - - // Should not be able to start a query with same auto-generated name - intercept[IllegalArgumentException] { - startQueryWithName(q4.name) - } - - // Should be able to start query with that name after stopping the previous query - q1.stop() - val q5 = startQueryWithName("name") - assert(activeStreamNames.contains("name")) - spark.streams.active.foreach(_.stop()) - } - - test("trigger") { - val df = spark.read - .format("org.apache.spark.sql.streaming.test") - .stream("/test") - - var q = df.write - .format("org.apache.spark.sql.streaming.test") - .option("checkpointLocation", newMetadataDir) - .trigger(ProcessingTime(10.seconds)) - .startStream() - q.stop() - - assert(q.asInstanceOf[StreamExecution].trigger == ProcessingTime(10000)) - - q = df.write - .format("org.apache.spark.sql.streaming.test") - .option("checkpointLocation", newMetadataDir) - .trigger(ProcessingTime.create(100, TimeUnit.SECONDS)) - .startStream() - q.stop() - - assert(q.asInstanceOf[StreamExecution].trigger == ProcessingTime(100000)) - } - - test("source metadataPath") { - LastOptions.clear() - - val checkpointLocation = newMetadataDir - - val df1 = spark.read - .format("org.apache.spark.sql.streaming.test") - .stream() - - val df2 = spark.read - .format("org.apache.spark.sql.streaming.test") - .stream() - - val q = df1.union(df2).write - .format("org.apache.spark.sql.streaming.test") - .option("checkpointLocation", checkpointLocation) - .trigger(ProcessingTime(10.seconds)) - .startStream() - q.stop() - - verify(LastOptions.mockStreamSourceProvider).createSource( - spark.sqlContext, - checkpointLocation + "/sources/0", - None, - "org.apache.spark.sql.streaming.test", - Map.empty) - - verify(LastOptions.mockStreamSourceProvider).createSource( - spark.sqlContext, - checkpointLocation + "/sources/1", - None, - "org.apache.spark.sql.streaming.test", - Map.empty) - } - - private def newTextInput = Utils.createTempDir(namePrefix = "text").getCanonicalPath - - test("check trigger() can only be called on continuous queries") { - val df = spark.read.text(newTextInput) - val w = df.write.option("checkpointLocation", newMetadataDir) - val e = intercept[AnalysisException](w.trigger(ProcessingTime("10 seconds"))) - assert(e.getMessage == "trigger() can only be called on continuous queries;") - } - - test("check queryName() can only be called on continuous queries") { - val df = spark.read.text(newTextInput) - val w = df.write.option("checkpointLocation", newMetadataDir) - val e = intercept[AnalysisException](w.queryName("queryName")) - assert(e.getMessage == "queryName() can only be called on continuous queries;") - } - - test("check startStream() can only be called on continuous queries") { - val df = spark.read.text(newTextInput) - val w = df.write.option("checkpointLocation", newMetadataDir) - val e = intercept[AnalysisException](w.startStream()) - assert(e.getMessage == "startStream() can only be called on continuous queries;") - } - - test("check startStream(path) can only be called on continuous queries") { - val df = spark.read.text(newTextInput) - val w = df.write.option("checkpointLocation", newMetadataDir) - val e = intercept[AnalysisException](w.startStream("non_exist_path")) - assert(e.getMessage == "startStream() can only be called on continuous queries;") - } - - test("check mode(SaveMode) can only be called on non-continuous queries") { - val df = spark.read - .format("org.apache.spark.sql.streaming.test") - .stream() - val w = df.write - val e = intercept[AnalysisException](w.mode(SaveMode.Append)) - assert(e.getMessage == "mode() can only be called on non-continuous queries;") - } - - test("check mode(string) can only be called on non-continuous queries") { - val df = spark.read - .format("org.apache.spark.sql.streaming.test") - .stream() - val w = df.write - val e = intercept[AnalysisException](w.mode("append")) - assert(e.getMessage == "mode() can only be called on non-continuous queries;") - } - - test("check bucketBy() can only be called on non-continuous queries") { - val df = spark.read - .format("org.apache.spark.sql.streaming.test") - .stream() - val w = df.write - val e = intercept[IllegalArgumentException](w.bucketBy(1, "text").startStream()) - assert(e.getMessage == "Currently we don't support writing bucketed data to this data source.") - } - - test("check sortBy() can only be called on non-continuous queries;") { - val df = spark.read - .format("org.apache.spark.sql.streaming.test") - .stream() - val w = df.write - val e = intercept[IllegalArgumentException](w.sortBy("text").startStream()) - assert(e.getMessage == "Currently we don't support writing bucketed data to this data source.") - } - - test("check save(path) can only be called on non-continuous queries") { - val df = spark.read - .format("org.apache.spark.sql.streaming.test") - .stream() - val w = df.write - val e = intercept[AnalysisException](w.save("non_exist_path")) - assert(e.getMessage == "save() can only be called on non-continuous queries;") - } - - test("check save() can only be called on non-continuous queries") { - val df = spark.read - .format("org.apache.spark.sql.streaming.test") - .stream() - val w = df.write - val e = intercept[AnalysisException](w.save()) - assert(e.getMessage == "save() can only be called on non-continuous queries;") - } - - test("check insertInto() can only be called on non-continuous queries") { - val df = spark.read - .format("org.apache.spark.sql.streaming.test") - .stream() - val w = df.write - val e = intercept[AnalysisException](w.insertInto("non_exsit_table")) - assert(e.getMessage == "insertInto() can only be called on non-continuous queries;") - } - - test("check saveAsTable() can only be called on non-continuous queries") { - val df = spark.read - .format("org.apache.spark.sql.streaming.test") - .stream() - val w = df.write - val e = intercept[AnalysisException](w.saveAsTable("non_exsit_table")) - assert(e.getMessage == "saveAsTable() can only be called on non-continuous queries;") - } - - test("check jdbc() can only be called on non-continuous queries") { - val df = spark.read - .format("org.apache.spark.sql.streaming.test") - .stream() - val w = df.write - val e = intercept[AnalysisException](w.jdbc(null, null, null)) - assert(e.getMessage == "jdbc() can only be called on non-continuous queries;") - } - - test("check json() can only be called on non-continuous queries") { - val df = spark.read - .format("org.apache.spark.sql.streaming.test") - .stream() - val w = df.write - val e = intercept[AnalysisException](w.json("non_exist_path")) - assert(e.getMessage == "json() can only be called on non-continuous queries;") - } - - test("check parquet() can only be called on non-continuous queries") { - val df = spark.read - .format("org.apache.spark.sql.streaming.test") - .stream() - val w = df.write - val e = intercept[AnalysisException](w.parquet("non_exist_path")) - assert(e.getMessage == "parquet() can only be called on non-continuous queries;") - } - - test("check orc() can only be called on non-continuous queries") { - val df = spark.read - .format("org.apache.spark.sql.streaming.test") - .stream() - val w = df.write - val e = intercept[AnalysisException](w.orc("non_exist_path")) - assert(e.getMessage == "orc() can only be called on non-continuous queries;") - } - - test("check text() can only be called on non-continuous queries") { - val df = spark.read - .format("org.apache.spark.sql.streaming.test") - .stream() - val w = df.write - val e = intercept[AnalysisException](w.text("non_exist_path")) - assert(e.getMessage == "text() can only be called on non-continuous queries;") - } - - test("check csv() can only be called on non-continuous queries") { - val df = spark.read - .format("org.apache.spark.sql.streaming.test") - .stream() - val w = df.write - val e = intercept[AnalysisException](w.csv("non_exist_path")) - assert(e.getMessage == "csv() can only be called on non-continuous queries;") - } - - test("ConsoleSink can be correctly loaded") { - LastOptions.clear() - val df = spark.read - .format("org.apache.spark.sql.streaming.test") - .stream() - - val cq = df.write - .format("console") - .option("checkpointLocation", newMetadataDir) - .trigger(ProcessingTime(2.seconds)) - .startStream() - - cq.awaitTermination(2000L) - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/b3ee53b8/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index 6238b74..f3262f7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -41,7 +41,7 @@ class FileStreamSinkSuite extends StreamTest with SharedSQLContext { path.delete() val hadoopConf = spark.sparkContext.hadoopConfiguration - val fileFormat = new parquet.DefaultSource() + val fileFormat = new parquet.ParquetFileFormat() def writeRange(start: Int, end: Int, numPartitions: Int): Seq[String] = { val df = spark @@ -73,7 +73,7 @@ class FileStreamSinkSuite extends StreamTest with SharedSQLContext { path.delete() val hadoopConf = spark.sparkContext.hadoopConfiguration - val fileFormat = new parquet.DefaultSource() + val fileFormat = new parquet.ParquetFileFormat() def writeRange(start: Int, end: Int, numPartitions: Int): Seq[String] = { val df = spark http://git-wip-us.apache.org/repos/asf/spark/blob/b3ee53b8/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataFrameReaderWriterSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataFrameReaderWriterSuite.scala new file mode 100644 index 0000000..288f6dc --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataFrameReaderWriterSuite.scala @@ -0,0 +1,541 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming.test + +import java.util.concurrent.TimeUnit + +import scala.concurrent.duration._ + +import org.mockito.Mockito._ +import org.scalatest.BeforeAndAfter + +import org.apache.spark.sql._ +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.sources.{StreamSinkProvider, StreamSourceProvider} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} +import org.apache.spark.util.Utils + +object LastOptions { + + var mockStreamSourceProvider = mock(classOf[StreamSourceProvider]) + var mockStreamSinkProvider = mock(classOf[StreamSinkProvider]) + var parameters: Map[String, String] = null + var schema: Option[StructType] = null + var partitionColumns: Seq[String] = Nil + + def clear(): Unit = { + parameters = null + schema = null + partitionColumns = null + reset(mockStreamSourceProvider) + reset(mockStreamSinkProvider) + } +} + +/** Dummy provider: returns no-op source/sink and records options in [[LastOptions]]. */ +class DefaultSource extends StreamSourceProvider with StreamSinkProvider { + + private val fakeSchema = StructType(StructField("a", IntegerType) :: Nil) + + override def sourceSchema( + spark: SQLContext, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): (String, StructType) = { + LastOptions.parameters = parameters + LastOptions.schema = schema + LastOptions.mockStreamSourceProvider.sourceSchema(spark, schema, providerName, parameters) + ("dummySource", fakeSchema) + } + + override def createSource( + spark: SQLContext, + metadataPath: String, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): Source = { + LastOptions.parameters = parameters + LastOptions.schema = schema + LastOptions.mockStreamSourceProvider.createSource( + spark, metadataPath, schema, providerName, parameters) + new Source { + override def schema: StructType = fakeSchema + + override def getOffset: Option[Offset] = Some(new LongOffset(0)) + + override def getBatch(start: Option[Offset], end: Offset): DataFrame = { + import spark.implicits._ + + Seq[Int]().toDS().toDF() + } + } + } + + override def createSink( + spark: SQLContext, + parameters: Map[String, String], + partitionColumns: Seq[String]): Sink = { + LastOptions.parameters = parameters + LastOptions.partitionColumns = partitionColumns + LastOptions.mockStreamSinkProvider.createSink(spark, parameters, partitionColumns) + new Sink { + override def addBatch(batchId: Long, data: DataFrame): Unit = {} + } + } +} + +class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with BeforeAndAfter { + + private def newMetadataDir = + Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath + + after { + spark.streams.active.foreach(_.stop()) + } + + test("resolve default source") { + spark.read + .format("org.apache.spark.sql.streaming.test") + .stream() + .write + .format("org.apache.spark.sql.streaming.test") + .option("checkpointLocation", newMetadataDir) + .startStream() + .stop() + } + + test("resolve full class") { + spark.read + .format("org.apache.spark.sql.streaming.test.DefaultSource") + .stream() + .write + .format("org.apache.spark.sql.streaming.test") + .option("checkpointLocation", newMetadataDir) + .startStream() + .stop() + } + + test("options") { + val map = new java.util.HashMap[String, String] + map.put("opt3", "3") + + val df = spark.read + .format("org.apache.spark.sql.streaming.test") + .option("opt1", "1") + .options(Map("opt2" -> "2")) + .options(map) + .stream() + + assert(LastOptions.parameters("opt1") == "1") + assert(LastOptions.parameters("opt2") == "2") + assert(LastOptions.parameters("opt3") == "3") + + LastOptions.clear() + + df.write + .format("org.apache.spark.sql.streaming.test") + .option("opt1", "1") + .options(Map("opt2" -> "2")) + .options(map) + .option("checkpointLocation", newMetadataDir) + .startStream() + .stop() + + assert(LastOptions.parameters("opt1") == "1") + assert(LastOptions.parameters("opt2") == "2") + assert(LastOptions.parameters("opt3") == "3") + } + + test("partitioning") { + val df = spark.read + .format("org.apache.spark.sql.streaming.test") + .stream() + + df.write + .format("org.apache.spark.sql.streaming.test") + .option("checkpointLocation", newMetadataDir) + .startStream() + .stop() + assert(LastOptions.partitionColumns == Nil) + + df.write + .format("org.apache.spark.sql.streaming.test") + .option("checkpointLocation", newMetadataDir) + .partitionBy("a") + .startStream() + .stop() + assert(LastOptions.partitionColumns == Seq("a")) + + withSQLConf("spark.sql.caseSensitive" -> "false") { + df.write + .format("org.apache.spark.sql.streaming.test") + .option("checkpointLocation", newMetadataDir) + .partitionBy("A") + .startStream() + .stop() + assert(LastOptions.partitionColumns == Seq("a")) + } + + intercept[AnalysisException] { + df.write + .format("org.apache.spark.sql.streaming.test") + .option("checkpointLocation", newMetadataDir) + .partitionBy("b") + .startStream() + .stop() + } + } + + test("stream paths") { + val df = spark.read + .format("org.apache.spark.sql.streaming.test") + .option("checkpointLocation", newMetadataDir) + .stream("/test") + + assert(LastOptions.parameters("path") == "/test") + + LastOptions.clear() + + df.write + .format("org.apache.spark.sql.streaming.test") + .option("checkpointLocation", newMetadataDir) + .startStream("/test") + .stop() + + assert(LastOptions.parameters("path") == "/test") + } + + test("test different data types for options") { + val df = spark.read + .format("org.apache.spark.sql.streaming.test") + .option("intOpt", 56) + .option("boolOpt", false) + .option("doubleOpt", 6.7) + .stream("/test") + + assert(LastOptions.parameters("intOpt") == "56") + assert(LastOptions.parameters("boolOpt") == "false") + assert(LastOptions.parameters("doubleOpt") == "6.7") + + LastOptions.clear() + df.write + .format("org.apache.spark.sql.streaming.test") + .option("intOpt", 56) + .option("boolOpt", false) + .option("doubleOpt", 6.7) + .option("checkpointLocation", newMetadataDir) + .startStream("/test") + .stop() + + assert(LastOptions.parameters("intOpt") == "56") + assert(LastOptions.parameters("boolOpt") == "false") + assert(LastOptions.parameters("doubleOpt") == "6.7") + } + + test("unique query names") { + + /** Start a query with a specific name */ + def startQueryWithName(name: String = ""): ContinuousQuery = { + spark.read + .format("org.apache.spark.sql.streaming.test") + .stream("/test") + .write + .format("org.apache.spark.sql.streaming.test") + .option("checkpointLocation", newMetadataDir) + .queryName(name) + .startStream() + } + + /** Start a query without specifying a name */ + def startQueryWithoutName(): ContinuousQuery = { + spark.read + .format("org.apache.spark.sql.streaming.test") + .stream("/test") + .write + .format("org.apache.spark.sql.streaming.test") + .option("checkpointLocation", newMetadataDir) + .startStream() + } + + /** Get the names of active streams */ + def activeStreamNames: Set[String] = { + val streams = spark.streams.active + val names = streams.map(_.name).toSet + assert(streams.length === names.size, s"names of active queries are not unique: $names") + names + } + + val q1 = startQueryWithName("name") + + // Should not be able to start another query with the same name + intercept[IllegalArgumentException] { + startQueryWithName("name") + } + assert(activeStreamNames === Set("name")) + + // Should be able to start queries with other names + val q3 = startQueryWithName("another-name") + assert(activeStreamNames === Set("name", "another-name")) + + // Should be able to start queries with auto-generated names + val q4 = startQueryWithoutName() + assert(activeStreamNames.contains(q4.name)) + + // Should not be able to start a query with same auto-generated name + intercept[IllegalArgumentException] { + startQueryWithName(q4.name) + } + + // Should be able to start query with that name after stopping the previous query + q1.stop() + val q5 = startQueryWithName("name") + assert(activeStreamNames.contains("name")) + spark.streams.active.foreach(_.stop()) + } + + test("trigger") { + val df = spark.read + .format("org.apache.spark.sql.streaming.test") + .stream("/test") + + var q = df.write + .format("org.apache.spark.sql.streaming.test") + .option("checkpointLocation", newMetadataDir) + .trigger(ProcessingTime(10.seconds)) + .startStream() + q.stop() + + assert(q.asInstanceOf[StreamExecution].trigger == ProcessingTime(10000)) + + q = df.write + .format("org.apache.spark.sql.streaming.test") + .option("checkpointLocation", newMetadataDir) + .trigger(ProcessingTime.create(100, TimeUnit.SECONDS)) + .startStream() + q.stop() + + assert(q.asInstanceOf[StreamExecution].trigger == ProcessingTime(100000)) + } + + test("source metadataPath") { + LastOptions.clear() + + val checkpointLocation = newMetadataDir + + val df1 = spark.read + .format("org.apache.spark.sql.streaming.test") + .stream() + + val df2 = spark.read + .format("org.apache.spark.sql.streaming.test") + .stream() + + val q = df1.union(df2).write + .format("org.apache.spark.sql.streaming.test") + .option("checkpointLocation", checkpointLocation) + .trigger(ProcessingTime(10.seconds)) + .startStream() + q.stop() + + verify(LastOptions.mockStreamSourceProvider).createSource( + spark.sqlContext, + checkpointLocation + "/sources/0", + None, + "org.apache.spark.sql.streaming.test", + Map.empty) + + verify(LastOptions.mockStreamSourceProvider).createSource( + spark.sqlContext, + checkpointLocation + "/sources/1", + None, + "org.apache.spark.sql.streaming.test", + Map.empty) + } + + private def newTextInput = Utils.createTempDir(namePrefix = "text").getCanonicalPath + + test("check trigger() can only be called on continuous queries") { + val df = spark.read.text(newTextInput) + val w = df.write.option("checkpointLocation", newMetadataDir) + val e = intercept[AnalysisException](w.trigger(ProcessingTime("10 seconds"))) + assert(e.getMessage == "trigger() can only be called on continuous queries;") + } + + test("check queryName() can only be called on continuous queries") { + val df = spark.read.text(newTextInput) + val w = df.write.option("checkpointLocation", newMetadataDir) + val e = intercept[AnalysisException](w.queryName("queryName")) + assert(e.getMessage == "queryName() can only be called on continuous queries;") + } + + test("check startStream() can only be called on continuous queries") { + val df = spark.read.text(newTextInput) + val w = df.write.option("checkpointLocation", newMetadataDir) + val e = intercept[AnalysisException](w.startStream()) + assert(e.getMessage == "startStream() can only be called on continuous queries;") + } + + test("check startStream(path) can only be called on continuous queries") { + val df = spark.read.text(newTextInput) + val w = df.write.option("checkpointLocation", newMetadataDir) + val e = intercept[AnalysisException](w.startStream("non_exist_path")) + assert(e.getMessage == "startStream() can only be called on continuous queries;") + } + + test("check mode(SaveMode) can only be called on non-continuous queries") { + val df = spark.read + .format("org.apache.spark.sql.streaming.test") + .stream() + val w = df.write + val e = intercept[AnalysisException](w.mode(SaveMode.Append)) + assert(e.getMessage == "mode() can only be called on non-continuous queries;") + } + + test("check mode(string) can only be called on non-continuous queries") { + val df = spark.read + .format("org.apache.spark.sql.streaming.test") + .stream() + val w = df.write + val e = intercept[AnalysisException](w.mode("append")) + assert(e.getMessage == "mode() can only be called on non-continuous queries;") + } + + test("check bucketBy() can only be called on non-continuous queries") { + val df = spark.read + .format("org.apache.spark.sql.streaming.test") + .stream() + val w = df.write + val e = intercept[IllegalArgumentException](w.bucketBy(1, "text").startStream()) + assert(e.getMessage == "Currently we don't support writing bucketed data to this data source.") + } + + test("check sortBy() can only be called on non-continuous queries;") { + val df = spark.read + .format("org.apache.spark.sql.streaming.test") + .stream() + val w = df.write + val e = intercept[IllegalArgumentException](w.sortBy("text").startStream()) + assert(e.getMessage == "Currently we don't support writing bucketed data to this data source.") + } + + test("check save(path) can only be called on non-continuous queries") { + val df = spark.read + .format("org.apache.spark.sql.streaming.test") + .stream() + val w = df.write + val e = intercept[AnalysisException](w.save("non_exist_path")) + assert(e.getMessage == "save() can only be called on non-continuous queries;") + } + + test("check save() can only be called on non-continuous queries") { + val df = spark.read + .format("org.apache.spark.sql.streaming.test") + .stream() + val w = df.write + val e = intercept[AnalysisException](w.save()) + assert(e.getMessage == "save() can only be called on non-continuous queries;") + } + + test("check insertInto() can only be called on non-continuous queries") { + val df = spark.read + .format("org.apache.spark.sql.streaming.test") + .stream() + val w = df.write + val e = intercept[AnalysisException](w.insertInto("non_exsit_table")) + assert(e.getMessage == "insertInto() can only be called on non-continuous queries;") + } + + test("check saveAsTable() can only be called on non-continuous queries") { + val df = spark.read + .format("org.apache.spark.sql.streaming.test") + .stream() + val w = df.write + val e = intercept[AnalysisException](w.saveAsTable("non_exsit_table")) + assert(e.getMessage == "saveAsTable() can only be called on non-continuous queries;") + } + + test("check jdbc() can only be called on non-continuous queries") { + val df = spark.read + .format("org.apache.spark.sql.streaming.test") + .stream() + val w = df.write + val e = intercept[AnalysisException](w.jdbc(null, null, null)) + assert(e.getMessage == "jdbc() can only be called on non-continuous queries;") + } + + test("check json() can only be called on non-continuous queries") { + val df = spark.read + .format("org.apache.spark.sql.streaming.test") + .stream() + val w = df.write + val e = intercept[AnalysisException](w.json("non_exist_path")) + assert(e.getMessage == "json() can only be called on non-continuous queries;") + } + + test("check parquet() can only be called on non-continuous queries") { + val df = spark.read + .format("org.apache.spark.sql.streaming.test") + .stream() + val w = df.write + val e = intercept[AnalysisException](w.parquet("non_exist_path")) + assert(e.getMessage == "parquet() can only be called on non-continuous queries;") + } + + test("check orc() can only be called on non-continuous queries") { + val df = spark.read + .format("org.apache.spark.sql.streaming.test") + .stream() + val w = df.write + val e = intercept[AnalysisException](w.orc("non_exist_path")) + assert(e.getMessage == "orc() can only be called on non-continuous queries;") + } + + test("check text() can only be called on non-continuous queries") { + val df = spark.read + .format("org.apache.spark.sql.streaming.test") + .stream() + val w = df.write + val e = intercept[AnalysisException](w.text("non_exist_path")) + assert(e.getMessage == "text() can only be called on non-continuous queries;") + } + + test("check csv() can only be called on non-continuous queries") { + val df = spark.read + .format("org.apache.spark.sql.streaming.test") + .stream() + val w = df.write + val e = intercept[AnalysisException](w.csv("non_exist_path")) + assert(e.getMessage == "csv() can only be called on non-continuous queries;") + } + + test("ConsoleSink can be correctly loaded") { + LastOptions.clear() + val df = spark.read + .format("org.apache.spark.sql.streaming.test") + .stream() + + val cq = df.write + .format("console") + .option("checkpointLocation", newMetadataDir) + .trigger(ProcessingTime(2.seconds)) + .startStream() + + cq.awaitTermination(2000L) + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/b3ee53b8/sql/hive/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister ---------------------------------------------------------------------- diff --git a/sql/hive/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/hive/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index 4a774fb..32aa13f 100644 --- a/sql/hive/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/hive/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -1 +1 @@ -org.apache.spark.sql.hive.orc.DefaultSource +org.apache.spark.sql.hive.orc.OrcFileFormat http://git-wip-us.apache.org/repos/asf/spark/blob/b3ee53b8/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 86ab152..b377a20 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -32,8 +32,8 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.execution.command.CreateTableAsSelectLogicalPlan import org.apache.spark.sql.execution.datasources.{Partition => _, _} -import org.apache.spark.sql.execution.datasources.parquet.{DefaultSource => ParquetDefaultSource, ParquetRelation} -import org.apache.spark.sql.hive.orc.{DefaultSource => OrcDefaultSource} +import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat +import org.apache.spark.sql.hive.orc.OrcFileFormat import org.apache.spark.sql.types._ @@ -281,7 +281,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log val inferredSchema = defaultSource.inferSchema(sparkSession, options, fileCatalog.allFiles()) inferredSchema.map { inferred => - ParquetRelation.mergeMetastoreParquetSchema(metastoreSchema, inferred) + ParquetFileFormat.mergeMetastoreParquetSchema(metastoreSchema, inferred) }.getOrElse(metastoreSchema) } else { defaultSource.inferSchema(sparkSession, options, fileCatalog.allFiles()).get @@ -348,13 +348,13 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log } private def convertToParquetRelation(relation: MetastoreRelation): LogicalRelation = { - val defaultSource = new ParquetDefaultSource() - val fileFormatClass = classOf[ParquetDefaultSource] + val defaultSource = new ParquetFileFormat() + val fileFormatClass = classOf[ParquetFileFormat] val mergeSchema = sessionState.convertMetastoreParquetWithSchemaMerging val options = Map( - ParquetRelation.MERGE_SCHEMA -> mergeSchema.toString, - ParquetRelation.METASTORE_TABLE_NAME -> TableIdentifier( + ParquetFileFormat.MERGE_SCHEMA -> mergeSchema.toString, + ParquetFileFormat.METASTORE_TABLE_NAME -> TableIdentifier( relation.tableName, Some(relation.databaseName) ).unquotedString @@ -400,8 +400,8 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log } private def convertToOrcRelation(relation: MetastoreRelation): LogicalRelation = { - val defaultSource = new OrcDefaultSource() - val fileFormatClass = classOf[OrcDefaultSource] + val defaultSource = new OrcFileFormat() + val fileFormatClass = classOf[OrcFileFormat] val options = Map[String, String]() convertToLogicalRelation(relation, options, defaultSource, fileFormatClass, "orc") http://git-wip-us.apache.org/repos/asf/spark/blob/b3ee53b8/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala new file mode 100644 index 0000000..f119817 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -0,0 +1,375 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import java.net.URI +import java.util.Properties + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.hadoop.hive.ql.io.orc._ +import org.apache.hadoop.hive.serde2.objectinspector.{SettableStructObjectInspector, StructObjectInspector} +import org.apache.hadoop.hive.serde2.typeinfo.{StructTypeInfo, TypeInfoUtils} +import org.apache.hadoop.io.{NullWritable, Writable} +import org.apache.hadoop.mapred.{InputFormat => MapRedInputFormat, JobConf, OutputFormat => MapRedOutputFormat, RecordWriter, Reporter} +import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, FileSplit} + +import org.apache.spark.internal.Logging +import org.apache.spark.rdd.{HadoopRDD, RDD} +import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.hive.{HiveInspectors, HiveShim} +import org.apache.spark.sql.sources.{Filter, _} +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.SerializableConfiguration + +/** + * [[FileFormat]] for reading ORC files. If this is moved or renamed, please update + * [[DataSource]]'s backwardCompatibilityMap. + */ +private[sql] class OrcFileFormat + extends FileFormat with DataSourceRegister with Serializable { + + override def shortName(): String = "orc" + + override def toString: String = "ORC" + + override def inferSchema( + sparkSession: SparkSession, + options: Map[String, String], + files: Seq[FileStatus]): Option[StructType] = { + OrcFileOperator.readSchema( + files.map(_.getPath.toUri.toString), + Some(sparkSession.sessionState.newHadoopConf()) + ) + } + + override def prepareWrite( + sparkSession: SparkSession, + job: Job, + options: Map[String, String], + dataSchema: StructType): OutputWriterFactory = { + val orcOptions = new OrcOptions(options) + + val configuration = job.getConfiguration + + configuration.set(OrcRelation.ORC_COMPRESSION, orcOptions.compressionCodec) + configuration match { + case conf: JobConf => + conf.setOutputFormat(classOf[OrcOutputFormat]) + case conf => + conf.setClass( + "mapred.output.format.class", + classOf[OrcOutputFormat], + classOf[MapRedOutputFormat[_, _]]) + } + + new OutputWriterFactory { + override def newInstance( + path: String, + bucketId: Option[Int], + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + new OrcOutputWriter(path, bucketId, dataSchema, context) + } + } + } + + override def buildReader( + sparkSession: SparkSession, + dataSchema: StructType, + partitionSchema: StructType, + requiredSchema: StructType, + filters: Seq[Filter], + options: Map[String, String], + hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { + if (sparkSession.sessionState.conf.orcFilterPushDown) { + // Sets pushed predicates + OrcFilters.createFilter(requiredSchema, filters.toArray).foreach { f => + hadoopConf.set(OrcTableScan.SARG_PUSHDOWN, f.toKryo) + hadoopConf.setBoolean(ConfVars.HIVEOPTINDEXFILTER.varname, true) + } + } + + val broadcastedHadoopConf = + sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + + (file: PartitionedFile) => { + val conf = broadcastedHadoopConf.value.value + + // SPARK-8501: Empty ORC files always have an empty schema stored in their footer. In this + // case, `OrcFileOperator.readSchema` returns `None`, and we can't read the underlying file + // using the given physical schema. Instead, we simply return an empty iterator. + val maybePhysicalSchema = OrcFileOperator.readSchema(Seq(file.filePath), Some(conf)) + if (maybePhysicalSchema.isEmpty) { + Iterator.empty + } else { + val physicalSchema = maybePhysicalSchema.get + OrcRelation.setRequiredColumns(conf, physicalSchema, requiredSchema) + + val orcRecordReader = { + val job = Job.getInstance(conf) + FileInputFormat.setInputPaths(job, file.filePath) + + val fileSplit = new FileSplit( + new Path(new URI(file.filePath)), file.start, file.length, Array.empty + ) + // Custom OrcRecordReader is used to get + // ObjectInspector during recordReader creation itself and can + // avoid NameNode call in unwrapOrcStructs per file. + // Specifically would be helpful for partitioned datasets. + val orcReader = OrcFile.createReader( + new Path(new URI(file.filePath)), OrcFile.readerOptions(conf)) + new SparkOrcNewRecordReader(orcReader, conf, fileSplit.getStart, fileSplit.getLength) + } + + // Unwraps `OrcStruct`s to `UnsafeRow`s + OrcRelation.unwrapOrcStructs( + conf, + requiredSchema, + Some(orcRecordReader.getObjectInspector.asInstanceOf[StructObjectInspector]), + new RecordReaderIterator[OrcStruct](orcRecordReader)) + } + } + } +} + +private[orc] class OrcSerializer(dataSchema: StructType, conf: Configuration) + extends HiveInspectors { + + def serialize(row: InternalRow): Writable = { + wrapOrcStruct(cachedOrcStruct, structOI, row) + serializer.serialize(cachedOrcStruct, structOI) + } + + private[this] val serializer = { + val table = new Properties() + table.setProperty("columns", dataSchema.fieldNames.mkString(",")) + table.setProperty("columns.types", dataSchema.map(_.dataType.catalogString).mkString(":")) + + val serde = new OrcSerde + serde.initialize(conf, table) + serde + } + + // Object inspector converted from the schema of the relation to be serialized. + private[this] val structOI = { + val typeInfo = TypeInfoUtils.getTypeInfoFromTypeString(dataSchema.catalogString) + OrcStruct.createObjectInspector(typeInfo.asInstanceOf[StructTypeInfo]) + .asInstanceOf[SettableStructObjectInspector] + } + + private[this] val cachedOrcStruct = structOI.create().asInstanceOf[OrcStruct] + + private[this] def wrapOrcStruct( + struct: OrcStruct, + oi: SettableStructObjectInspector, + row: InternalRow): Unit = { + val fieldRefs = oi.getAllStructFieldRefs + var i = 0 + while (i < fieldRefs.size) { + + oi.setStructFieldData( + struct, + fieldRefs.get(i), + wrap( + row.get(i, dataSchema(i).dataType), + fieldRefs.get(i).getFieldObjectInspector, + dataSchema(i).dataType)) + i += 1 + } + } +} + +private[orc] class OrcOutputWriter( + path: String, + bucketId: Option[Int], + dataSchema: StructType, + context: TaskAttemptContext) + extends OutputWriter { + + private[this] val conf = context.getConfiguration + + private[this] val serializer = new OrcSerializer(dataSchema, conf) + + // `OrcRecordWriter.close()` creates an empty file if no rows are written at all. We use this + // flag to decide whether `OrcRecordWriter.close()` needs to be called. + private var recordWriterInstantiated = false + + private lazy val recordWriter: RecordWriter[NullWritable, Writable] = { + recordWriterInstantiated = true + val uniqueWriteJobId = conf.get("spark.sql.sources.writeJobUUID") + val taskAttemptId = context.getTaskAttemptID + val partition = taskAttemptId.getTaskID.getId + val bucketString = bucketId.map(BucketingUtils.bucketIdToString).getOrElse("") + val compressionExtension = { + val name = conf.get(OrcRelation.ORC_COMPRESSION) + OrcRelation.extensionsForCompressionCodecNames.getOrElse(name, "") + } + // It has the `.orc` extension at the end because (de)compression tools + // such as gunzip would not be able to decompress this as the compression + // is not applied on this whole file but on each "stream" in ORC format. + val filename = f"part-r-$partition%05d-$uniqueWriteJobId$bucketString$compressionExtension.orc" + + new OrcOutputFormat().getRecordWriter( + new Path(path, filename).getFileSystem(conf), + conf.asInstanceOf[JobConf], + new Path(path, filename).toString, + Reporter.NULL + ).asInstanceOf[RecordWriter[NullWritable, Writable]] + } + + override def write(row: Row): Unit = + throw new UnsupportedOperationException("call writeInternal") + + override protected[sql] def writeInternal(row: InternalRow): Unit = { + recordWriter.write(NullWritable.get(), serializer.serialize(row)) + } + + override def close(): Unit = { + if (recordWriterInstantiated) { + recordWriter.close(Reporter.NULL) + } + } +} + +private[orc] case class OrcTableScan( + @transient sparkSession: SparkSession, + attributes: Seq[Attribute], + filters: Array[Filter], + @transient inputPaths: Seq[FileStatus]) + extends Logging + with HiveInspectors { + + def execute(): RDD[InternalRow] = { + val job = Job.getInstance(sparkSession.sessionState.newHadoopConf()) + val conf = job.getConfiguration + + // Figure out the actual schema from the ORC source (without partition columns) so that we + // can pick the correct ordinals. Note that this assumes that all files have the same schema. + val orcFormat = new OrcFileFormat + val dataSchema = + orcFormat + .inferSchema(sparkSession, Map.empty, inputPaths) + .getOrElse(sys.error("Failed to read schema from target ORC files.")) + + // Tries to push down filters if ORC filter push-down is enabled + if (sparkSession.sessionState.conf.orcFilterPushDown) { + OrcFilters.createFilter(dataSchema, filters).foreach { f => + conf.set(OrcTableScan.SARG_PUSHDOWN, f.toKryo) + conf.setBoolean(ConfVars.HIVEOPTINDEXFILTER.varname, true) + } + } + + // Sets requested columns + OrcRelation.setRequiredColumns(conf, dataSchema, StructType.fromAttributes(attributes)) + + if (inputPaths.isEmpty) { + // the input path probably be pruned, return an empty RDD. + return sparkSession.sparkContext.emptyRDD[InternalRow] + } + FileInputFormat.setInputPaths(job, inputPaths.map(_.getPath): _*) + + val inputFormatClass = + classOf[OrcInputFormat] + .asInstanceOf[Class[_ <: MapRedInputFormat[NullWritable, Writable]]] + + val rdd = sparkSession.sparkContext.hadoopRDD( + conf.asInstanceOf[JobConf], + inputFormatClass, + classOf[NullWritable], + classOf[Writable] + ).asInstanceOf[HadoopRDD[NullWritable, Writable]] + + val wrappedConf = new SerializableConfiguration(conf) + + rdd.mapPartitionsWithInputSplit { case (split: OrcSplit, iterator) => + val writableIterator = iterator.map(_._2) + val maybeStructOI = OrcFileOperator.getObjectInspector(split.getPath.toString, Some(conf)) + OrcRelation.unwrapOrcStructs( + wrappedConf.value, + StructType.fromAttributes(attributes), + maybeStructOI, + writableIterator + ) + } + } +} + +private[orc] object OrcTableScan { + // This constant duplicates `OrcInputFormat.SARG_PUSHDOWN`, which is unfortunately not public. + private[orc] val SARG_PUSHDOWN = "sarg.pushdown" +} + +private[orc] object OrcRelation extends HiveInspectors { + // The references of Hive's classes will be minimized. + val ORC_COMPRESSION = "orc.compress" + + // The extensions for ORC compression codecs + val extensionsForCompressionCodecNames = Map( + "NONE" -> "", + "SNAPPY" -> ".snappy", + "ZLIB" -> ".zlib", + "LZO" -> ".lzo") + + def unwrapOrcStructs( + conf: Configuration, + dataSchema: StructType, + maybeStructOI: Option[StructObjectInspector], + iterator: Iterator[Writable]): Iterator[InternalRow] = { + val deserializer = new OrcSerde + val mutableRow = new SpecificMutableRow(dataSchema.map(_.dataType)) + val unsafeProjection = UnsafeProjection.create(dataSchema) + + def unwrap(oi: StructObjectInspector): Iterator[InternalRow] = { + val (fieldRefs, fieldOrdinals) = dataSchema.zipWithIndex.map { + case (field, ordinal) => oi.getStructFieldRef(field.name) -> ordinal + }.unzip + + val unwrappers = fieldRefs.map(unwrapperFor) + + iterator.map { value => + val raw = deserializer.deserialize(value) + var i = 0 + while (i < fieldRefs.length) { + val fieldValue = oi.getStructFieldData(raw, fieldRefs(i)) + if (fieldValue == null) { + mutableRow.setNullAt(fieldOrdinals(i)) + } else { + unwrappers(i)(fieldValue, mutableRow, fieldOrdinals(i)) + } + i += 1 + } + unsafeProjection(mutableRow) + } + } + + maybeStructOI.map(unwrap).getOrElse(Iterator.empty) + } + + def setRequiredColumns( + conf: Configuration, physicalSchema: StructType, requestedSchema: StructType): Unit = { + val ids = requestedSchema.map(a => physicalSchema.fieldIndex(a.name): Integer) + val (sortedIDs, sortedNames) = ids.zip(requestedSchema.fieldNames).sorted.unzip + HiveShim.appendReadColumns(conf, sortedIDs, sortedNames) + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
