http://git-wip-us.apache.org/repos/asf/spark/blob/e7548871/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala index 5267f5f..e9cc399 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala @@ -21,6 +21,7 @@ import java.util.regex.Pattern import org.apache.spark.internal.Logging import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.sources.v2.{DataSourceV2, SessionConfigSupport} private[sql] object DataSourceV2Utils extends Logging { @@ -55,4 +56,12 @@ private[sql] object DataSourceV2Utils extends Logging { case _ => Map.empty } + + def failForUserSpecifiedSchema[T](ds: DataSourceV2): T = { + val name = ds match { + case register: DataSourceRegister => register.shortName() + case _ => ds.getClass.getName + } + throw new UnsupportedOperationException(name + " source does not support user-specified schema") + } }
http://git-wip-us.apache.org/repos/asf/spark/blob/e7548871/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 59ebb9b..c3f7b69 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -23,15 +23,11 @@ import org.apache.spark.{SparkEnv, SparkException, TaskContext} import org.apache.spark.executor.CommitDeniedException import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.streaming.MicroBatchExecution import org.apache.spark.sql.sources.v2.writer._ -import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils /** @@ -39,7 +35,8 @@ import org.apache.spark.util.Utils * specific logical plans, like [[org.apache.spark.sql.catalyst.plans.logical.AppendData]]. */ @deprecated("Use specific logical plans like AppendData instead", "2.4.0") -case class WriteToDataSourceV2(writer: DataSourceWriter, query: LogicalPlan) extends LogicalPlan { +case class WriteToDataSourceV2(writeSupport: BatchWriteSupport, query: LogicalPlan) + extends LogicalPlan { override def children: Seq[LogicalPlan] = Seq(query) override def output: Seq[Attribute] = Nil } @@ -47,46 +44,48 @@ case class WriteToDataSourceV2(writer: DataSourceWriter, query: LogicalPlan) ext /** * The physical plan for writing data into data source v2. */ -case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) extends SparkPlan { +case class WriteToDataSourceV2Exec(writeSupport: BatchWriteSupport, query: SparkPlan) + extends SparkPlan { + override def children: Seq[SparkPlan] = Seq(query) override def output: Seq[Attribute] = Nil override protected def doExecute(): RDD[InternalRow] = { - val writeTask = writer.createWriterFactory() - val useCommitCoordinator = writer.useCommitCoordinator + val writerFactory = writeSupport.createBatchWriterFactory() + val useCommitCoordinator = writeSupport.useCommitCoordinator val rdd = query.execute() val messages = new Array[WriterCommitMessage](rdd.partitions.length) - logInfo(s"Start processing data source writer: $writer. " + + logInfo(s"Start processing data source write support: $writeSupport. " + s"The input RDD has ${messages.length} partitions.") try { sparkContext.runJob( rdd, (context: TaskContext, iter: Iterator[InternalRow]) => - DataWritingSparkTask.run(writeTask, context, iter, useCommitCoordinator), + DataWritingSparkTask.run(writerFactory, context, iter, useCommitCoordinator), rdd.partitions.indices, (index, message: WriterCommitMessage) => { messages(index) = message - writer.onDataWriterCommit(message) + writeSupport.onDataWriterCommit(message) } ) - logInfo(s"Data source writer $writer is committing.") - writer.commit(messages) - logInfo(s"Data source writer $writer committed.") + logInfo(s"Data source write support $writeSupport is committing.") + writeSupport.commit(messages) + logInfo(s"Data source write support $writeSupport committed.") } catch { case cause: Throwable => - logError(s"Data source writer $writer is aborting.") + logError(s"Data source write support $writeSupport is aborting.") try { - writer.abort(messages) + writeSupport.abort(messages) } catch { case t: Throwable => - logError(s"Data source writer $writer failed to abort.") + logError(s"Data source write support $writeSupport failed to abort.") cause.addSuppressed(t) throw new SparkException("Writing job failed.", cause) } - logError(s"Data source writer $writer aborted.") + logError(s"Data source write support $writeSupport aborted.") cause match { // Only wrap non fatal exceptions. case NonFatal(e) => throw new SparkException("Writing job aborted.", e) @@ -100,7 +99,7 @@ case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) e object DataWritingSparkTask extends Logging { def run( - writeTask: DataWriterFactory[InternalRow], + writerFactory: DataWriterFactory, context: TaskContext, iter: Iterator[InternalRow], useCommitCoordinator: Boolean): WriterCommitMessage = { @@ -109,8 +108,7 @@ object DataWritingSparkTask extends Logging { val partId = context.partitionId() val taskId = context.taskAttemptId() val attemptId = context.attemptNumber() - val epochId = Option(context.getLocalProperty(MicroBatchExecution.BATCH_ID_KEY)).getOrElse("0") - val dataWriter = writeTask.createDataWriter(partId, taskId, epochId.toLong) + val dataWriter = writerFactory.createWriter(partId, taskId) // write the data and commit this writer. Utils.tryWithSafeFinallyAndFailureCallbacks(block = { http://git-wip-us.apache.org/repos/asf/spark/blob/e7548871/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index b1c91ac..cf83ba7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.execution.streaming -import java.util.Optional - import scala.collection.JavaConverters._ import scala.collection.mutable.{Map => MutableMap} @@ -28,9 +26,9 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentBatchTimestamp, import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} -import org.apache.spark.sql.execution.streaming.sources.MicroBatchWriter -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport, StreamWriteSupport} -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} +import org.apache.spark.sql.execution.streaming.sources.{MicroBatchWritSupport, RateControlMicroBatchReadSupport} +import org.apache.spark.sql.sources.v2._ +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset => OffsetV2} import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} import org.apache.spark.util.{Clock, Utils} @@ -51,8 +49,8 @@ class MicroBatchExecution( @volatile protected var sources: Seq[BaseStreamingSource] = Seq.empty - private val readerToDataSourceMap = - MutableMap.empty[MicroBatchReader, (DataSourceV2, Map[String, String])] + private val readSupportToDataSourceMap = + MutableMap.empty[MicroBatchReadSupport, (DataSourceV2, Map[String, String])] private val triggerExecutor = trigger match { case t: ProcessingTime => ProcessingTimeExecutor(t, triggerClock) @@ -91,20 +89,19 @@ class MicroBatchExecution( StreamingExecutionRelation(source, output)(sparkSession) }) case s @ StreamingRelationV2( - dataSourceV2: MicroBatchReadSupport, sourceName, options, output, _) if + dataSourceV2: MicroBatchReadSupportProvider, sourceName, options, output, _) if !disabledSources.contains(dataSourceV2.getClass.getCanonicalName) => v2ToExecutionRelationMap.getOrElseUpdate(s, { // Materialize source to avoid creating it in every batch val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" - val reader = dataSourceV2.createMicroBatchReader( - Optional.empty(), // user specified schema + val readSupport = dataSourceV2.createMicroBatchReadSupport( metadataPath, new DataSourceOptions(options.asJava)) nextSourceId += 1 - readerToDataSourceMap(reader) = dataSourceV2 -> options - logInfo(s"Using MicroBatchReader [$reader] from " + + readSupportToDataSourceMap(readSupport) = dataSourceV2 -> options + logInfo(s"Using MicroBatchReadSupport [$readSupport] from " + s"DataSourceV2 named '$sourceName' [$dataSourceV2]") - StreamingExecutionRelation(reader, output)(sparkSession) + StreamingExecutionRelation(readSupport, output)(sparkSession) }) case s @ StreamingRelationV2(dataSourceV2, sourceName, _, output, v1Relation) => v2ToExecutionRelationMap.getOrElseUpdate(s, { @@ -340,19 +337,19 @@ class MicroBatchExecution( reportTimeTaken("getOffset") { (s, s.getOffset) } - case s: MicroBatchReader => + case s: RateControlMicroBatchReadSupport => updateStatusMessage(s"Getting offsets from $s") - reportTimeTaken("setOffsetRange") { - // Once v1 streaming source execution is gone, we can refactor this away. - // For now, we set the range here to get the source to infer the available end offset, - // get that offset, and then set the range again when we later execute. - s.setOffsetRange( - toJava(availableOffsets.get(s).map(off => s.deserializeOffset(off.json))), - Optional.empty()) + reportTimeTaken("latestOffset") { + val startOffset = availableOffsets + .get(s).map(off => s.deserializeOffset(off.json)) + .getOrElse(s.initialOffset()) + (s, Option(s.latestOffset(startOffset))) + } + case s: MicroBatchReadSupport => + updateStatusMessage(s"Getting offsets from $s") + reportTimeTaken("latestOffset") { + (s, Option(s.latestOffset())) } - - val currentOffset = reportTimeTaken("getEndOffset") { s.getEndOffset() } - (s, Option(currentOffset)) }.toMap availableOffsets ++= latestOffsets.filter { case (_, o) => o.nonEmpty }.mapValues(_.get) @@ -392,8 +389,8 @@ class MicroBatchExecution( if (prevBatchOff.isDefined) { prevBatchOff.get.toStreamProgress(sources).foreach { case (src: Source, off) => src.commit(off) - case (reader: MicroBatchReader, off) => - reader.commit(reader.deserializeOffset(off.json)) + case (readSupport: MicroBatchReadSupport, off) => + readSupport.commit(readSupport.deserializeOffset(off.json)) case (src, _) => throw new IllegalArgumentException( s"Unknown source is found at constructNextBatch: $src") @@ -437,30 +434,34 @@ class MicroBatchExecution( s"${batch.queryExecution.logical}") logDebug(s"Retrieving data from $source: $current -> $available") Some(source -> batch.logicalPlan) - case (reader: MicroBatchReader, available) - if committedOffsets.get(reader).map(_ != available).getOrElse(true) => - val current = committedOffsets.get(reader).map(off => reader.deserializeOffset(off.json)) - val availableV2: OffsetV2 = available match { - case v1: SerializedOffset => reader.deserializeOffset(v1.json) + + // TODO(cloud-fan): for data source v2, the new batch is just a new `ScanConfigBuilder`, but + // to be compatible with streaming source v1, we return a logical plan as a new batch here. + case (readSupport: MicroBatchReadSupport, available) + if committedOffsets.get(readSupport).map(_ != available).getOrElse(true) => + val current = committedOffsets.get(readSupport).map { + off => readSupport.deserializeOffset(off.json) + } + val endOffset: OffsetV2 = available match { + case v1: SerializedOffset => readSupport.deserializeOffset(v1.json) case v2: OffsetV2 => v2 } - reader.setOffsetRange( - toJava(current), - Optional.of(availableV2)) - logDebug(s"Retrieving data from $reader: $current -> $availableV2") + val startOffset = current.getOrElse(readSupport.initialOffset) + val scanConfigBuilder = readSupport.newScanConfigBuilder(startOffset, endOffset) + logDebug(s"Retrieving data from $readSupport: $current -> $endOffset") - val (source, options) = reader match { + val (source, options) = readSupport match { // `MemoryStream` is special. It's for test only and doesn't have a `DataSourceV2` // implementation. We provide a fake one here for explain. case _: MemoryStream[_] => MemoryStreamDataSource -> Map.empty[String, String] // Provide a fake value here just in case something went wrong, e.g. the reader gives // a wrong `equals` implementation. - case _ => readerToDataSourceMap.getOrElse(reader, { + case _ => readSupportToDataSourceMap.getOrElse(readSupport, { FakeDataSourceV2 -> Map.empty[String, String] }) } - Some(reader -> StreamingDataSourceV2Relation( - reader.readSchema().toAttributes, source, options, reader)) + Some(readSupport -> StreamingDataSourceV2Relation( + readSupport.fullSchema().toAttributes, source, options, readSupport, scanConfigBuilder)) case _ => None } } @@ -494,13 +495,13 @@ class MicroBatchExecution( val triggerLogicalPlan = sink match { case _: Sink => newAttributePlan - case s: StreamWriteSupport => - val writer = s.createStreamWriter( + case s: StreamingWriteSupportProvider => + val writer = s.createStreamingWriteSupport( s"$runId", newAttributePlan.schema, outputMode, new DataSourceOptions(extraOptions.asJava)) - WriteToDataSourceV2(new MicroBatchWriter(currentBatchId, writer), newAttributePlan) + WriteToDataSourceV2(new MicroBatchWritSupport(currentBatchId, writer), newAttributePlan) case _ => throw new IllegalArgumentException(s"unknown sink type for $sink") } @@ -526,7 +527,7 @@ class MicroBatchExecution( SQLExecution.withNewExecutionId(sparkSessionToRunBatch, lastExecution) { sink match { case s: Sink => s.addBatch(currentBatchId, nextBatch) - case _: StreamWriteSupport => + case _: StreamingWriteSupportProvider => // This doesn't accumulate any data - it just forces execution of the microbatch writer. nextBatch.collect() } @@ -551,10 +552,6 @@ class MicroBatchExecution( awaitProgressLock.unlock() } } - - private def toJava(scalaOption: Option[OffsetV2]): Optional[OffsetV2] = { - Optional.ofNullable(scalaOption.orNull) - } } object MicroBatchExecution { http://git-wip-us.apache.org/repos/asf/spark/blob/e7548871/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index ae1bfa2..417b6b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -24,7 +24,6 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.control.NonFatal -import org.json4s.JsonAST.JValue import org.json4s.jackson.JsonMethods.parse import org.apache.spark.internal.Logging @@ -33,11 +32,10 @@ import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalP import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanExec, WriteToDataSourceV2Exec} -import org.apache.spark.sql.execution.streaming.sources.MicroBatchWriter +import org.apache.spark.sql.execution.streaming.sources.MicroBatchWritSupport import org.apache.spark.sql.sources.v2.CustomMetrics -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, SupportsCustomReaderMetrics} -import org.apache.spark.sql.sources.v2.writer.DataSourceWriter -import org.apache.spark.sql.sources.v2.writer.streaming.SupportsCustomWriterMetrics +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, SupportsCustomReaderMetrics} +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingWriteSupport, SupportsCustomWriterMetrics} import org.apache.spark.sql.streaming._ import org.apache.spark.sql.streaming.StreamingQueryListener.QueryProgressEvent import org.apache.spark.util.Clock @@ -201,7 +199,7 @@ trait ProgressReporter extends Logging { ) } - val customWriterMetrics = dataSourceWriter match { + val customWriterMetrics = extractWriteSupport() match { case Some(s: SupportsCustomWriterMetrics) => extractMetrics(() => Option(s.getCustomMetrics), s.onInvalidMetrics) @@ -238,13 +236,13 @@ trait ProgressReporter extends Logging { } /** Extract writer from the executed query plan. */ - private def dataSourceWriter: Option[DataSourceWriter] = { + private def extractWriteSupport(): Option[StreamingWriteSupport] = { if (lastExecution == null) return None lastExecution.executedPlan.collect { case p if p.isInstanceOf[WriteToDataSourceV2Exec] => - p.asInstanceOf[WriteToDataSourceV2Exec].writer + p.asInstanceOf[WriteToDataSourceV2Exec].writeSupport }.headOption match { - case Some(w: MicroBatchWriter) => Some(w.writer) + case Some(w: MicroBatchWritSupport) => Some(w.writeSupport) case _ => None } } @@ -303,7 +301,7 @@ trait ProgressReporter extends Logging { // Check whether the streaming query's logical plan has only V2 data sources val allStreamingLeaves = logicalPlan.collect { case s: StreamingExecutionRelation => s } - allStreamingLeaves.forall { _.source.isInstanceOf[MicroBatchReader] } + allStreamingLeaves.forall { _.source.isInstanceOf[MicroBatchReadSupport] } } if (onlyDataSourceV2Sources) { @@ -330,7 +328,7 @@ trait ProgressReporter extends Logging { new IdentityHashMap[DataSourceV2ScanExec, DataSourceV2ScanExec]() lastExecution.executedPlan.collectLeaves().foreach { - case s: DataSourceV2ScanExec if s.reader.isInstanceOf[BaseStreamingSource] => + case s: DataSourceV2ScanExec if s.readSupport.isInstanceOf[BaseStreamingSource] => uniqueStreamingExecLeavesMap.put(s, s) case _ => } @@ -338,7 +336,7 @@ trait ProgressReporter extends Logging { val sourceToInputRowsTuples = uniqueStreamingExecLeavesMap.values.asScala.map { execLeaf => val numRows = execLeaf.metrics.get("numOutputRows").map(_.value).getOrElse(0L) - val source = execLeaf.reader.asInstanceOf[BaseStreamingSource] + val source = execLeaf.readSupport.asInstanceOf[BaseStreamingSource] source -> numRows }.toSeq logDebug("Source -> # input rows\n\t" + sourceToInputRowsTuples.mkString("\n\t")) http://git-wip-us.apache.org/repos/asf/spark/blob/e7548871/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/SimpleStreamingScanConfigBuilder.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/SimpleStreamingScanConfigBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/SimpleStreamingScanConfigBuilder.scala new file mode 100644 index 0000000..1be0716 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/SimpleStreamingScanConfigBuilder.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.sql.sources.v2.reader.{ScanConfig, ScanConfigBuilder} +import org.apache.spark.sql.types.StructType + +/** + * A very simple [[ScanConfigBuilder]] implementation that creates a simple [[ScanConfig]] to + * carry schema and offsets for streaming data sources. + */ +class SimpleStreamingScanConfigBuilder( + schema: StructType, + start: Offset, + end: Option[Offset] = None) + extends ScanConfigBuilder { + + override def build(): ScanConfig = SimpleStreamingScanConfig(schema, start, end) +} + +case class SimpleStreamingScanConfig( + readSchema: StructType, + start: Offset, + end: Option[Offset]) + extends ScanConfig http://git-wip-us.apache.org/repos/asf/spark/blob/e7548871/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala index 24195b5..4b696df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.execution.LeafExecNode import org.apache.spark.sql.execution.datasources.DataSource -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceV2} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceV2} object StreamingRelation { def apply(dataSource: DataSource): StreamingRelation = { @@ -83,7 +83,7 @@ case class StreamingExecutionRelation( // We have to pack in the V1 data source as a shim, for the case when a source implements // continuous processing (which is always V2) but only has V1 microbatch support. We don't -// know at read time whether the query is conntinuous or not, so we need to be able to +// know at read time whether the query is continuous or not, so we need to be able to // swap a V1 relation back in. /** * Used to link a [[DataSourceV2]] into a streaming @@ -113,7 +113,7 @@ case class StreamingRelationV2( * Used to link a [[DataSourceV2]] into a continuous processing execution. */ case class ContinuousExecutionRelation( - source: ContinuousReadSupport, + source: ContinuousReadSupportProvider, extraOptions: Map[String, String], output: Seq[Attribute])(session: SparkSession) extends LeafNode with MultiInstanceRelation { http://git-wip-us.apache.org/repos/asf/spark/blob/e7548871/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala index cfba100..9c5c16f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -18,10 +18,10 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.sql._ -import org.apache.spark.sql.execution.streaming.sources.ConsoleWriter +import org.apache.spark.sql.execution.streaming.sources.ConsoleWriteSupport import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister} -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamWriteSupport} -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamingWriteSupportProvider} +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -31,16 +31,16 @@ case class ConsoleRelation(override val sqlContext: SQLContext, data: DataFrame) } class ConsoleSinkProvider extends DataSourceV2 - with StreamWriteSupport + with StreamingWriteSupportProvider with DataSourceRegister with CreatableRelationProvider { - override def createStreamWriter( + override def createStreamingWriteSupport( queryId: String, schema: StructType, mode: OutputMode, - options: DataSourceOptions): StreamWriter = { - new ConsoleWriter(schema, options) + options: DataSourceOptions): StreamingWriteSupport = { + new ConsoleWriteSupport(schema, options) } def createRelation( http://git-wip-us.apache.org/repos/asf/spark/blob/e7548871/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala index 554a0b0..b68f67e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala @@ -21,12 +21,13 @@ import org.apache.spark._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousInputPartitionReader +import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousPartitionReaderFactory +import org.apache.spark.sql.types.StructType import org.apache.spark.util.NextIterator class ContinuousDataSourceRDDPartition( val index: Int, - val inputPartition: InputPartition[InternalRow]) + val inputPartition: InputPartition) extends Partition with Serializable { // This is semantically a lazy val - it's initialized once the first time a call to @@ -49,15 +50,22 @@ class ContinuousDataSourceRDD( sc: SparkContext, dataQueueSize: Int, epochPollIntervalMs: Long, - private val readerInputPartitions: Seq[InputPartition[InternalRow]]) + private val inputPartitions: Seq[InputPartition], + schema: StructType, + partitionReaderFactory: ContinuousPartitionReaderFactory) extends RDD[InternalRow](sc, Nil) { override protected def getPartitions: Array[Partition] = { - readerInputPartitions.zipWithIndex.map { + inputPartitions.zipWithIndex.map { case (inputPartition, index) => new ContinuousDataSourceRDDPartition(index, inputPartition) }.toArray } + private def castPartition(split: Partition): ContinuousDataSourceRDDPartition = split match { + case p: ContinuousDataSourceRDDPartition => p + case _ => throw new SparkException(s"[BUG] Not a ContinuousDataSourceRDDPartition: $split") + } + /** * Initialize the shared reader for this partition if needed, then read rows from it until * it returns null to signal the end of the epoch. @@ -69,10 +77,12 @@ class ContinuousDataSourceRDD( } val readerForPartition = { - val partition = split.asInstanceOf[ContinuousDataSourceRDDPartition] + val partition = castPartition(split) if (partition.queueReader == null) { - partition.queueReader = - new ContinuousQueuedDataReader(partition, context, dataQueueSize, epochPollIntervalMs) + val partitionReader = partitionReaderFactory.createReader( + partition.inputPartition) + partition.queueReader = new ContinuousQueuedDataReader( + partition.index, partitionReader, schema, context, dataQueueSize, epochPollIntervalMs) } partition.queueReader @@ -93,17 +103,6 @@ class ContinuousDataSourceRDD( } override def getPreferredLocations(split: Partition): Seq[String] = { - split.asInstanceOf[ContinuousDataSourceRDDPartition].inputPartition.preferredLocations() - } -} - -object ContinuousDataSourceRDD { - private[continuous] def getContinuousReader( - reader: InputPartitionReader[InternalRow]): ContinuousInputPartitionReader[_] = { - reader match { - case r: ContinuousInputPartitionReader[InternalRow] => r - case _ => - throw new IllegalStateException(s"Unknown continuous reader type ${reader.getClass}") - } + castPartition(split).inputPartition.preferredLocations() } } http://git-wip-us.apache.org/repos/asf/spark/blob/e7548871/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 140cec6..4ddebb3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -29,13 +29,12 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentDate, CurrentTimestamp} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanExec, StreamingDataSourceV2Relation} import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, StreamingRelationV2, _} import org.apache.spark.sql.sources.v2 -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, StreamWriteSupport} -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, PartitionOffset} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceOptions, StreamingWriteSupportProvider} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, PartitionOffset} import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} -import org.apache.spark.sql.types.StructType import org.apache.spark.util.{Clock, Utils} class ContinuousExecution( @@ -43,7 +42,7 @@ class ContinuousExecution( name: String, checkpointRoot: String, analyzedPlan: LogicalPlan, - sink: StreamWriteSupport, + sink: StreamingWriteSupportProvider, trigger: Trigger, triggerClock: Clock, outputMode: OutputMode, @@ -53,7 +52,7 @@ class ContinuousExecution( sparkSession, name, checkpointRoot, analyzedPlan, sink, trigger, triggerClock, outputMode, deleteCheckpointOnStop) { - @volatile protected var continuousSources: Seq[ContinuousReader] = Seq() + @volatile protected var continuousSources: Seq[ContinuousReadSupport] = Seq() override protected def sources: Seq[BaseStreamingSource] = continuousSources // For use only in test harnesses. @@ -63,7 +62,8 @@ class ContinuousExecution( val toExecutionRelationMap = MutableMap[StreamingRelationV2, ContinuousExecutionRelation]() analyzedPlan.transform { case r @ StreamingRelationV2( - source: ContinuousReadSupport, _, extraReaderOptions, output, _) => + source: ContinuousReadSupportProvider, _, extraReaderOptions, output, _) => + // TODO: shall we create `ContinuousReadSupport` here instead of each reconfiguration? toExecutionRelationMap.getOrElseUpdate(r, { ContinuousExecutionRelation(source, extraReaderOptions, output)(sparkSession) }) @@ -148,8 +148,7 @@ class ContinuousExecution( val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" nextSourceId += 1 - dataSource.createContinuousReader( - java.util.Optional.empty[StructType](), + dataSource.createContinuousReadSupport( metadataPath, new DataSourceOptions(extraReaderOptions.asJava)) } @@ -160,9 +159,9 @@ class ContinuousExecution( var insertedSourceId = 0 val withNewSources = logicalPlan transform { case ContinuousExecutionRelation(source, options, output) => - val reader = continuousSources(insertedSourceId) + val readSupport = continuousSources(insertedSourceId) insertedSourceId += 1 - val newOutput = reader.readSchema().toAttributes + val newOutput = readSupport.fullSchema().toAttributes assert(output.size == newOutput.size, s"Invalid reader: ${Utils.truncatedString(output, ",")} != " + @@ -170,9 +169,10 @@ class ContinuousExecution( replacements ++= output.zip(newOutput) val loggedOffset = offsets.offsets(0) - val realOffset = loggedOffset.map(off => reader.deserializeOffset(off.json)) - reader.setStartOffset(java.util.Optional.ofNullable(realOffset.orNull)) - StreamingDataSourceV2Relation(newOutput, source, options, reader) + val realOffset = loggedOffset.map(off => readSupport.deserializeOffset(off.json)) + val startOffset = realOffset.getOrElse(readSupport.initialOffset) + val scanConfigBuilder = readSupport.newScanConfigBuilder(startOffset) + StreamingDataSourceV2Relation(newOutput, source, options, readSupport, scanConfigBuilder) } // Rewire the plan to use the new attributes that were returned by the source. @@ -185,17 +185,13 @@ class ContinuousExecution( "CurrentTimestamp and CurrentDate not yet supported for continuous processing") } - val writer = sink.createStreamWriter( + val writer = sink.createStreamingWriteSupport( s"$runId", triggerLogicalPlan.schema, outputMode, new DataSourceOptions(extraOptions.asJava)) val withSink = WriteToContinuousDataSource(writer, triggerLogicalPlan) - val reader = withSink.collect { - case StreamingDataSourceV2Relation(_, _, _, r: ContinuousReader) => r - }.head - reportTimeTaken("queryPlanning") { lastExecution = new IncrementalExecution( sparkSessionForQuery, @@ -208,6 +204,11 @@ class ContinuousExecution( lastExecution.executedPlan // Force the lazy generation of execution plan } + val (readSupport, scanConfig) = lastExecution.executedPlan.collect { + case scan: DataSourceV2ScanExec if scan.readSupport.isInstanceOf[ContinuousReadSupport] => + scan.readSupport.asInstanceOf[ContinuousReadSupport] -> scan.scanConfig + }.head + sparkSessionForQuery.sparkContext.setLocalProperty( ContinuousExecution.START_EPOCH_KEY, currentBatchId.toString) // Add another random ID on top of the run ID, to distinguish epoch coordinators across @@ -223,14 +224,16 @@ class ContinuousExecution( // Use the parent Spark session for the endpoint since it's where this query ID is registered. val epochEndpoint = EpochCoordinatorRef.create( - writer, reader, this, epochCoordinatorId, currentBatchId, sparkSession, SparkEnv.get) + writer, readSupport, this, epochCoordinatorId, currentBatchId, sparkSession, SparkEnv.get) val epochUpdateThread = new Thread(new Runnable { override def run: Unit = { try { triggerExecutor.execute(() => { startTrigger() - if (reader.needsReconfiguration() && state.compareAndSet(ACTIVE, RECONFIGURING)) { + val shouldReconfigure = readSupport.needsReconfiguration(scanConfig) && + state.compareAndSet(ACTIVE, RECONFIGURING) + if (shouldReconfigure) { if (queryExecutionThread.isAlive) { queryExecutionThread.interrupt() } @@ -280,10 +283,12 @@ class ContinuousExecution( * Report ending partition offsets for the given reader at the given epoch. */ def addOffset( - epoch: Long, reader: ContinuousReader, partitionOffsets: Seq[PartitionOffset]): Unit = { + epoch: Long, + readSupport: ContinuousReadSupport, + partitionOffsets: Seq[PartitionOffset]): Unit = { assert(continuousSources.length == 1, "only one continuous source supported currently") - val globalOffset = reader.mergeOffsets(partitionOffsets.toArray) + val globalOffset = readSupport.mergeOffsets(partitionOffsets.toArray) val oldOffset = synchronized { offsetLog.add(epoch, OffsetSeq.fill(globalOffset)) offsetLog.get(epoch - 1) http://git-wip-us.apache.org/repos/asf/spark/blob/e7548871/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala index ec1dabd..65c5fc6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala @@ -25,8 +25,9 @@ import scala.util.control.NonFatal import org.apache.spark.{SparkEnv, SparkException, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader} -import org.apache.spark.sql.sources.v2.reader.streaming.PartitionOffset +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousPartitionReader, PartitionOffset} +import org.apache.spark.sql.types.StructType import org.apache.spark.util.ThreadUtils /** @@ -37,15 +38,14 @@ import org.apache.spark.util.ThreadUtils * offsets across epochs. Each compute() should call the next() method here until null is returned. */ class ContinuousQueuedDataReader( - partition: ContinuousDataSourceRDDPartition, + partitionIndex: Int, + reader: ContinuousPartitionReader[InternalRow], + schema: StructType, context: TaskContext, dataQueueSize: Int, epochPollIntervalMs: Long) extends Closeable { - private val reader = partition.inputPartition.createPartitionReader() - // Important sequencing - we must get our starting point before the provider threads start running - private var currentOffset: PartitionOffset = - ContinuousDataSourceRDD.getContinuousReader(reader).getOffset + private var currentOffset: PartitionOffset = reader.getOffset /** * The record types in the read buffer. @@ -66,7 +66,7 @@ class ContinuousQueuedDataReader( epochMarkerExecutor.scheduleWithFixedDelay( epochMarkerGenerator, 0, epochPollIntervalMs, TimeUnit.MILLISECONDS) - private val dataReaderThread = new DataReaderThread + private val dataReaderThread = new DataReaderThread(schema) dataReaderThread.setDaemon(true) dataReaderThread.start() @@ -113,7 +113,7 @@ class ContinuousQueuedDataReader( currentEntry match { case EpochMarker => epochCoordEndpoint.send(ReportPartitionOffset( - partition.index, EpochTracker.getCurrentEpoch.get, currentOffset)) + partitionIndex, EpochTracker.getCurrentEpoch.get, currentOffset)) null case ContinuousRow(row, offset) => currentOffset = offset @@ -128,16 +128,16 @@ class ContinuousQueuedDataReader( /** * The data component of [[ContinuousQueuedDataReader]]. Pushes (row, offset) to the queue when - * a new row arrives to the [[InputPartitionReader]]. + * a new row arrives to the [[ContinuousPartitionReader]]. */ - class DataReaderThread extends Thread( + class DataReaderThread(schema: StructType) extends Thread( s"continuous-reader--${context.partitionId()}--" + s"${context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY)}") with Logging { @volatile private[continuous] var failureReason: Throwable = _ + private val toUnsafe = UnsafeProjection.create(schema) override def run(): Unit = { TaskContext.setTaskContext(context) - val baseReader = ContinuousDataSourceRDD.getContinuousReader(reader) try { while (!shouldStop()) { if (!reader.next()) { @@ -149,8 +149,9 @@ class ContinuousQueuedDataReader( return } } - - queue.put(ContinuousRow(reader.get().copy(), baseReader.getOffset)) + // `InternalRow#copy` may not be properly implemented, for safety we convert to unsafe row + // before copy here. + queue.put(ContinuousRow(toUnsafe(reader.get()).copy(), reader.getOffset)) } } catch { case _: InterruptedException => http://git-wip-us.apache.org/repos/asf/spark/blob/e7548871/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index 551e07c..a6cde2b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -17,24 +17,22 @@ package org.apache.spark.sql.execution.streaming.continuous -import scala.collection.JavaConverters._ - import org.json4s.DefaultFormats import org.json4s.jackson.Serialization import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair} +import org.apache.spark.sql.execution.streaming.{RateStreamOffset, SimpleStreamingScanConfig, SimpleStreamingScanConfigBuilder, ValueRunTimeMsPair} import org.apache.spark.sql.execution.streaming.sources.RateStreamProvider import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, ContinuousReader, Offset, PartitionOffset} +import org.apache.spark.sql.sources.v2.reader.streaming._ import org.apache.spark.sql.types.StructType case class RateStreamPartitionOffset( partition: Int, currentValue: Long, currentTimeMs: Long) extends PartitionOffset -class RateStreamContinuousReader(options: DataSourceOptions) extends ContinuousReader { +class RateStreamContinuousReadSupport(options: DataSourceOptions) extends ContinuousReadSupport { implicit val defaultFormats: DefaultFormats = DefaultFormats val creationTime = System.currentTimeMillis() @@ -56,18 +54,18 @@ class RateStreamContinuousReader(options: DataSourceOptions) extends ContinuousR RateStreamOffset(Serialization.read[Map[Int, ValueRunTimeMsPair]](json)) } - override def readSchema(): StructType = RateStreamProvider.SCHEMA - - private var offset: Offset = _ + override def fullSchema(): StructType = RateStreamProvider.SCHEMA - override def setStartOffset(offset: java.util.Optional[Offset]): Unit = { - this.offset = offset.orElse(createInitialOffset(numPartitions, creationTime)) + override def newScanConfigBuilder(start: Offset): ScanConfigBuilder = { + new SimpleStreamingScanConfigBuilder(fullSchema(), start) } - override def getStartOffset(): Offset = offset + override def initialOffset: Offset = createInitialOffset(numPartitions, creationTime) - override def planInputPartitions(): java.util.List[InputPartition[InternalRow]] = { - val partitionStartMap = offset match { + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + val startOffset = config.asInstanceOf[SimpleStreamingScanConfig].start + + val partitionStartMap = startOffset match { case off: RateStreamOffset => off.partitionToValueAndRunTimeMs case off => throw new IllegalArgumentException( @@ -90,8 +88,12 @@ class RateStreamContinuousReader(options: DataSourceOptions) extends ContinuousR i, numPartitions, perPartitionRate) - .asInstanceOf[InputPartition[InternalRow]] - }.asJava + }.toArray + } + + override def createContinuousReaderFactory( + config: ScanConfig): ContinuousPartitionReaderFactory = { + RateStreamContinuousReaderFactory } override def commit(end: Offset): Unit = {} @@ -118,33 +120,23 @@ case class RateStreamContinuousInputPartition( partitionIndex: Int, increment: Long, rowsPerSecond: Double) - extends ContinuousInputPartition[InternalRow] { - - override def createContinuousReader( - offset: PartitionOffset): InputPartitionReader[InternalRow] = { - val rateStreamOffset = offset.asInstanceOf[RateStreamPartitionOffset] - require(rateStreamOffset.partition == partitionIndex, - s"Expected partitionIndex: $partitionIndex, but got: ${rateStreamOffset.partition}") - new RateStreamContinuousInputPartitionReader( - rateStreamOffset.currentValue, - rateStreamOffset.currentTimeMs, - partitionIndex, - increment, - rowsPerSecond) - } + extends InputPartition - override def createPartitionReader(): InputPartitionReader[InternalRow] = - new RateStreamContinuousInputPartitionReader( - startValue, startTimeMs, partitionIndex, increment, rowsPerSecond) +object RateStreamContinuousReaderFactory extends ContinuousPartitionReaderFactory { + override def createReader(partition: InputPartition): ContinuousPartitionReader[InternalRow] = { + val p = partition.asInstanceOf[RateStreamContinuousInputPartition] + new RateStreamContinuousPartitionReader( + p.startValue, p.startTimeMs, p.partitionIndex, p.increment, p.rowsPerSecond) + } } -class RateStreamContinuousInputPartitionReader( +class RateStreamContinuousPartitionReader( startValue: Long, startTimeMs: Long, partitionIndex: Int, increment: Long, rowsPerSecond: Double) - extends ContinuousInputPartitionReader[InternalRow] { + extends ContinuousPartitionReader[InternalRow] { private var nextReadTime: Long = startTimeMs private val readTimeIncrement: Long = (1000 / rowsPerSecond).toLong http://git-wip-us.apache.org/repos/asf/spark/blob/e7548871/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala index 1dbdfd5..28ab244 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala @@ -20,10 +20,9 @@ package org.apache.spark.sql.execution.streaming.continuous import java.io.{BufferedReader, InputStreamReader, IOException} import java.net.Socket import java.sql.Timestamp -import java.util.{Calendar, List => JList} +import java.util.Calendar import javax.annotation.concurrent.GuardedBy -import scala.collection.JavaConverters._ import scala.collection.mutable.ListBuffer import org.json4s.{DefaultFormats, NoTypeHints} @@ -34,24 +33,26 @@ import org.apache.spark.internal.Logging import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.streaming.{ContinuousRecordEndpoint, ContinuousRecordPartitionOffset, GetRecord} +import org.apache.spark.sql.execution.streaming.{Offset => _, _} import org.apache.spark.sql.execution.streaming.sources.TextSocketReader import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader, SupportsDeprecatedScanRow} -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, ContinuousReader, Offset, PartitionOffset} -import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.reader.streaming._ +import org.apache.spark.sql.types.StructType import org.apache.spark.util.RpcUtils /** - * A ContinuousReader that reads text lines through a TCP socket, designed only for tutorials and - * debugging. This ContinuousReader will *not* work in production applications due to multiple - * reasons, including no support for fault recovery. + * A ContinuousReadSupport that reads text lines through a TCP socket, designed only for tutorials + * and debugging. This ContinuousReadSupport will *not* work in production applications due to + * multiple reasons, including no support for fault recovery. * * The driver maintains a socket connection to the host-port, keeps the received messages in * buckets and serves the messages to the executors via a RPC endpoint. */ -class TextSocketContinuousReader(options: DataSourceOptions) extends ContinuousReader with Logging { +class TextSocketContinuousReadSupport(options: DataSourceOptions) + extends ContinuousReadSupport with Logging { + implicit val defaultFormats: DefaultFormats = DefaultFormats private val host: String = options.get("host").get() @@ -73,7 +74,8 @@ class TextSocketContinuousReader(options: DataSourceOptions) extends ContinuousR @GuardedBy("this") private var currentOffset: Int = -1 - private var startOffset: TextSocketOffset = _ + // Exposed for tests. + private[spark] var startOffset: TextSocketOffset = _ private val recordEndpoint = new ContinuousRecordEndpoint(buckets, this) @volatile private var endpointRef: RpcEndpointRef = _ @@ -94,16 +96,16 @@ class TextSocketContinuousReader(options: DataSourceOptions) extends ContinuousR TextSocketOffset(Serialization.read[List[Int]](json)) } - override def setStartOffset(offset: java.util.Optional[Offset]): Unit = { - this.startOffset = offset - .orElse(TextSocketOffset(List.fill(numPartitions)(0))) - .asInstanceOf[TextSocketOffset] - recordEndpoint.setStartOffsets(startOffset.offsets) + override def initialOffset(): Offset = { + startOffset = TextSocketOffset(List.fill(numPartitions)(0)) + startOffset } - override def getStartOffset: Offset = startOffset + override def newScanConfigBuilder(start: Offset): ScanConfigBuilder = { + new SimpleStreamingScanConfigBuilder(fullSchema(), start) + } - override def readSchema(): StructType = { + override def fullSchema(): StructType = { if (includeTimestamp) { TextSocketReader.SCHEMA_TIMESTAMP } else { @@ -111,8 +113,10 @@ class TextSocketContinuousReader(options: DataSourceOptions) extends ContinuousR } } - override def planInputPartitions(): JList[InputPartition[InternalRow]] = { - + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + val startOffset = config.asInstanceOf[SimpleStreamingScanConfig] + .start.asInstanceOf[TextSocketOffset] + recordEndpoint.setStartOffsets(startOffset.offsets) val endpointName = s"TextSocketContinuousReaderEndpoint-${java.util.UUID.randomUUID()}" endpointRef = recordEndpoint.rpcEnv.setupEndpoint(endpointName, recordEndpoint) @@ -132,10 +136,13 @@ class TextSocketContinuousReader(options: DataSourceOptions) extends ContinuousR startOffset.offsets.zipWithIndex.map { case (offset, i) => - TextSocketContinuousInputPartition( - endpointName, i, offset, includeTimestamp): InputPartition[InternalRow] - }.asJava + TextSocketContinuousInputPartition(endpointName, i, offset, includeTimestamp) + }.toArray + } + override def createContinuousReaderFactory( + config: ScanConfig): ContinuousPartitionReaderFactory = { + TextSocketReaderFactory } override def commit(end: Offset): Unit = synchronized { @@ -190,7 +197,7 @@ class TextSocketContinuousReader(options: DataSourceOptions) extends ContinuousR logWarning(s"Stream closed by $host:$port") return } - TextSocketContinuousReader.this.synchronized { + TextSocketContinuousReadSupport.this.synchronized { currentOffset += 1 val newData = (line, Timestamp.valueOf( @@ -221,25 +228,30 @@ case class TextSocketContinuousInputPartition( driverEndpointName: String, partitionId: Int, startOffset: Int, - includeTimestamp: Boolean) -extends InputPartition[InternalRow] { + includeTimestamp: Boolean) extends InputPartition + + +object TextSocketReaderFactory extends ContinuousPartitionReaderFactory { - override def createPartitionReader(): InputPartitionReader[InternalRow] = - new TextSocketContinuousInputPartitionReader(driverEndpointName, partitionId, startOffset, - includeTimestamp) + override def createReader(partition: InputPartition): ContinuousPartitionReader[InternalRow] = { + val p = partition.asInstanceOf[TextSocketContinuousInputPartition] + new TextSocketContinuousPartitionReader( + p.driverEndpointName, p.partitionId, p.startOffset, p.includeTimestamp) + } } + /** * Continuous text socket input partition reader. * * Polls the driver endpoint for new records. */ -class TextSocketContinuousInputPartitionReader( +class TextSocketContinuousPartitionReader( driverEndpointName: String, partitionId: Int, startOffset: Int, includeTimestamp: Boolean) - extends ContinuousInputPartitionReader[InternalRow] { + extends ContinuousPartitionReader[InternalRow] { private val endpoint = RpcUtils.makeDriverRef( driverEndpointName, http://git-wip-us.apache.org/repos/asf/spark/blob/e7548871/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala index 967dbe2..a08411d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala @@ -20,7 +20,8 @@ package org.apache.spark.sql.execution.streaming.continuous import org.apache.spark.{Partition, SparkEnv, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory} +import org.apache.spark.sql.sources.v2.writer.DataWriter +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingDataWriterFactory import org.apache.spark.util.Utils /** @@ -31,7 +32,7 @@ import org.apache.spark.util.Utils * * We keep repeating prev.compute() and writing new epochs until the query is shut down. */ -class ContinuousWriteRDD(var prev: RDD[InternalRow], writeTask: DataWriterFactory[InternalRow]) +class ContinuousWriteRDD(var prev: RDD[InternalRow], writerFactory: StreamingDataWriterFactory) extends RDD[Unit](prev) { override val partitioner = prev.partitioner @@ -50,7 +51,7 @@ class ContinuousWriteRDD(var prev: RDD[InternalRow], writeTask: DataWriterFactor Utils.tryWithSafeFinallyAndFailureCallbacks(block = { try { val dataIterator = prev.compute(split, context) - dataWriter = writeTask.createDataWriter( + dataWriter = writerFactory.createWriter( context.partitionId(), context.taskAttemptId(), EpochTracker.getCurrentEpoch.get) http://git-wip-us.apache.org/repos/asf/spark/blob/e7548871/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala index 8877ebe..2238ce2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala @@ -23,9 +23,9 @@ import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, PartitionOffset} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, PartitionOffset} import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport import org.apache.spark.util.RpcUtils private[continuous] sealed trait EpochCoordinatorMessage extends Serializable @@ -82,15 +82,15 @@ private[sql] object EpochCoordinatorRef extends Logging { * Create a reference to a new [[EpochCoordinator]]. */ def create( - writer: StreamWriter, - reader: ContinuousReader, + writeSupport: StreamingWriteSupport, + readSupport: ContinuousReadSupport, query: ContinuousExecution, epochCoordinatorId: String, startEpoch: Long, session: SparkSession, env: SparkEnv): RpcEndpointRef = synchronized { val coordinator = new EpochCoordinator( - writer, reader, query, startEpoch, session, env.rpcEnv) + writeSupport, readSupport, query, startEpoch, session, env.rpcEnv) val ref = env.rpcEnv.setupEndpoint(endpointName(epochCoordinatorId), coordinator) logInfo("Registered EpochCoordinator endpoint") ref @@ -115,8 +115,8 @@ private[sql] object EpochCoordinatorRef extends Logging { * have both committed and reported an end offset for a given epoch. */ private[continuous] class EpochCoordinator( - writer: StreamWriter, - reader: ContinuousReader, + writeSupport: StreamingWriteSupport, + readSupport: ContinuousReadSupport, query: ContinuousExecution, startEpoch: Long, session: SparkSession, @@ -198,7 +198,7 @@ private[continuous] class EpochCoordinator( s"and is ready to be committed. Committing epoch $epoch.") // Sequencing is important here. We must commit to the writer before recording the commit // in the query, or we will end up dropping the commit if we restart in the middle. - writer.commit(epoch, messages.toArray) + writeSupport.commit(epoch, messages.toArray) query.commit(epoch) } @@ -220,7 +220,7 @@ private[continuous] class EpochCoordinator( partitionOffsets.collect { case ((e, _), o) if e == epoch => o } if (thisEpochOffsets.size == numReaderPartitions) { logDebug(s"Epoch $epoch has offsets reported from all partitions: $thisEpochOffsets") - query.addOffset(epoch, reader, thisEpochOffsets.toSeq) + query.addOffset(epoch, readSupport, thisEpochOffsets.toSeq) resolveCommitsAtEpoch(epoch) } } http://git-wip-us.apache.org/repos/asf/spark/blob/e7548871/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala index 943c731..7ad21cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala @@ -19,13 +19,13 @@ package org.apache.spark.sql.execution.streaming.continuous import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport /** * The logical plan for writing data in a continuous stream. */ case class WriteToContinuousDataSource( - writer: StreamWriter, query: LogicalPlan) extends LogicalPlan { + writeSupport: StreamingWriteSupport, query: LogicalPlan) extends LogicalPlan { override def children: Seq[LogicalPlan] = Seq(query) override def output: Seq[Attribute] = Nil } http://git-wip-us.apache.org/repos/asf/spark/blob/e7548871/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala index 927d3a8..c216b61 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala @@ -26,21 +26,21 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.streaming.StreamExecution -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport /** - * The physical plan for writing data into a continuous processing [[StreamWriter]]. + * The physical plan for writing data into a continuous processing [[StreamingWriteSupport]]. */ -case class WriteToContinuousDataSourceExec(writer: StreamWriter, query: SparkPlan) +case class WriteToContinuousDataSourceExec(writeSupport: StreamingWriteSupport, query: SparkPlan) extends SparkPlan with Logging { override def children: Seq[SparkPlan] = Seq(query) override def output: Seq[Attribute] = Nil override protected def doExecute(): RDD[InternalRow] = { - val writerFactory = writer.createWriterFactory() + val writerFactory = writeSupport.createStreamingWriterFactory() val rdd = new ContinuousWriteRDD(query.execute(), writerFactory) - logInfo(s"Start processing data source writer: $writer. " + + logInfo(s"Start processing data source write support: $writeSupport. " + s"The input RDD has ${rdd.partitions.length} partitions.") EpochCoordinatorRef.get( sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), http://git-wip-us.apache.org/repos/asf/spark/blob/e7548871/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index f81abdc..adf52ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -17,12 +17,9 @@ package org.apache.spark.sql.execution.streaming -import java.{util => ju} -import java.util.Optional import java.util.concurrent.atomic.AtomicInteger import javax.annotation.concurrent.GuardedBy -import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, ListBuffer} import scala.util.control.NonFatal @@ -34,8 +31,8 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ -import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader} -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset => OffsetV2} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -67,7 +64,7 @@ abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) extends Bas addData(data.toTraversable) } - def readSchema(): StructType = encoder.schema + def fullSchema(): StructType = encoder.schema protected def logicalPlan: LogicalPlan @@ -80,7 +77,7 @@ abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) extends Bas * available. */ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) - extends MemoryStreamBase[A](sqlContext) with MicroBatchReader with Logging { + extends MemoryStreamBase[A](sqlContext) with MicroBatchReadSupport with Logging { protected val logicalPlan: LogicalPlan = StreamingExecutionRelation(this, attributes)(sqlContext.sparkSession) @@ -122,24 +119,22 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) override def toString: String = s"MemoryStream[${Utils.truncatedString(output, ",")}]" - override def setOffsetRange(start: Optional[OffsetV2], end: Optional[OffsetV2]): Unit = { - synchronized { - startOffset = start.orElse(LongOffset(-1)).asInstanceOf[LongOffset] - endOffset = end.orElse(currentOffset).asInstanceOf[LongOffset] - } - } - override def deserializeOffset(json: String): OffsetV2 = LongOffset(json.toLong) - override def getStartOffset: OffsetV2 = synchronized { - if (startOffset.offset == -1) null else startOffset + override def initialOffset: OffsetV2 = LongOffset(-1) + + override def latestOffset(): OffsetV2 = { + if (currentOffset.offset == -1) null else currentOffset } - override def getEndOffset: OffsetV2 = synchronized { - if (endOffset.offset == -1) null else endOffset + override def newScanConfigBuilder(start: OffsetV2, end: OffsetV2): ScanConfigBuilder = { + new SimpleStreamingScanConfigBuilder(fullSchema(), start, Some(end)) } - override def planInputPartitions(): ju.List[InputPartition[InternalRow]] = { + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + val sc = config.asInstanceOf[SimpleStreamingScanConfig] + val startOffset = sc.start.asInstanceOf[LongOffset] + val endOffset = sc.end.get.asInstanceOf[LongOffset] synchronized { // Compute the internal batch numbers to fetch: [startOrdinal, endOrdinal) val startOrdinal = startOffset.offset.toInt + 1 @@ -156,11 +151,15 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) logDebug(generateDebugString(newBlocks.flatten, startOrdinal, endOrdinal)) newBlocks.map { block => - new MemoryStreamInputPartition(block): InputPartition[InternalRow] - }.asJava + new MemoryStreamInputPartition(block) + }.toArray } } + override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { + MemoryStreamReaderFactory + } + private def generateDebugString( rows: Seq[UnsafeRow], startOrdinal: Int, @@ -201,10 +200,12 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) } -class MemoryStreamInputPartition(records: Array[UnsafeRow]) - extends InputPartition[InternalRow] { - override def createPartitionReader(): InputPartitionReader[InternalRow] = { - new InputPartitionReader[InternalRow] { +class MemoryStreamInputPartition(val records: Array[UnsafeRow]) extends InputPartition + +object MemoryStreamReaderFactory extends PartitionReaderFactory { + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + val records = partition.asInstanceOf[MemoryStreamInputPartition].records + new PartitionReader[InternalRow] { private var currentIndex = -1 override def next(): Boolean = { http://git-wip-us.apache.org/repos/asf/spark/blob/e7548871/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupport.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupport.scala new file mode 100644 index 0000000..833e62f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupport.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{Dataset, SparkSession} +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.sources.v2.DataSourceOptions +import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport} +import org.apache.spark.sql.types.StructType + +/** Common methods used to create writes for the the console sink */ +class ConsoleWriteSupport(schema: StructType, options: DataSourceOptions) + extends StreamingWriteSupport with Logging { + + // Number of rows to display, by default 20 rows + protected val numRowsToShow = options.getInt("numRows", 20) + + // Truncate the displayed data if it is too long, by default it is true + protected val isTruncated = options.getBoolean("truncate", true) + + assert(SparkSession.getActiveSession.isDefined) + protected val spark = SparkSession.getActiveSession.get + + def createStreamingWriterFactory(): StreamingDataWriterFactory = PackedRowWriterFactory + + override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { + // We have to print a "Batch" label for the epoch for compatibility with the pre-data source V2 + // behavior. + printRows(messages, schema, s"Batch: $epochId") + } + + def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} + + protected def printRows( + commitMessages: Array[WriterCommitMessage], + schema: StructType, + printMessage: String): Unit = { + val rows = commitMessages.collect { + case PackedRowCommitMessage(rs) => rs + }.flatten + + // scalastyle:off println + println("-------------------------------------------") + println(printMessage) + println("-------------------------------------------") + // scalastyle:off println + Dataset.ofRows(spark, LocalRelation(schema.toAttributes, rows)) + .show(numRowsToShow, isTruncated) + } + + override def toString(): String = { + s"ConsoleWriter[numRows=$numRowsToShow, truncate=$isTruncated]" + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/e7548871/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala deleted file mode 100644 index fd45ba5..0000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming.sources - -import org.apache.spark.internal.Logging -import org.apache.spark.sql.{Dataset, SparkSession} -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.writer.{DataWriterFactory, WriterCommitMessage} -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter -import org.apache.spark.sql.types.StructType - -/** Common methods used to create writes for the the console sink */ -class ConsoleWriter(schema: StructType, options: DataSourceOptions) - extends StreamWriter with Logging { - - // Number of rows to display, by default 20 rows - protected val numRowsToShow = options.getInt("numRows", 20) - - // Truncate the displayed data if it is too long, by default it is true - protected val isTruncated = options.getBoolean("truncate", true) - - assert(SparkSession.getActiveSession.isDefined) - protected val spark = SparkSession.getActiveSession.get - - def createWriterFactory(): DataWriterFactory[InternalRow] = PackedRowWriterFactory - - override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { - // We have to print a "Batch" label for the epoch for compatibility with the pre-data source V2 - // behavior. - printRows(messages, schema, s"Batch: $epochId") - } - - def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} - - protected def printRows( - commitMessages: Array[WriterCommitMessage], - schema: StructType, - printMessage: String): Unit = { - val rows = commitMessages.collect { - case PackedRowCommitMessage(rs) => rs - }.flatten - - // scalastyle:off println - println("-------------------------------------------") - println(printMessage) - println("-------------------------------------------") - // scalastyle:off println - Dataset.ofRows(spark, LocalRelation(schema.toAttributes, rows)) - .show(numRowsToShow, isTruncated) - } - - override def toString(): String = { - s"ConsoleWriter[numRows=$numRowsToShow, truncate=$isTruncated]" - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/e7548871/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala index 4a32217..dbcc448 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala @@ -17,26 +17,22 @@ package org.apache.spark.sql.execution.streaming.sources -import java.{util => ju} -import java.util.Optional import java.util.concurrent.atomic.AtomicInteger import javax.annotation.concurrent.GuardedBy -import scala.collection.JavaConverters._ import scala.collection.mutable.ListBuffer import org.json4s.NoTypeHints import org.json4s.jackson.Serialization import org.apache.spark.{SparkEnv, TaskContext} -import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.sql.{Encoder, SQLContext} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions} -import org.apache.spark.sql.sources.v2.reader.InputPartition -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, ContinuousReader, Offset, PartitionOffset} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.execution.streaming.{Offset => _, _} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceOptions} +import org.apache.spark.sql.sources.v2.reader.{InputPartition, ScanConfig, ScanConfigBuilder} +import org.apache.spark.sql.sources.v2.reader.streaming._ import org.apache.spark.util.RpcUtils /** @@ -48,7 +44,9 @@ import org.apache.spark.util.RpcUtils * the specified offset within the list, or null if that offset doesn't yet have a record. */ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPartitions: Int = 2) - extends MemoryStreamBase[A](sqlContext) with ContinuousReader with ContinuousReadSupport { + extends MemoryStreamBase[A](sqlContext) + with ContinuousReadSupportProvider with ContinuousReadSupport { + private implicit val formats = Serialization.formats(NoTypeHints) protected val logicalPlan = @@ -59,9 +57,6 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPa @GuardedBy("this") private val records = Seq.fill(numPartitions)(new ListBuffer[A]) - @GuardedBy("this") - private var startOffset: ContinuousMemoryStreamOffset = _ - private val recordEndpoint = new ContinuousRecordEndpoint(records, this) @volatile private var endpointRef: RpcEndpointRef = _ @@ -75,15 +70,8 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPa ContinuousMemoryStreamOffset((0 until numPartitions).map(i => (i, records(i).size)).toMap) } - override def setStartOffset(start: Optional[Offset]): Unit = synchronized { - // Inferred initial offset is position 0 in each partition. - startOffset = start.orElse { - ContinuousMemoryStreamOffset((0 until numPartitions).map(i => (i, 0)).toMap) - }.asInstanceOf[ContinuousMemoryStreamOffset] - } - - override def getStartOffset: Offset = synchronized { - startOffset + override def initialOffset(): Offset = { + ContinuousMemoryStreamOffset((0 until numPartitions).map(i => (i, 0)).toMap) } override def deserializeOffset(json: String): ContinuousMemoryStreamOffset = { @@ -98,34 +86,40 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPa ) } - override def planInputPartitions(): ju.List[InputPartition[InternalRow]] = { + override def newScanConfigBuilder(start: Offset): ScanConfigBuilder = { + new SimpleStreamingScanConfigBuilder(fullSchema(), start) + } + + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + val startOffset = config.asInstanceOf[SimpleStreamingScanConfig] + .start.asInstanceOf[ContinuousMemoryStreamOffset] synchronized { val endpointName = s"ContinuousMemoryStreamRecordEndpoint-${java.util.UUID.randomUUID()}-$id" endpointRef = recordEndpoint.rpcEnv.setupEndpoint(endpointName, recordEndpoint) startOffset.partitionNums.map { - case (part, index) => - new ContinuousMemoryStreamInputPartition( - endpointName, part, index): InputPartition[InternalRow] - }.toList.asJava + case (part, index) => ContinuousMemoryStreamInputPartition(endpointName, part, index) + }.toArray } } + override def createContinuousReaderFactory( + config: ScanConfig): ContinuousPartitionReaderFactory = { + ContinuousMemoryStreamReaderFactory + } + override def stop(): Unit = { if (endpointRef != null) recordEndpoint.rpcEnv.stop(endpointRef) } override def commit(end: Offset): Unit = {} - // ContinuousReadSupport implementation + // ContinuousReadSupportProvider implementation // This is necessary because of how StreamTest finds the source for AddDataMemory steps. - def createContinuousReader( - schema: Optional[StructType], + override def createContinuousReadSupport( checkpointLocation: String, - options: DataSourceOptions): ContinuousReader = { - this - } + options: DataSourceOptions): ContinuousReadSupport = this } object ContinuousMemoryStream { @@ -141,12 +135,16 @@ object ContinuousMemoryStream { /** * An input partition for continuous memory stream. */ -class ContinuousMemoryStreamInputPartition( +case class ContinuousMemoryStreamInputPartition( driverEndpointName: String, partition: Int, - startOffset: Int) extends InputPartition[InternalRow] { - override def createPartitionReader: ContinuousMemoryStreamInputPartitionReader = - new ContinuousMemoryStreamInputPartitionReader(driverEndpointName, partition, startOffset) + startOffset: Int) extends InputPartition + +object ContinuousMemoryStreamReaderFactory extends ContinuousPartitionReaderFactory { + override def createReader(partition: InputPartition): ContinuousPartitionReader[InternalRow] = { + val p = partition.asInstanceOf[ContinuousMemoryStreamInputPartition] + new ContinuousMemoryStreamPartitionReader(p.driverEndpointName, p.partition, p.startOffset) + } } /** @@ -154,10 +152,10 @@ class ContinuousMemoryStreamInputPartition( * * Polls the driver endpoint for new records. */ -class ContinuousMemoryStreamInputPartitionReader( +class ContinuousMemoryStreamPartitionReader( driverEndpointName: String, partition: Int, - startOffset: Int) extends ContinuousInputPartitionReader[InternalRow] { + startOffset: Int) extends ContinuousPartitionReader[InternalRow] { private val endpoint = RpcUtils.makeDriverRef( driverEndpointName, SparkEnv.get.conf, --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org