http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 2cac865..7a007b6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.streaming +import java.util.Optional + import scala.collection.JavaConverters._ import scala.collection.mutable.{Map => MutableMap} @@ -26,9 +28,9 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentBatchTimestamp, import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} -import org.apache.spark.sql.execution.streaming.sources.{MicroBatchWritSupport, RateControlMicroBatchReadSupport} -import org.apache.spark.sql.sources.v2._ -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset => OffsetV2} +import org.apache.spark.sql.execution.streaming.sources.MicroBatchWriter +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport, StreamWriteSupport} +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} import org.apache.spark.util.{Clock, Utils} @@ -49,8 +51,8 @@ class MicroBatchExecution( @volatile protected var sources: Seq[BaseStreamingSource] = Seq.empty - private val readSupportToDataSourceMap = - MutableMap.empty[MicroBatchReadSupport, (DataSourceV2, Map[String, String])] + private val readerToDataSourceMap = + MutableMap.empty[MicroBatchReader, (DataSourceV2, Map[String, String])] private val triggerExecutor = trigger match { case t: ProcessingTime => ProcessingTimeExecutor(t, triggerClock) @@ -89,19 +91,20 @@ class MicroBatchExecution( StreamingExecutionRelation(source, output)(sparkSession) }) case s @ StreamingRelationV2( - dataSourceV2: MicroBatchReadSupportProvider, sourceName, options, output, _) if + dataSourceV2: MicroBatchReadSupport, sourceName, options, output, _) if !disabledSources.contains(dataSourceV2.getClass.getCanonicalName) => v2ToExecutionRelationMap.getOrElseUpdate(s, { // Materialize source to avoid creating it in every batch val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" - val readSupport = dataSourceV2.createMicroBatchReadSupport( + val reader = dataSourceV2.createMicroBatchReader( + Optional.empty(), // user specified schema metadataPath, new DataSourceOptions(options.asJava)) nextSourceId += 1 - readSupportToDataSourceMap(readSupport) = dataSourceV2 -> options - logInfo(s"Using MicroBatchReadSupport [$readSupport] from " + + readerToDataSourceMap(reader) = dataSourceV2 -> options + logInfo(s"Using MicroBatchReader [$reader] from " + s"DataSourceV2 named '$sourceName' [$dataSourceV2]") - StreamingExecutionRelation(readSupport, output)(sparkSession) + StreamingExecutionRelation(reader, output)(sparkSession) }) case s @ StreamingRelationV2(dataSourceV2, sourceName, _, output, v1Relation) => v2ToExecutionRelationMap.getOrElseUpdate(s, { @@ -341,19 +344,19 @@ class MicroBatchExecution( reportTimeTaken("getOffset") { (s, s.getOffset) } - case s: RateControlMicroBatchReadSupport => - updateStatusMessage(s"Getting offsets from $s") - reportTimeTaken("latestOffset") { - val startOffset = availableOffsets - .get(s).map(off => s.deserializeOffset(off.json)) - .getOrElse(s.initialOffset()) - (s, Option(s.latestOffset(startOffset))) - } - case s: MicroBatchReadSupport => + case s: MicroBatchReader => updateStatusMessage(s"Getting offsets from $s") - reportTimeTaken("latestOffset") { - (s, Option(s.latestOffset())) + reportTimeTaken("setOffsetRange") { + // Once v1 streaming source execution is gone, we can refactor this away. + // For now, we set the range here to get the source to infer the available end offset, + // get that offset, and then set the range again when we later execute. + s.setOffsetRange( + toJava(availableOffsets.get(s).map(off => s.deserializeOffset(off.json))), + Optional.empty()) } + + val currentOffset = reportTimeTaken("getEndOffset") { s.getEndOffset() } + (s, Option(currentOffset)) }.toMap availableOffsets ++= latestOffsets.filter { case (_, o) => o.nonEmpty }.mapValues(_.get) @@ -393,8 +396,8 @@ class MicroBatchExecution( if (prevBatchOff.isDefined) { prevBatchOff.get.toStreamProgress(sources).foreach { case (src: Source, off) => src.commit(off) - case (readSupport: MicroBatchReadSupport, off) => - readSupport.commit(readSupport.deserializeOffset(off.json)) + case (reader: MicroBatchReader, off) => + reader.commit(reader.deserializeOffset(off.json)) case (src, _) => throw new IllegalArgumentException( s"Unknown source is found at constructNextBatch: $src") @@ -438,34 +441,30 @@ class MicroBatchExecution( s"${batch.queryExecution.logical}") logDebug(s"Retrieving data from $source: $current -> $available") Some(source -> batch.logicalPlan) - - // TODO(cloud-fan): for data source v2, the new batch is just a new `ScanConfigBuilder`, but - // to be compatible with streaming source v1, we return a logical plan as a new batch here. - case (readSupport: MicroBatchReadSupport, available) - if committedOffsets.get(readSupport).map(_ != available).getOrElse(true) => - val current = committedOffsets.get(readSupport).map { - off => readSupport.deserializeOffset(off.json) - } - val endOffset: OffsetV2 = available match { - case v1: SerializedOffset => readSupport.deserializeOffset(v1.json) + case (reader: MicroBatchReader, available) + if committedOffsets.get(reader).map(_ != available).getOrElse(true) => + val current = committedOffsets.get(reader).map(off => reader.deserializeOffset(off.json)) + val availableV2: OffsetV2 = available match { + case v1: SerializedOffset => reader.deserializeOffset(v1.json) case v2: OffsetV2 => v2 } - val startOffset = current.getOrElse(readSupport.initialOffset) - val scanConfigBuilder = readSupport.newScanConfigBuilder(startOffset, endOffset) - logDebug(s"Retrieving data from $readSupport: $current -> $endOffset") + reader.setOffsetRange( + toJava(current), + Optional.of(availableV2)) + logDebug(s"Retrieving data from $reader: $current -> $availableV2") - val (source, options) = readSupport match { + val (source, options) = reader match { // `MemoryStream` is special. It's for test only and doesn't have a `DataSourceV2` // implementation. We provide a fake one here for explain. case _: MemoryStream[_] => MemoryStreamDataSource -> Map.empty[String, String] // Provide a fake value here just in case something went wrong, e.g. the reader gives // a wrong `equals` implementation. - case _ => readSupportToDataSourceMap.getOrElse(readSupport, { + case _ => readerToDataSourceMap.getOrElse(reader, { FakeDataSourceV2 -> Map.empty[String, String] }) } - Some(readSupport -> StreamingDataSourceV2Relation( - readSupport.fullSchema().toAttributes, source, options, readSupport, scanConfigBuilder)) + Some(reader -> StreamingDataSourceV2Relation( + reader.readSchema().toAttributes, source, options, reader)) case _ => None } } @@ -499,13 +498,13 @@ class MicroBatchExecution( val triggerLogicalPlan = sink match { case _: Sink => newAttributePlan - case s: StreamingWriteSupportProvider => - val writer = s.createStreamingWriteSupport( + case s: StreamWriteSupport => + val writer = s.createStreamWriter( s"$runId", newAttributePlan.schema, outputMode, new DataSourceOptions(extraOptions.asJava)) - WriteToDataSourceV2(new MicroBatchWritSupport(currentBatchId, writer), newAttributePlan) + WriteToDataSourceV2(new MicroBatchWriter(currentBatchId, writer), newAttributePlan) case _ => throw new IllegalArgumentException(s"unknown sink type for $sink") } @@ -533,7 +532,7 @@ class MicroBatchExecution( SQLExecution.withNewExecutionId(sparkSessionToRunBatch, lastExecution) { sink match { case s: Sink => s.addBatch(currentBatchId, nextBatch) - case _: StreamingWriteSupportProvider => + case _: StreamWriteSupport => // This doesn't accumulate any data - it just forces execution of the microbatch writer. nextBatch.collect() } @@ -557,6 +556,10 @@ class MicroBatchExecution( awaitProgressLock.unlock() } } + + private def toJava(scalaOption: Option[OffsetV2]): Optional[OffsetV2] = { + Optional.ofNullable(scalaOption.orNull) + } } object MicroBatchExecution {
http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index d4b5065..6a380ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalP import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec -import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReadSupport +import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReader import org.apache.spark.sql.streaming._ import org.apache.spark.sql.streaming.StreamingQueryListener.QueryProgressEvent import org.apache.spark.util.Clock @@ -251,7 +251,7 @@ trait ProgressReporter extends Logging { // Check whether the streaming query's logical plan has only V2 data sources val allStreamingLeaves = logicalPlan.collect { case s: StreamingExecutionRelation => s } - allStreamingLeaves.forall { _.source.isInstanceOf[MicroBatchReadSupport] } + allStreamingLeaves.forall { _.source.isInstanceOf[MicroBatchReader] } } if (onlyDataSourceV2Sources) { @@ -278,7 +278,7 @@ trait ProgressReporter extends Logging { new IdentityHashMap[DataSourceV2ScanExec, DataSourceV2ScanExec]() lastExecution.executedPlan.collectLeaves().foreach { - case s: DataSourceV2ScanExec if s.readSupport.isInstanceOf[BaseStreamingSource] => + case s: DataSourceV2ScanExec if s.reader.isInstanceOf[BaseStreamingSource] => uniqueStreamingExecLeavesMap.put(s, s) case _ => } @@ -286,7 +286,7 @@ trait ProgressReporter extends Logging { val sourceToInputRowsTuples = uniqueStreamingExecLeavesMap.values.asScala.map { execLeaf => val numRows = execLeaf.metrics.get("numOutputRows").map(_.value).getOrElse(0L) - val source = execLeaf.readSupport.asInstanceOf[BaseStreamingSource] + val source = execLeaf.reader.asInstanceOf[BaseStreamingSource] source -> numRows }.toSeq logDebug("Source -> # input rows\n\t" + sourceToInputRowsTuples.mkString("\n\t")) http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/SimpleStreamingScanConfigBuilder.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/SimpleStreamingScanConfigBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/SimpleStreamingScanConfigBuilder.scala deleted file mode 100644 index 1be0716..0000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/SimpleStreamingScanConfigBuilder.scala +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming - -import org.apache.spark.sql.sources.v2.reader.{ScanConfig, ScanConfigBuilder} -import org.apache.spark.sql.types.StructType - -/** - * A very simple [[ScanConfigBuilder]] implementation that creates a simple [[ScanConfig]] to - * carry schema and offsets for streaming data sources. - */ -class SimpleStreamingScanConfigBuilder( - schema: StructType, - start: Offset, - end: Option[Offset] = None) - extends ScanConfigBuilder { - - override def build(): ScanConfig = SimpleStreamingScanConfig(schema, start, end) -} - -case class SimpleStreamingScanConfig( - readSchema: StructType, - start: Offset, - end: Option[Offset]) - extends ScanConfig http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala index 4b696df..24195b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.execution.LeafExecNode import org.apache.spark.sql.execution.datasources.DataSource -import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceV2} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceV2} object StreamingRelation { def apply(dataSource: DataSource): StreamingRelation = { @@ -83,7 +83,7 @@ case class StreamingExecutionRelation( // We have to pack in the V1 data source as a shim, for the case when a source implements // continuous processing (which is always V2) but only has V1 microbatch support. We don't -// know at read time whether the query is continuous or not, so we need to be able to +// know at read time whether the query is conntinuous or not, so we need to be able to // swap a V1 relation back in. /** * Used to link a [[DataSourceV2]] into a streaming @@ -113,7 +113,7 @@ case class StreamingRelationV2( * Used to link a [[DataSourceV2]] into a continuous processing execution. */ case class ContinuousExecutionRelation( - source: ContinuousReadSupportProvider, + source: ContinuousReadSupport, extraOptions: Map[String, String], output: Seq[Attribute])(session: SparkSession) extends LeafNode with MultiInstanceRelation { http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala index 9c5c16f..cfba100 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -18,10 +18,10 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.sql._ -import org.apache.spark.sql.execution.streaming.sources.ConsoleWriteSupport +import org.apache.spark.sql.execution.streaming.sources.ConsoleWriter import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister} -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamingWriteSupportProvider} -import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamWriteSupport} +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -31,16 +31,16 @@ case class ConsoleRelation(override val sqlContext: SQLContext, data: DataFrame) } class ConsoleSinkProvider extends DataSourceV2 - with StreamingWriteSupportProvider + with StreamWriteSupport with DataSourceRegister with CreatableRelationProvider { - override def createStreamingWriteSupport( + override def createStreamWriter( queryId: String, schema: StructType, mode: OutputMode, - options: DataSourceOptions): StreamingWriteSupport = { - new ConsoleWriteSupport(schema, options) + options: DataSourceOptions): StreamWriter = { + new ConsoleWriter(schema, options) } def createRelation( http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala index b68f67e..554a0b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala @@ -21,13 +21,12 @@ import org.apache.spark._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousPartitionReaderFactory -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousInputPartitionReader import org.apache.spark.util.NextIterator class ContinuousDataSourceRDDPartition( val index: Int, - val inputPartition: InputPartition) + val inputPartition: InputPartition[InternalRow]) extends Partition with Serializable { // This is semantically a lazy val - it's initialized once the first time a call to @@ -50,22 +49,15 @@ class ContinuousDataSourceRDD( sc: SparkContext, dataQueueSize: Int, epochPollIntervalMs: Long, - private val inputPartitions: Seq[InputPartition], - schema: StructType, - partitionReaderFactory: ContinuousPartitionReaderFactory) + private val readerInputPartitions: Seq[InputPartition[InternalRow]]) extends RDD[InternalRow](sc, Nil) { override protected def getPartitions: Array[Partition] = { - inputPartitions.zipWithIndex.map { + readerInputPartitions.zipWithIndex.map { case (inputPartition, index) => new ContinuousDataSourceRDDPartition(index, inputPartition) }.toArray } - private def castPartition(split: Partition): ContinuousDataSourceRDDPartition = split match { - case p: ContinuousDataSourceRDDPartition => p - case _ => throw new SparkException(s"[BUG] Not a ContinuousDataSourceRDDPartition: $split") - } - /** * Initialize the shared reader for this partition if needed, then read rows from it until * it returns null to signal the end of the epoch. @@ -77,12 +69,10 @@ class ContinuousDataSourceRDD( } val readerForPartition = { - val partition = castPartition(split) + val partition = split.asInstanceOf[ContinuousDataSourceRDDPartition] if (partition.queueReader == null) { - val partitionReader = partitionReaderFactory.createReader( - partition.inputPartition) - partition.queueReader = new ContinuousQueuedDataReader( - partition.index, partitionReader, schema, context, dataQueueSize, epochPollIntervalMs) + partition.queueReader = + new ContinuousQueuedDataReader(partition, context, dataQueueSize, epochPollIntervalMs) } partition.queueReader @@ -103,6 +93,17 @@ class ContinuousDataSourceRDD( } override def getPreferredLocations(split: Partition): Seq[String] = { - castPartition(split).inputPartition.preferredLocations() + split.asInstanceOf[ContinuousDataSourceRDDPartition].inputPartition.preferredLocations() + } +} + +object ContinuousDataSourceRDD { + private[continuous] def getContinuousReader( + reader: InputPartitionReader[InternalRow]): ContinuousInputPartitionReader[_] = { + reader match { + case r: ContinuousInputPartitionReader[InternalRow] => r + case _ => + throw new IllegalStateException(s"Unknown continuous reader type ${reader.getClass}") + } } } http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index ccca726..f104422 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -29,12 +29,13 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentDate, CurrentTimestamp} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanExec, StreamingDataSourceV2Relation} +import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2} import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, StreamingRelationV2, _} import org.apache.spark.sql.sources.v2 -import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceOptions, StreamingWriteSupportProvider} -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, PartitionOffset} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, StreamWriteSupport} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, PartitionOffset} import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} +import org.apache.spark.sql.types.StructType import org.apache.spark.util.{Clock, Utils} class ContinuousExecution( @@ -42,7 +43,7 @@ class ContinuousExecution( name: String, checkpointRoot: String, analyzedPlan: LogicalPlan, - sink: StreamingWriteSupportProvider, + sink: StreamWriteSupport, trigger: Trigger, triggerClock: Clock, outputMode: OutputMode, @@ -52,7 +53,7 @@ class ContinuousExecution( sparkSession, name, checkpointRoot, analyzedPlan, sink, trigger, triggerClock, outputMode, deleteCheckpointOnStop) { - @volatile protected var continuousSources: Seq[ContinuousReadSupport] = Seq() + @volatile protected var continuousSources: Seq[ContinuousReader] = Seq() override protected def sources: Seq[BaseStreamingSource] = continuousSources // For use only in test harnesses. @@ -62,8 +63,7 @@ class ContinuousExecution( val toExecutionRelationMap = MutableMap[StreamingRelationV2, ContinuousExecutionRelation]() analyzedPlan.transform { case r @ StreamingRelationV2( - source: ContinuousReadSupportProvider, _, extraReaderOptions, output, _) => - // TODO: shall we create `ContinuousReadSupport` here instead of each reconfiguration? + source: ContinuousReadSupport, _, extraReaderOptions, output, _) => toExecutionRelationMap.getOrElseUpdate(r, { ContinuousExecutionRelation(source, extraReaderOptions, output)(sparkSession) }) @@ -148,7 +148,8 @@ class ContinuousExecution( val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" nextSourceId += 1 - dataSource.createContinuousReadSupport( + dataSource.createContinuousReader( + java.util.Optional.empty[StructType](), metadataPath, new DataSourceOptions(extraReaderOptions.asJava)) } @@ -159,9 +160,9 @@ class ContinuousExecution( var insertedSourceId = 0 val withNewSources = logicalPlan transform { case ContinuousExecutionRelation(source, options, output) => - val readSupport = continuousSources(insertedSourceId) + val reader = continuousSources(insertedSourceId) insertedSourceId += 1 - val newOutput = readSupport.fullSchema().toAttributes + val newOutput = reader.readSchema().toAttributes assert(output.size == newOutput.size, s"Invalid reader: ${Utils.truncatedString(output, ",")} != " + @@ -169,10 +170,9 @@ class ContinuousExecution( replacements ++= output.zip(newOutput) val loggedOffset = offsets.offsets(0) - val realOffset = loggedOffset.map(off => readSupport.deserializeOffset(off.json)) - val startOffset = realOffset.getOrElse(readSupport.initialOffset) - val scanConfigBuilder = readSupport.newScanConfigBuilder(startOffset) - StreamingDataSourceV2Relation(newOutput, source, options, readSupport, scanConfigBuilder) + val realOffset = loggedOffset.map(off => reader.deserializeOffset(off.json)) + reader.setStartOffset(java.util.Optional.ofNullable(realOffset.orNull)) + StreamingDataSourceV2Relation(newOutput, source, options, reader) } // Rewire the plan to use the new attributes that were returned by the source. @@ -185,13 +185,17 @@ class ContinuousExecution( "CurrentTimestamp and CurrentDate not yet supported for continuous processing") } - val writer = sink.createStreamingWriteSupport( + val writer = sink.createStreamWriter( s"$runId", triggerLogicalPlan.schema, outputMode, new DataSourceOptions(extraOptions.asJava)) val withSink = WriteToContinuousDataSource(writer, triggerLogicalPlan) + val reader = withSink.collect { + case StreamingDataSourceV2Relation(_, _, _, r: ContinuousReader) => r + }.head + reportTimeTaken("queryPlanning") { lastExecution = new IncrementalExecution( sparkSessionForQuery, @@ -204,11 +208,6 @@ class ContinuousExecution( lastExecution.executedPlan // Force the lazy generation of execution plan } - val (readSupport, scanConfig) = lastExecution.executedPlan.collect { - case scan: DataSourceV2ScanExec if scan.readSupport.isInstanceOf[ContinuousReadSupport] => - scan.readSupport.asInstanceOf[ContinuousReadSupport] -> scan.scanConfig - }.head - sparkSessionForQuery.sparkContext.setLocalProperty( StreamExecution.IS_CONTINUOUS_PROCESSING, true.toString) sparkSessionForQuery.sparkContext.setLocalProperty( @@ -226,16 +225,14 @@ class ContinuousExecution( // Use the parent Spark session for the endpoint since it's where this query ID is registered. val epochEndpoint = EpochCoordinatorRef.create( - writer, readSupport, this, epochCoordinatorId, currentBatchId, sparkSession, SparkEnv.get) + writer, reader, this, epochCoordinatorId, currentBatchId, sparkSession, SparkEnv.get) val epochUpdateThread = new Thread(new Runnable { override def run: Unit = { try { triggerExecutor.execute(() => { startTrigger() - val shouldReconfigure = readSupport.needsReconfiguration(scanConfig) && - state.compareAndSet(ACTIVE, RECONFIGURING) - if (shouldReconfigure) { + if (reader.needsReconfiguration() && state.compareAndSet(ACTIVE, RECONFIGURING)) { if (queryExecutionThread.isAlive) { queryExecutionThread.interrupt() } @@ -285,12 +282,10 @@ class ContinuousExecution( * Report ending partition offsets for the given reader at the given epoch. */ def addOffset( - epoch: Long, - readSupport: ContinuousReadSupport, - partitionOffsets: Seq[PartitionOffset]): Unit = { + epoch: Long, reader: ContinuousReader, partitionOffsets: Seq[PartitionOffset]): Unit = { assert(continuousSources.length == 1, "only one continuous source supported currently") - val globalOffset = readSupport.mergeOffsets(partitionOffsets.toArray) + val globalOffset = reader.mergeOffsets(partitionOffsets.toArray) val oldOffset = synchronized { offsetLog.add(epoch, OffsetSeq.fill(globalOffset)) offsetLog.get(epoch - 1) http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala index 65c5fc6..ec1dabd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala @@ -25,9 +25,8 @@ import scala.util.control.NonFatal import org.apache.spark.{SparkEnv, SparkException, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.UnsafeProjection -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousPartitionReader, PartitionOffset} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader} +import org.apache.spark.sql.sources.v2.reader.streaming.PartitionOffset import org.apache.spark.util.ThreadUtils /** @@ -38,14 +37,15 @@ import org.apache.spark.util.ThreadUtils * offsets across epochs. Each compute() should call the next() method here until null is returned. */ class ContinuousQueuedDataReader( - partitionIndex: Int, - reader: ContinuousPartitionReader[InternalRow], - schema: StructType, + partition: ContinuousDataSourceRDDPartition, context: TaskContext, dataQueueSize: Int, epochPollIntervalMs: Long) extends Closeable { + private val reader = partition.inputPartition.createPartitionReader() + // Important sequencing - we must get our starting point before the provider threads start running - private var currentOffset: PartitionOffset = reader.getOffset + private var currentOffset: PartitionOffset = + ContinuousDataSourceRDD.getContinuousReader(reader).getOffset /** * The record types in the read buffer. @@ -66,7 +66,7 @@ class ContinuousQueuedDataReader( epochMarkerExecutor.scheduleWithFixedDelay( epochMarkerGenerator, 0, epochPollIntervalMs, TimeUnit.MILLISECONDS) - private val dataReaderThread = new DataReaderThread(schema) + private val dataReaderThread = new DataReaderThread dataReaderThread.setDaemon(true) dataReaderThread.start() @@ -113,7 +113,7 @@ class ContinuousQueuedDataReader( currentEntry match { case EpochMarker => epochCoordEndpoint.send(ReportPartitionOffset( - partitionIndex, EpochTracker.getCurrentEpoch.get, currentOffset)) + partition.index, EpochTracker.getCurrentEpoch.get, currentOffset)) null case ContinuousRow(row, offset) => currentOffset = offset @@ -128,16 +128,16 @@ class ContinuousQueuedDataReader( /** * The data component of [[ContinuousQueuedDataReader]]. Pushes (row, offset) to the queue when - * a new row arrives to the [[ContinuousPartitionReader]]. + * a new row arrives to the [[InputPartitionReader]]. */ - class DataReaderThread(schema: StructType) extends Thread( + class DataReaderThread extends Thread( s"continuous-reader--${context.partitionId()}--" + s"${context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY)}") with Logging { @volatile private[continuous] var failureReason: Throwable = _ - private val toUnsafe = UnsafeProjection.create(schema) override def run(): Unit = { TaskContext.setTaskContext(context) + val baseReader = ContinuousDataSourceRDD.getContinuousReader(reader) try { while (!shouldStop()) { if (!reader.next()) { @@ -149,9 +149,8 @@ class ContinuousQueuedDataReader( return } } - // `InternalRow#copy` may not be properly implemented, for safety we convert to unsafe row - // before copy here. - queue.put(ContinuousRow(toUnsafe(reader.get()).copy(), reader.getOffset)) + + queue.put(ContinuousRow(reader.get().copy(), baseReader.getOffset)) } } catch { case _: InterruptedException => http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index a6cde2b..551e07c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -17,22 +17,24 @@ package org.apache.spark.sql.execution.streaming.continuous +import scala.collection.JavaConverters._ + import org.json4s.DefaultFormats import org.json4s.jackson.Serialization import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.streaming.{RateStreamOffset, SimpleStreamingScanConfig, SimpleStreamingScanConfigBuilder, ValueRunTimeMsPair} +import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair} import org.apache.spark.sql.execution.streaming.sources.RateStreamProvider import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming._ +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, ContinuousReader, Offset, PartitionOffset} import org.apache.spark.sql.types.StructType case class RateStreamPartitionOffset( partition: Int, currentValue: Long, currentTimeMs: Long) extends PartitionOffset -class RateStreamContinuousReadSupport(options: DataSourceOptions) extends ContinuousReadSupport { +class RateStreamContinuousReader(options: DataSourceOptions) extends ContinuousReader { implicit val defaultFormats: DefaultFormats = DefaultFormats val creationTime = System.currentTimeMillis() @@ -54,18 +56,18 @@ class RateStreamContinuousReadSupport(options: DataSourceOptions) extends Contin RateStreamOffset(Serialization.read[Map[Int, ValueRunTimeMsPair]](json)) } - override def fullSchema(): StructType = RateStreamProvider.SCHEMA + override def readSchema(): StructType = RateStreamProvider.SCHEMA - override def newScanConfigBuilder(start: Offset): ScanConfigBuilder = { - new SimpleStreamingScanConfigBuilder(fullSchema(), start) - } + private var offset: Offset = _ - override def initialOffset: Offset = createInitialOffset(numPartitions, creationTime) + override def setStartOffset(offset: java.util.Optional[Offset]): Unit = { + this.offset = offset.orElse(createInitialOffset(numPartitions, creationTime)) + } - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { - val startOffset = config.asInstanceOf[SimpleStreamingScanConfig].start + override def getStartOffset(): Offset = offset - val partitionStartMap = startOffset match { + override def planInputPartitions(): java.util.List[InputPartition[InternalRow]] = { + val partitionStartMap = offset match { case off: RateStreamOffset => off.partitionToValueAndRunTimeMs case off => throw new IllegalArgumentException( @@ -88,12 +90,8 @@ class RateStreamContinuousReadSupport(options: DataSourceOptions) extends Contin i, numPartitions, perPartitionRate) - }.toArray - } - - override def createContinuousReaderFactory( - config: ScanConfig): ContinuousPartitionReaderFactory = { - RateStreamContinuousReaderFactory + .asInstanceOf[InputPartition[InternalRow]] + }.asJava } override def commit(end: Offset): Unit = {} @@ -120,23 +118,33 @@ case class RateStreamContinuousInputPartition( partitionIndex: Int, increment: Long, rowsPerSecond: Double) - extends InputPartition - -object RateStreamContinuousReaderFactory extends ContinuousPartitionReaderFactory { - override def createReader(partition: InputPartition): ContinuousPartitionReader[InternalRow] = { - val p = partition.asInstanceOf[RateStreamContinuousInputPartition] - new RateStreamContinuousPartitionReader( - p.startValue, p.startTimeMs, p.partitionIndex, p.increment, p.rowsPerSecond) + extends ContinuousInputPartition[InternalRow] { + + override def createContinuousReader( + offset: PartitionOffset): InputPartitionReader[InternalRow] = { + val rateStreamOffset = offset.asInstanceOf[RateStreamPartitionOffset] + require(rateStreamOffset.partition == partitionIndex, + s"Expected partitionIndex: $partitionIndex, but got: ${rateStreamOffset.partition}") + new RateStreamContinuousInputPartitionReader( + rateStreamOffset.currentValue, + rateStreamOffset.currentTimeMs, + partitionIndex, + increment, + rowsPerSecond) } + + override def createPartitionReader(): InputPartitionReader[InternalRow] = + new RateStreamContinuousInputPartitionReader( + startValue, startTimeMs, partitionIndex, increment, rowsPerSecond) } -class RateStreamContinuousPartitionReader( +class RateStreamContinuousInputPartitionReader( startValue: Long, startTimeMs: Long, partitionIndex: Int, increment: Long, rowsPerSecond: Double) - extends ContinuousPartitionReader[InternalRow] { + extends ContinuousInputPartitionReader[InternalRow] { private var nextReadTime: Long = startTimeMs private val readTimeIncrement: Long = (1000 / rowsPerSecond).toLong http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala index 28ab244..56bfefd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala @@ -20,9 +20,10 @@ package org.apache.spark.sql.execution.streaming.continuous import java.io.{BufferedReader, InputStreamReader, IOException} import java.net.Socket import java.sql.Timestamp -import java.util.Calendar +import java.util.{Calendar, List => JList} import javax.annotation.concurrent.GuardedBy +import scala.collection.JavaConverters._ import scala.collection.mutable.ListBuffer import org.json4s.{DefaultFormats, NoTypeHints} @@ -33,26 +34,24 @@ import org.apache.spark.internal.Logging import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.streaming.{Offset => _, _} +import org.apache.spark.sql.execution.streaming.{ContinuousRecordEndpoint, ContinuousRecordPartitionOffset, GetRecord} import org.apache.spark.sql.execution.streaming.sources.TextSocketReader import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming._ +import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, ContinuousReader, Offset, PartitionOffset} import org.apache.spark.sql.types.StructType import org.apache.spark.util.RpcUtils /** - * A ContinuousReadSupport that reads text lines through a TCP socket, designed only for tutorials - * and debugging. This ContinuousReadSupport will *not* work in production applications due to - * multiple reasons, including no support for fault recovery. + * A ContinuousReader that reads text lines through a TCP socket, designed only for tutorials and + * debugging. This ContinuousReader will *not* work in production applications due to multiple + * reasons, including no support for fault recovery. * * The driver maintains a socket connection to the host-port, keeps the received messages in * buckets and serves the messages to the executors via a RPC endpoint. */ -class TextSocketContinuousReadSupport(options: DataSourceOptions) - extends ContinuousReadSupport with Logging { - +class TextSocketContinuousReader(options: DataSourceOptions) extends ContinuousReader with Logging { implicit val defaultFormats: DefaultFormats = DefaultFormats private val host: String = options.get("host").get() @@ -74,8 +73,7 @@ class TextSocketContinuousReadSupport(options: DataSourceOptions) @GuardedBy("this") private var currentOffset: Int = -1 - // Exposed for tests. - private[spark] var startOffset: TextSocketOffset = _ + private var startOffset: TextSocketOffset = _ private val recordEndpoint = new ContinuousRecordEndpoint(buckets, this) @volatile private var endpointRef: RpcEndpointRef = _ @@ -96,16 +94,16 @@ class TextSocketContinuousReadSupport(options: DataSourceOptions) TextSocketOffset(Serialization.read[List[Int]](json)) } - override def initialOffset(): Offset = { - startOffset = TextSocketOffset(List.fill(numPartitions)(0)) - startOffset + override def setStartOffset(offset: java.util.Optional[Offset]): Unit = { + this.startOffset = offset + .orElse(TextSocketOffset(List.fill(numPartitions)(0))) + .asInstanceOf[TextSocketOffset] + recordEndpoint.setStartOffsets(startOffset.offsets) } - override def newScanConfigBuilder(start: Offset): ScanConfigBuilder = { - new SimpleStreamingScanConfigBuilder(fullSchema(), start) - } + override def getStartOffset: Offset = startOffset - override def fullSchema(): StructType = { + override def readSchema(): StructType = { if (includeTimestamp) { TextSocketReader.SCHEMA_TIMESTAMP } else { @@ -113,10 +111,8 @@ class TextSocketContinuousReadSupport(options: DataSourceOptions) } } - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { - val startOffset = config.asInstanceOf[SimpleStreamingScanConfig] - .start.asInstanceOf[TextSocketOffset] - recordEndpoint.setStartOffsets(startOffset.offsets) + override def planInputPartitions(): JList[InputPartition[InternalRow]] = { + val endpointName = s"TextSocketContinuousReaderEndpoint-${java.util.UUID.randomUUID()}" endpointRef = recordEndpoint.rpcEnv.setupEndpoint(endpointName, recordEndpoint) @@ -136,13 +132,10 @@ class TextSocketContinuousReadSupport(options: DataSourceOptions) startOffset.offsets.zipWithIndex.map { case (offset, i) => - TextSocketContinuousInputPartition(endpointName, i, offset, includeTimestamp) - }.toArray - } + TextSocketContinuousInputPartition( + endpointName, i, offset, includeTimestamp): InputPartition[InternalRow] + }.asJava - override def createContinuousReaderFactory( - config: ScanConfig): ContinuousPartitionReaderFactory = { - TextSocketReaderFactory } override def commit(end: Offset): Unit = synchronized { @@ -197,7 +190,7 @@ class TextSocketContinuousReadSupport(options: DataSourceOptions) logWarning(s"Stream closed by $host:$port") return } - TextSocketContinuousReadSupport.this.synchronized { + TextSocketContinuousReader.this.synchronized { currentOffset += 1 val newData = (line, Timestamp.valueOf( @@ -228,30 +221,25 @@ case class TextSocketContinuousInputPartition( driverEndpointName: String, partitionId: Int, startOffset: Int, - includeTimestamp: Boolean) extends InputPartition - - -object TextSocketReaderFactory extends ContinuousPartitionReaderFactory { + includeTimestamp: Boolean) +extends InputPartition[InternalRow] { - override def createReader(partition: InputPartition): ContinuousPartitionReader[InternalRow] = { - val p = partition.asInstanceOf[TextSocketContinuousInputPartition] - new TextSocketContinuousPartitionReader( - p.driverEndpointName, p.partitionId, p.startOffset, p.includeTimestamp) - } + override def createPartitionReader(): InputPartitionReader[InternalRow] = + new TextSocketContinuousInputPartitionReader(driverEndpointName, partitionId, startOffset, + includeTimestamp) } - /** * Continuous text socket input partition reader. * * Polls the driver endpoint for new records. */ -class TextSocketContinuousPartitionReader( +class TextSocketContinuousInputPartitionReader( driverEndpointName: String, partitionId: Int, startOffset: Int, includeTimestamp: Boolean) - extends ContinuousPartitionReader[InternalRow] { + extends ContinuousInputPartitionReader[InternalRow] { private val endpoint = RpcUtils.makeDriverRef( driverEndpointName, http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala index a08411d..967dbe2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala @@ -20,8 +20,7 @@ package org.apache.spark.sql.execution.streaming.continuous import org.apache.spark.{Partition, SparkEnv, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.sources.v2.writer.DataWriter -import org.apache.spark.sql.sources.v2.writer.streaming.StreamingDataWriterFactory +import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory} import org.apache.spark.util.Utils /** @@ -32,7 +31,7 @@ import org.apache.spark.util.Utils * * We keep repeating prev.compute() and writing new epochs until the query is shut down. */ -class ContinuousWriteRDD(var prev: RDD[InternalRow], writerFactory: StreamingDataWriterFactory) +class ContinuousWriteRDD(var prev: RDD[InternalRow], writeTask: DataWriterFactory[InternalRow]) extends RDD[Unit](prev) { override val partitioner = prev.partitioner @@ -51,7 +50,7 @@ class ContinuousWriteRDD(var prev: RDD[InternalRow], writerFactory: StreamingDat Utils.tryWithSafeFinallyAndFailureCallbacks(block = { try { val dataIterator = prev.compute(split, context) - dataWriter = writerFactory.createWriter( + dataWriter = writeTask.createDataWriter( context.partitionId(), context.taskAttemptId(), EpochTracker.getCurrentEpoch.get) http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala index 2238ce2..8877ebe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala @@ -23,9 +23,9 @@ import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, PartitionOffset} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, PartitionOffset} import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage -import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.util.RpcUtils private[continuous] sealed trait EpochCoordinatorMessage extends Serializable @@ -82,15 +82,15 @@ private[sql] object EpochCoordinatorRef extends Logging { * Create a reference to a new [[EpochCoordinator]]. */ def create( - writeSupport: StreamingWriteSupport, - readSupport: ContinuousReadSupport, + writer: StreamWriter, + reader: ContinuousReader, query: ContinuousExecution, epochCoordinatorId: String, startEpoch: Long, session: SparkSession, env: SparkEnv): RpcEndpointRef = synchronized { val coordinator = new EpochCoordinator( - writeSupport, readSupport, query, startEpoch, session, env.rpcEnv) + writer, reader, query, startEpoch, session, env.rpcEnv) val ref = env.rpcEnv.setupEndpoint(endpointName(epochCoordinatorId), coordinator) logInfo("Registered EpochCoordinator endpoint") ref @@ -115,8 +115,8 @@ private[sql] object EpochCoordinatorRef extends Logging { * have both committed and reported an end offset for a given epoch. */ private[continuous] class EpochCoordinator( - writeSupport: StreamingWriteSupport, - readSupport: ContinuousReadSupport, + writer: StreamWriter, + reader: ContinuousReader, query: ContinuousExecution, startEpoch: Long, session: SparkSession, @@ -198,7 +198,7 @@ private[continuous] class EpochCoordinator( s"and is ready to be committed. Committing epoch $epoch.") // Sequencing is important here. We must commit to the writer before recording the commit // in the query, or we will end up dropping the commit if we restart in the middle. - writeSupport.commit(epoch, messages.toArray) + writer.commit(epoch, messages.toArray) query.commit(epoch) } @@ -220,7 +220,7 @@ private[continuous] class EpochCoordinator( partitionOffsets.collect { case ((e, _), o) if e == epoch => o } if (thisEpochOffsets.size == numReaderPartitions) { logDebug(s"Epoch $epoch has offsets reported from all partitions: $thisEpochOffsets") - query.addOffset(epoch, readSupport, thisEpochOffsets.toSeq) + query.addOffset(epoch, reader, thisEpochOffsets.toSeq) resolveCommitsAtEpoch(epoch) } } http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala index 7ad21cc..943c731 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala @@ -19,13 +19,13 @@ package org.apache.spark.sql.execution.streaming.continuous import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter /** * The logical plan for writing data in a continuous stream. */ case class WriteToContinuousDataSource( - writeSupport: StreamingWriteSupport, query: LogicalPlan) extends LogicalPlan { + writer: StreamWriter, query: LogicalPlan) extends LogicalPlan { override def children: Seq[LogicalPlan] = Seq(query) override def output: Seq[Attribute] = Nil } http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala index c216b61..927d3a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala @@ -26,21 +26,21 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.streaming.StreamExecution -import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter /** - * The physical plan for writing data into a continuous processing [[StreamingWriteSupport]]. + * The physical plan for writing data into a continuous processing [[StreamWriter]]. */ -case class WriteToContinuousDataSourceExec(writeSupport: StreamingWriteSupport, query: SparkPlan) +case class WriteToContinuousDataSourceExec(writer: StreamWriter, query: SparkPlan) extends SparkPlan with Logging { override def children: Seq[SparkPlan] = Seq(query) override def output: Seq[Attribute] = Nil override protected def doExecute(): RDD[InternalRow] = { - val writerFactory = writeSupport.createStreamingWriterFactory() + val writerFactory = writer.createWriterFactory() val rdd = new ContinuousWriteRDD(query.execute(), writerFactory) - logInfo(s"Start processing data source write support: $writeSupport. " + + logInfo(s"Start processing data source writer: $writer. " + s"The input RDD has ${rdd.partitions.length} partitions.") EpochCoordinatorRef.get( sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index adf52ab..f81abdc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -17,9 +17,12 @@ package org.apache.spark.sql.execution.streaming +import java.{util => ju} +import java.util.Optional import java.util.concurrent.atomic.AtomicInteger import javax.annotation.concurrent.GuardedBy +import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, ListBuffer} import scala.util.control.NonFatal @@ -31,8 +34,8 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ -import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset => OffsetV2} +import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader} +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -64,7 +67,7 @@ abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) extends Bas addData(data.toTraversable) } - def fullSchema(): StructType = encoder.schema + def readSchema(): StructType = encoder.schema protected def logicalPlan: LogicalPlan @@ -77,7 +80,7 @@ abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) extends Bas * available. */ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) - extends MemoryStreamBase[A](sqlContext) with MicroBatchReadSupport with Logging { + extends MemoryStreamBase[A](sqlContext) with MicroBatchReader with Logging { protected val logicalPlan: LogicalPlan = StreamingExecutionRelation(this, attributes)(sqlContext.sparkSession) @@ -119,22 +122,24 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) override def toString: String = s"MemoryStream[${Utils.truncatedString(output, ",")}]" - override def deserializeOffset(json: String): OffsetV2 = LongOffset(json.toLong) + override def setOffsetRange(start: Optional[OffsetV2], end: Optional[OffsetV2]): Unit = { + synchronized { + startOffset = start.orElse(LongOffset(-1)).asInstanceOf[LongOffset] + endOffset = end.orElse(currentOffset).asInstanceOf[LongOffset] + } + } - override def initialOffset: OffsetV2 = LongOffset(-1) + override def deserializeOffset(json: String): OffsetV2 = LongOffset(json.toLong) - override def latestOffset(): OffsetV2 = { - if (currentOffset.offset == -1) null else currentOffset + override def getStartOffset: OffsetV2 = synchronized { + if (startOffset.offset == -1) null else startOffset } - override def newScanConfigBuilder(start: OffsetV2, end: OffsetV2): ScanConfigBuilder = { - new SimpleStreamingScanConfigBuilder(fullSchema(), start, Some(end)) + override def getEndOffset: OffsetV2 = synchronized { + if (endOffset.offset == -1) null else endOffset } - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { - val sc = config.asInstanceOf[SimpleStreamingScanConfig] - val startOffset = sc.start.asInstanceOf[LongOffset] - val endOffset = sc.end.get.asInstanceOf[LongOffset] + override def planInputPartitions(): ju.List[InputPartition[InternalRow]] = { synchronized { // Compute the internal batch numbers to fetch: [startOrdinal, endOrdinal) val startOrdinal = startOffset.offset.toInt + 1 @@ -151,15 +156,11 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) logDebug(generateDebugString(newBlocks.flatten, startOrdinal, endOrdinal)) newBlocks.map { block => - new MemoryStreamInputPartition(block) - }.toArray + new MemoryStreamInputPartition(block): InputPartition[InternalRow] + }.asJava } } - override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { - MemoryStreamReaderFactory - } - private def generateDebugString( rows: Seq[UnsafeRow], startOrdinal: Int, @@ -200,12 +201,10 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) } -class MemoryStreamInputPartition(val records: Array[UnsafeRow]) extends InputPartition - -object MemoryStreamReaderFactory extends PartitionReaderFactory { - override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { - val records = partition.asInstanceOf[MemoryStreamInputPartition].records - new PartitionReader[InternalRow] { +class MemoryStreamInputPartition(records: Array[UnsafeRow]) + extends InputPartition[InternalRow] { + override def createPartitionReader(): InputPartitionReader[InternalRow] = { + new InputPartitionReader[InternalRow] { private var currentIndex = -1 override def next(): Boolean = { http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupport.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupport.scala deleted file mode 100644 index 833e62f..0000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupport.scala +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming.sources - -import org.apache.spark.internal.Logging -import org.apache.spark.sql.{Dataset, SparkSession} -import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage -import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport} -import org.apache.spark.sql.types.StructType - -/** Common methods used to create writes for the the console sink */ -class ConsoleWriteSupport(schema: StructType, options: DataSourceOptions) - extends StreamingWriteSupport with Logging { - - // Number of rows to display, by default 20 rows - protected val numRowsToShow = options.getInt("numRows", 20) - - // Truncate the displayed data if it is too long, by default it is true - protected val isTruncated = options.getBoolean("truncate", true) - - assert(SparkSession.getActiveSession.isDefined) - protected val spark = SparkSession.getActiveSession.get - - def createStreamingWriterFactory(): StreamingDataWriterFactory = PackedRowWriterFactory - - override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { - // We have to print a "Batch" label for the epoch for compatibility with the pre-data source V2 - // behavior. - printRows(messages, schema, s"Batch: $epochId") - } - - def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} - - protected def printRows( - commitMessages: Array[WriterCommitMessage], - schema: StructType, - printMessage: String): Unit = { - val rows = commitMessages.collect { - case PackedRowCommitMessage(rs) => rs - }.flatten - - // scalastyle:off println - println("-------------------------------------------") - println(printMessage) - println("-------------------------------------------") - // scalastyle:off println - Dataset.ofRows(spark, LocalRelation(schema.toAttributes, rows)) - .show(numRowsToShow, isTruncated) - } - - override def toString(): String = { - s"ConsoleWriter[numRows=$numRowsToShow, truncate=$isTruncated]" - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala new file mode 100644 index 0000000..fd45ba5 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{Dataset, SparkSession} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.sources.v2.DataSourceOptions +import org.apache.spark.sql.sources.v2.writer.{DataWriterFactory, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.types.StructType + +/** Common methods used to create writes for the the console sink */ +class ConsoleWriter(schema: StructType, options: DataSourceOptions) + extends StreamWriter with Logging { + + // Number of rows to display, by default 20 rows + protected val numRowsToShow = options.getInt("numRows", 20) + + // Truncate the displayed data if it is too long, by default it is true + protected val isTruncated = options.getBoolean("truncate", true) + + assert(SparkSession.getActiveSession.isDefined) + protected val spark = SparkSession.getActiveSession.get + + def createWriterFactory(): DataWriterFactory[InternalRow] = PackedRowWriterFactory + + override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { + // We have to print a "Batch" label for the epoch for compatibility with the pre-data source V2 + // behavior. + printRows(messages, schema, s"Batch: $epochId") + } + + def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} + + protected def printRows( + commitMessages: Array[WriterCommitMessage], + schema: StructType, + printMessage: String): Unit = { + val rows = commitMessages.collect { + case PackedRowCommitMessage(rs) => rs + }.flatten + + // scalastyle:off println + println("-------------------------------------------") + println(printMessage) + println("-------------------------------------------") + // scalastyle:off println + Dataset.ofRows(spark, LocalRelation(schema.toAttributes, rows)) + .show(numRowsToShow, isTruncated) + } + + override def toString(): String = { + s"ConsoleWriter[numRows=$numRowsToShow, truncate=$isTruncated]" + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala index dbcc448..4a32217 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala @@ -17,22 +17,26 @@ package org.apache.spark.sql.execution.streaming.sources +import java.{util => ju} +import java.util.Optional import java.util.concurrent.atomic.AtomicInteger import javax.annotation.concurrent.GuardedBy +import scala.collection.JavaConverters._ import scala.collection.mutable.ListBuffer import org.json4s.NoTypeHints import org.json4s.jackson.Serialization import org.apache.spark.{SparkEnv, TaskContext} -import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.sql.{Encoder, SQLContext} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.streaming.{Offset => _, _} -import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceOptions} -import org.apache.spark.sql.sources.v2.reader.{InputPartition, ScanConfig, ScanConfigBuilder} -import org.apache.spark.sql.sources.v2.reader.streaming._ +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions} +import org.apache.spark.sql.sources.v2.reader.InputPartition +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, ContinuousReader, Offset, PartitionOffset} +import org.apache.spark.sql.types.StructType import org.apache.spark.util.RpcUtils /** @@ -44,9 +48,7 @@ import org.apache.spark.util.RpcUtils * the specified offset within the list, or null if that offset doesn't yet have a record. */ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPartitions: Int = 2) - extends MemoryStreamBase[A](sqlContext) - with ContinuousReadSupportProvider with ContinuousReadSupport { - + extends MemoryStreamBase[A](sqlContext) with ContinuousReader with ContinuousReadSupport { private implicit val formats = Serialization.formats(NoTypeHints) protected val logicalPlan = @@ -57,6 +59,9 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPa @GuardedBy("this") private val records = Seq.fill(numPartitions)(new ListBuffer[A]) + @GuardedBy("this") + private var startOffset: ContinuousMemoryStreamOffset = _ + private val recordEndpoint = new ContinuousRecordEndpoint(records, this) @volatile private var endpointRef: RpcEndpointRef = _ @@ -70,8 +75,15 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPa ContinuousMemoryStreamOffset((0 until numPartitions).map(i => (i, records(i).size)).toMap) } - override def initialOffset(): Offset = { - ContinuousMemoryStreamOffset((0 until numPartitions).map(i => (i, 0)).toMap) + override def setStartOffset(start: Optional[Offset]): Unit = synchronized { + // Inferred initial offset is position 0 in each partition. + startOffset = start.orElse { + ContinuousMemoryStreamOffset((0 until numPartitions).map(i => (i, 0)).toMap) + }.asInstanceOf[ContinuousMemoryStreamOffset] + } + + override def getStartOffset: Offset = synchronized { + startOffset } override def deserializeOffset(json: String): ContinuousMemoryStreamOffset = { @@ -86,40 +98,34 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPa ) } - override def newScanConfigBuilder(start: Offset): ScanConfigBuilder = { - new SimpleStreamingScanConfigBuilder(fullSchema(), start) - } - - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { - val startOffset = config.asInstanceOf[SimpleStreamingScanConfig] - .start.asInstanceOf[ContinuousMemoryStreamOffset] + override def planInputPartitions(): ju.List[InputPartition[InternalRow]] = { synchronized { val endpointName = s"ContinuousMemoryStreamRecordEndpoint-${java.util.UUID.randomUUID()}-$id" endpointRef = recordEndpoint.rpcEnv.setupEndpoint(endpointName, recordEndpoint) startOffset.partitionNums.map { - case (part, index) => ContinuousMemoryStreamInputPartition(endpointName, part, index) - }.toArray + case (part, index) => + new ContinuousMemoryStreamInputPartition( + endpointName, part, index): InputPartition[InternalRow] + }.toList.asJava } } - override def createContinuousReaderFactory( - config: ScanConfig): ContinuousPartitionReaderFactory = { - ContinuousMemoryStreamReaderFactory - } - override def stop(): Unit = { if (endpointRef != null) recordEndpoint.rpcEnv.stop(endpointRef) } override def commit(end: Offset): Unit = {} - // ContinuousReadSupportProvider implementation + // ContinuousReadSupport implementation // This is necessary because of how StreamTest finds the source for AddDataMemory steps. - override def createContinuousReadSupport( + def createContinuousReader( + schema: Optional[StructType], checkpointLocation: String, - options: DataSourceOptions): ContinuousReadSupport = this + options: DataSourceOptions): ContinuousReader = { + this + } } object ContinuousMemoryStream { @@ -135,16 +141,12 @@ object ContinuousMemoryStream { /** * An input partition for continuous memory stream. */ -case class ContinuousMemoryStreamInputPartition( +class ContinuousMemoryStreamInputPartition( driverEndpointName: String, partition: Int, - startOffset: Int) extends InputPartition - -object ContinuousMemoryStreamReaderFactory extends ContinuousPartitionReaderFactory { - override def createReader(partition: InputPartition): ContinuousPartitionReader[InternalRow] = { - val p = partition.asInstanceOf[ContinuousMemoryStreamInputPartition] - new ContinuousMemoryStreamPartitionReader(p.driverEndpointName, p.partition, p.startOffset) - } + startOffset: Int) extends InputPartition[InternalRow] { + override def createPartitionReader: ContinuousMemoryStreamInputPartitionReader = + new ContinuousMemoryStreamInputPartitionReader(driverEndpointName, partition, startOffset) } /** @@ -152,10 +154,10 @@ object ContinuousMemoryStreamReaderFactory extends ContinuousPartitionReaderFact * * Polls the driver endpoint for new records. */ -class ContinuousMemoryStreamPartitionReader( +class ContinuousMemoryStreamInputPartitionReader( driverEndpointName: String, partition: Int, - startOffset: Int) extends ContinuousPartitionReader[InternalRow] { + startOffset: Int) extends ContinuousInputPartitionReader[InternalRow] { private val endpoint = RpcUtils.makeDriverRef( driverEndpointName, SparkEnv.get.conf, http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriteSupportProvider.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriteSupportProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriteSupportProvider.scala deleted file mode 100644 index 4218fd5..0000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriteSupportProvider.scala +++ /dev/null @@ -1,140 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming.sources - -import org.apache.spark.sql.{ForeachWriter, SparkSession} -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.execution.python.PythonForeachWriter -import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamingWriteSupportProvider} -import org.apache.spark.sql.sources.v2.writer.{DataWriter, WriterCommitMessage} -import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport} -import org.apache.spark.sql.streaming.OutputMode -import org.apache.spark.sql.types.StructType - -/** - * A [[org.apache.spark.sql.sources.v2.DataSourceV2]] for forwarding data into the specified - * [[ForeachWriter]]. - * - * @param writer The [[ForeachWriter]] to process all data. - * @param converter An object to convert internal rows to target type T. Either it can be - * a [[ExpressionEncoder]] or a direct converter function. - * @tparam T The expected type of the sink. - */ -case class ForeachWriteSupportProvider[T]( - writer: ForeachWriter[T], - converter: Either[ExpressionEncoder[T], InternalRow => T]) - extends StreamingWriteSupportProvider { - - override def createStreamingWriteSupport( - queryId: String, - schema: StructType, - mode: OutputMode, - options: DataSourceOptions): StreamingWriteSupport = { - new StreamingWriteSupport { - override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} - override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} - - override def createStreamingWriterFactory(): StreamingDataWriterFactory = { - val rowConverter: InternalRow => T = converter match { - case Left(enc) => - val boundEnc = enc.resolveAndBind( - schema.toAttributes, - SparkSession.getActiveSession.get.sessionState.analyzer) - boundEnc.fromRow - case Right(func) => - func - } - ForeachWriterFactory(writer, rowConverter) - } - - override def toString: String = "ForeachSink" - } - } -} - -object ForeachWriteSupportProvider { - def apply[T]( - writer: ForeachWriter[T], - encoder: ExpressionEncoder[T]): ForeachWriteSupportProvider[_] = { - writer match { - case pythonWriter: PythonForeachWriter => - new ForeachWriteSupportProvider[UnsafeRow]( - pythonWriter, Right((x: InternalRow) => x.asInstanceOf[UnsafeRow])) - case _ => - new ForeachWriteSupportProvider[T](writer, Left(encoder)) - } - } -} - -case class ForeachWriterFactory[T]( - writer: ForeachWriter[T], - rowConverter: InternalRow => T) - extends StreamingDataWriterFactory { - override def createWriter( - partitionId: Int, - taskId: Long, - epochId: Long): ForeachDataWriter[T] = { - new ForeachDataWriter(writer, rowConverter, partitionId, epochId) - } -} - -/** - * A [[DataWriter]] which writes data in this partition to a [[ForeachWriter]]. - * - * @param writer The [[ForeachWriter]] to process all data. - * @param rowConverter A function which can convert [[InternalRow]] to the required type [[T]] - * @param partitionId - * @param epochId - * @tparam T The type expected by the writer. - */ -class ForeachDataWriter[T]( - writer: ForeachWriter[T], - rowConverter: InternalRow => T, - partitionId: Int, - epochId: Long) - extends DataWriter[InternalRow] { - - // If open returns false, we should skip writing rows. - private val opened = writer.open(partitionId, epochId) - - override def write(record: InternalRow): Unit = { - if (!opened) return - - try { - writer.process(rowConverter(record)) - } catch { - case t: Throwable => - writer.close(t) - throw t - } - } - - override def commit(): WriterCommitMessage = { - writer.close(null) - ForeachWriterCommitMessage - } - - override def abort(): Unit = {} -} - -/** - * An empty [[WriterCommitMessage]]. [[ForeachWriter]] implementations have no global coordination. - */ -case object ForeachWriterCommitMessage extends WriterCommitMessage http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala new file mode 100644 index 0000000..e8ce21c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import org.apache.spark.sql.{ForeachWriter, SparkSession} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.python.PythonForeachWriter +import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamWriteSupport} +import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.types.StructType + +/** + * A [[org.apache.spark.sql.sources.v2.DataSourceV2]] for forwarding data into the specified + * [[ForeachWriter]]. + * + * @param writer The [[ForeachWriter]] to process all data. + * @param converter An object to convert internal rows to target type T. Either it can be + * a [[ExpressionEncoder]] or a direct converter function. + * @tparam T The expected type of the sink. + */ +case class ForeachWriterProvider[T]( + writer: ForeachWriter[T], + converter: Either[ExpressionEncoder[T], InternalRow => T]) extends StreamWriteSupport { + + override def createStreamWriter( + queryId: String, + schema: StructType, + mode: OutputMode, + options: DataSourceOptions): StreamWriter = { + new StreamWriter { + override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} + override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} + + override def createWriterFactory(): DataWriterFactory[InternalRow] = { + val rowConverter: InternalRow => T = converter match { + case Left(enc) => + val boundEnc = enc.resolveAndBind( + schema.toAttributes, + SparkSession.getActiveSession.get.sessionState.analyzer) + boundEnc.fromRow + case Right(func) => + func + } + ForeachWriterFactory(writer, rowConverter) + } + + override def toString: String = "ForeachSink" + } + } +} + +object ForeachWriterProvider { + def apply[T]( + writer: ForeachWriter[T], + encoder: ExpressionEncoder[T]): ForeachWriterProvider[_] = { + writer match { + case pythonWriter: PythonForeachWriter => + new ForeachWriterProvider[UnsafeRow]( + pythonWriter, Right((x: InternalRow) => x.asInstanceOf[UnsafeRow])) + case _ => + new ForeachWriterProvider[T](writer, Left(encoder)) + } + } +} + +case class ForeachWriterFactory[T]( + writer: ForeachWriter[T], + rowConverter: InternalRow => T) + extends DataWriterFactory[InternalRow] { + override def createDataWriter( + partitionId: Int, + taskId: Long, + epochId: Long): ForeachDataWriter[T] = { + new ForeachDataWriter(writer, rowConverter, partitionId, epochId) + } +} + +/** + * A [[DataWriter]] which writes data in this partition to a [[ForeachWriter]]. + * + * @param writer The [[ForeachWriter]] to process all data. + * @param rowConverter A function which can convert [[InternalRow]] to the required type [[T]] + * @param partitionId + * @param epochId + * @tparam T The type expected by the writer. + */ +class ForeachDataWriter[T]( + writer: ForeachWriter[T], + rowConverter: InternalRow => T, + partitionId: Int, + epochId: Long) + extends DataWriter[InternalRow] { + + // If open returns false, we should skip writing rows. + private val opened = writer.open(partitionId, epochId) + + override def write(record: InternalRow): Unit = { + if (!opened) return + + try { + writer.process(rowConverter(record)) + } catch { + case t: Throwable => + writer.close(t) + throw t + } + } + + override def commit(): WriterCommitMessage = { + writer.close(null) + ForeachWriterCommitMessage + } + + override def abort(): Unit = {} +} + +/** + * An empty [[WriterCommitMessage]]. [[ForeachWriter]] implementations have no global coordination. + */ +case object ForeachWriterCommitMessage extends WriterCommitMessage --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org