Github user tdas commented on a diff in the pull request: https://github.com/apache/spark/pull/20828#discussion_r180977228 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousMemoryStream.scala --- @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import java.{util => ju} +import java.util.Optional +import java.util.concurrent.ArrayBlockingQueue +import javax.annotation.concurrent.GuardedBy + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ListBuffer +import scala.reflect.ClassTag + +import org.json4s.NoTypeHints +import org.json4s.jackson.Serialization + +import org.apache.spark.SparkEnv +import org.apache.spark.internal.Logging +import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} +import org.apache.spark.sql.{Dataset, Encoder, Row, SQLContext} +import org.apache.spark.sql.catalyst.encoders.encoderFor +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.continuous.ContinuousMemoryStream.GetRecord +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions} +import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, SupportsScanUnsafeRow} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset} +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.RpcUtils + +/** + * The overall strategy here is: + * * ContinuousMemoryStream maintains a list of records for each partition. addData() will + * distribute records evenly-ish across partitions. + * * ContinuousMemoryStreamRecordBuffer is set up as an endpoint for executor-side + * ContinuousMemoryStreamDataReader instances to poll. It returns the record at the specified + * offset within the list, or null if that offset doesn't yet have a record. + */ +class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) + extends MemoryStreamBase[A](sqlContext) with ContinuousReader with ContinuousReadSupport { + private implicit val formats = Serialization.formats(NoTypeHints) + val NUM_PARTITIONS = 2 + + protected val logicalPlan = + StreamingRelationV2(this, "memory", Map(), attributes, None)(sqlContext.sparkSession) + + // ContinuousReader implementation + + @GuardedBy("this") + private val records = Seq.fill(NUM_PARTITIONS)(new ListBuffer[A]) + + private val recordBuffer = new ContinuousMemoryStreamRecordBuffer() + + private var startOffset: ContinuousMemoryStreamOffset = _ + + @volatile private var endpointRef: RpcEndpointRef = _ + + def addData(data: TraversableOnce[A]): Offset = synchronized { + // Distribute data evenly among partition lists. + data.toSeq.zipWithIndex.map { + case (item, index) => records(index % NUM_PARTITIONS) += item + } + + // The new target offset is the offset where all records in all partitions have been processed. + ContinuousMemoryStreamOffset((0 until NUM_PARTITIONS).map(i => (i, records(i).size)).toMap) + } + + override def setStartOffset(start: Optional[Offset]): Unit = synchronized { + // Inferred initial offset is position 0 in each partition. + startOffset = start.orElse { + ContinuousMemoryStreamOffset((0 until NUM_PARTITIONS).map(i => (i, 0)).toMap) + }.asInstanceOf[ContinuousMemoryStreamOffset] + } + + override def getStartOffset: Offset = startOffset + + override def deserializeOffset(json: String): ContinuousMemoryStreamOffset = { + ContinuousMemoryStreamOffset(Serialization.read[Map[Int, Int]](json)) + } + + override def mergeOffsets(offsets: Array[PartitionOffset]): ContinuousMemoryStreamOffset = { + ContinuousMemoryStreamOffset { + offsets.map { + case ContinuousMemoryStreamPartitionOffset(part, num) => (part, num) + }.toMap + } + } + + override def createDataReaderFactories(): ju.List[DataReaderFactory[Row]] = { + synchronized { + endpointRef = + recordBuffer.rpcEnv.setupEndpoint(ContinuousMemoryStream.recordBufferName(id), recordBuffer) + + startOffset.partitionNums.map { + case (part, index) => + val name = ContinuousMemoryStream.recordBufferName(id) + new ContinuousMemoryStreamDataReaderFactory(name, part, index): DataReaderFactory[Row] + }.toList.asJava + } + } + + override def stop(): Unit = { + if (endpointRef != null) recordBuffer.rpcEnv.stop(endpointRef) + } + + override def commit(end: Offset): Unit = {} + + // ContinuousReadSupport implementation + // This is necessary because of how StreamTest finds the source for AddDataMemory steps. + def createContinuousReader( + schema: Optional[StructType], + checkpointLocation: String, + options: DataSourceOptions): ContinuousReader = { + this + } + + override def reset(): Unit = synchronized { + records.foreach(_.clear()) + startOffset = ContinuousMemoryStreamOffset((0 until NUM_PARTITIONS).map(i => (i, 0)).toMap) + } + + /** + * Endpoint for executors to poll for records. + */ + private class ContinuousMemoryStreamRecordBuffer extends ThreadSafeRpcEndpoint { + override val rpcEnv: RpcEnv = SparkEnv.get.rpcEnv + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case GetRecord(ContinuousMemoryStreamPartitionOffset(part, index)) => + ContinuousMemoryStream.this.synchronized { + val buf = records(part) + + val record = --- End diff -- nit: Can this be fit into a single line? `val record = if (buf.size <= index) None else Some(buf(index))` should fit. also, extra line above.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org