Github user tdas commented on a diff in the pull request:

    https://github.com/apache/spark/pull/20828#discussion_r180977228
  
    --- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousMemoryStream.scala
 ---
    @@ -0,0 +1,220 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.sql.execution.streaming.continuous
    +
    +import java.{util => ju}
    +import java.util.Optional
    +import java.util.concurrent.ArrayBlockingQueue
    +import javax.annotation.concurrent.GuardedBy
    +
    +import scala.collection.JavaConverters._
    +import scala.collection.mutable.ListBuffer
    +import scala.reflect.ClassTag
    +
    +import org.json4s.NoTypeHints
    +import org.json4s.jackson.Serialization
    +
    +import org.apache.spark.SparkEnv
    +import org.apache.spark.internal.Logging
    +import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, 
ThreadSafeRpcEndpoint}
    +import org.apache.spark.sql.{Dataset, Encoder, Row, SQLContext}
    +import org.apache.spark.sql.catalyst.encoders.encoderFor
    +import org.apache.spark.sql.catalyst.expressions.UnsafeRow
    +import org.apache.spark.sql.execution.streaming._
    +import 
org.apache.spark.sql.execution.streaming.continuous.ContinuousMemoryStream.GetRecord
    +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, 
DataSourceOptions}
    +import org.apache.spark.sql.sources.v2.reader.{DataReader, 
DataReaderFactory, SupportsScanUnsafeRow}
    +import 
org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, 
ContinuousReader, Offset, PartitionOffset}
    +import org.apache.spark.sql.types.StructType
    +import org.apache.spark.util.RpcUtils
    +
    +/**
    + * The overall strategy here is:
    + *  * ContinuousMemoryStream maintains a list of records for each 
partition. addData() will
    + *    distribute records evenly-ish across partitions.
    + *  * ContinuousMemoryStreamRecordBuffer is set up as an endpoint for 
executor-side
    + *    ContinuousMemoryStreamDataReader instances to poll. It returns the 
record at the specified
    + *    offset within the list, or null if that offset doesn't yet have a 
record.
    + */
    +class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
    +  extends MemoryStreamBase[A](sqlContext) with ContinuousReader with 
ContinuousReadSupport {
    +  private implicit val formats = Serialization.formats(NoTypeHints)
    +  val NUM_PARTITIONS = 2
    +
    +  protected val logicalPlan =
    +    StreamingRelationV2(this, "memory", Map(), attributes, 
None)(sqlContext.sparkSession)
    +
    +  // ContinuousReader implementation
    +
    +  @GuardedBy("this")
    +  private val records = Seq.fill(NUM_PARTITIONS)(new ListBuffer[A])
    +
    +  private val recordBuffer = new ContinuousMemoryStreamRecordBuffer()
    +
    +  private var startOffset: ContinuousMemoryStreamOffset = _
    +
    +  @volatile private var endpointRef: RpcEndpointRef = _
    +
    +  def addData(data: TraversableOnce[A]): Offset = synchronized {
    +    // Distribute data evenly among partition lists.
    +    data.toSeq.zipWithIndex.map {
    +      case (item, index) => records(index % NUM_PARTITIONS) += item
    +    }
    +
    +    // The new target offset is the offset where all records in all 
partitions have been processed.
    +    ContinuousMemoryStreamOffset((0 until NUM_PARTITIONS).map(i => (i, 
records(i).size)).toMap)
    +  }
    +
    +  override def setStartOffset(start: Optional[Offset]): Unit = 
synchronized {
    +    // Inferred initial offset is position 0 in each partition.
    +    startOffset = start.orElse {
    +      ContinuousMemoryStreamOffset((0 until NUM_PARTITIONS).map(i => (i, 
0)).toMap)
    +    }.asInstanceOf[ContinuousMemoryStreamOffset]
    +  }
    +
    +  override def getStartOffset: Offset = startOffset
    +
    +  override def deserializeOffset(json: String): 
ContinuousMemoryStreamOffset = {
    +    ContinuousMemoryStreamOffset(Serialization.read[Map[Int, Int]](json))
    +  }
    +
    +  override def mergeOffsets(offsets: Array[PartitionOffset]): 
ContinuousMemoryStreamOffset = {
    +    ContinuousMemoryStreamOffset {
    +      offsets.map {
    +        case ContinuousMemoryStreamPartitionOffset(part, num) => (part, 
num)
    +      }.toMap
    +    }
    +  }
    +
    +  override def createDataReaderFactories(): 
ju.List[DataReaderFactory[Row]] = {
    +    synchronized {
    +      endpointRef =
    +        
recordBuffer.rpcEnv.setupEndpoint(ContinuousMemoryStream.recordBufferName(id), 
recordBuffer)
    +
    +      startOffset.partitionNums.map {
    +        case (part, index) =>
    +          val name = ContinuousMemoryStream.recordBufferName(id)
    +          new ContinuousMemoryStreamDataReaderFactory(name, part, index): 
DataReaderFactory[Row]
    +      }.toList.asJava
    +    }
    +  }
    +
    +  override def stop(): Unit = {
    +    if (endpointRef != null) recordBuffer.rpcEnv.stop(endpointRef)
    +  }
    +
    +  override def commit(end: Offset): Unit = {}
    +
    +  // ContinuousReadSupport implementation
    +  // This is necessary because of how StreamTest finds the source for 
AddDataMemory steps.
    +  def createContinuousReader(
    +      schema: Optional[StructType],
    +      checkpointLocation: String,
    +      options: DataSourceOptions): ContinuousReader = {
    +    this
    +  }
    +
    +  override def reset(): Unit = synchronized {
    +    records.foreach(_.clear())
    +    startOffset = ContinuousMemoryStreamOffset((0 until 
NUM_PARTITIONS).map(i => (i, 0)).toMap)
    +  }
    +
    +  /**
    +   * Endpoint for executors to poll for records.
    +   */
    +  private class ContinuousMemoryStreamRecordBuffer extends 
ThreadSafeRpcEndpoint {
    +    override val rpcEnv: RpcEnv = SparkEnv.get.rpcEnv
    +
    +    override def receiveAndReply(context: RpcCallContext): 
PartialFunction[Any, Unit] = {
    +      case GetRecord(ContinuousMemoryStreamPartitionOffset(part, index)) =>
    +        ContinuousMemoryStream.this.synchronized {
    +          val buf = records(part)
    +
    +          val record =
    --- End diff --
    
    nit: Can this be fit into a single line?
    `val record = if (buf.size <= index) None else Some(buf(index))` should fit.
    also, extra line above.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to