Github user zsxwing commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19984#discussion_r158391247
  
    --- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala
 ---
    @@ -0,0 +1,195 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.sql.execution.streaming.continuous
    +
    +import java.util.concurrent.atomic.AtomicLong
    +
    +import scala.collection.mutable
    +
    +import org.apache.spark.SparkEnv
    +import org.apache.spark.internal.Logging
    +import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, 
ThreadSafeRpcEndpoint}
    +import org.apache.spark.sql.SparkSession
    +import org.apache.spark.sql.execution.streaming.StreamingQueryWrapper
    +import org.apache.spark.sql.sources.v2.reader.{ContinuousReader, 
PartitionOffset}
    +import org.apache.spark.sql.sources.v2.writer.{ContinuousWriter, 
WriterCommitMessage}
    +import org.apache.spark.util.RpcUtils
    +
    +private[continuous] sealed trait EpochCoordinatorMessage extends 
Serializable
    +
    +// Driver epoch trigger message
    +/**
    + * Atomically increment the current epoch and get the new value.
    + */
    +private[sql] case object IncrementAndGetEpoch extends 
EpochCoordinatorMessage
    +
    +// Init messages
    +/**
    + * Set the reader and writer partition counts. Tasks may not be started 
until the coordinator
    + * has acknowledged these messages.
    + */
    +private[sql] case class SetReaderPartitions(numPartitions: Int) extends 
EpochCoordinatorMessage
    +case class SetWriterPartitions(numPartitions: Int) extends 
EpochCoordinatorMessage
    +
    +// Partition task messages
    +/**
    + * Get the current epoch.
    + */
    +private[sql] case object GetCurrentEpoch extends EpochCoordinatorMessage
    +/**
    + * Commit a partition at the specified epoch with the given message.
    + */
    +private[sql] case class CommitPartitionEpoch(
    +    partitionId: Int,
    +    epoch: Long,
    +    message: WriterCommitMessage) extends EpochCoordinatorMessage
    +/**
    + * Report that a partition is ending the specified epoch at the specified 
offset.
    + */
    +private[sql] case class ReportPartitionOffset(
    +    partitionId: Int,
    +    epoch: Long,
    +    offset: PartitionOffset) extends EpochCoordinatorMessage
    +
    +
    +/** Helper object used to create reference to [[EpochCoordinator]]. */
    +private[sql] object EpochCoordinatorRef extends Logging {
    +  private def endpointName(runId: String) = s"EpochCoordinator-$runId"
    +
    +  /**
    +   * Create a reference to a new [[EpochCoordinator]].
    +   */
    +  def create(
    +      writer: ContinuousWriter,
    +      reader: ContinuousReader,
    +      startEpoch: Long,
    +      queryId: String,
    +      runId: String,
    +      session: SparkSession,
    +      env: SparkEnv): RpcEndpointRef = synchronized {
    +    val coordinator = new EpochCoordinator(writer, reader, startEpoch, 
queryId, session, env.rpcEnv)
    +    val ref = env.rpcEnv.setupEndpoint(endpointName(runId), coordinator)
    +    logInfo("Registered EpochCoordinator endpoint")
    +    ref
    +  }
    +
    +  def get(runId: String, env: SparkEnv): RpcEndpointRef = synchronized {
    +    val rpcEndpointRef = RpcUtils.makeDriverRef(endpointName(runId), 
env.conf, env.rpcEnv)
    +    logDebug("Retrieved existing EpochCoordinator endpoint")
    +    rpcEndpointRef
    +  }
    +}
    +
    +/**
    + * Handles three major epoch coordination tasks for continuous processing:
    + *
    + * * Maintains a local epoch counter (the "driver epoch"), incremented by 
IncrementAndGetEpoch
    + *   and pollable from executors by GetCurrentEpoch. Note that this epoch 
is *not* immediately
    + *   reflected anywhere in ContinuousExecution.
    + * * Collates ReportPartitionOffset messages, and forwards to 
ContinuousExecution when all
    + *   readers have ended a given epoch.
    + * * Collates CommitPartitionEpoch messages, and forwards to 
ContinuousExecution when all readers
    + *   have both committed and reported an end offset for a given epoch.
    + */
    +private[continuous] class EpochCoordinator(
    +    writer: ContinuousWriter,
    +    reader: ContinuousReader,
    +    startEpoch: Long,
    +    queryId: String,
    +    session: SparkSession,
    +    override val rpcEnv: RpcEnv)
    +  extends ThreadSafeRpcEndpoint with Logging {
    +
    +  private var numReaderPartitions: Int = _
    +  private var numWriterPartitions: Int = _
    +
    +  private var currentDriverEpoch = startEpoch
    +
    +  // (epoch, partition) -> message
    +  private val partitionCommits =
    +    mutable.Map[(Long, Int), WriterCommitMessage]()
    +  // (epoch, partition) -> offset
    +  private val partitionOffsets =
    +    mutable.Map[(Long, Int), PartitionOffset]()
    +
    +  private def resolveCommitsAtEpoch(epoch: Long) = {
    +    val thisEpochCommits =
    +      partitionCommits.collect { case ((e, _), msg) if e == epoch => msg }
    +    val nextEpochOffsets =
    +      partitionOffsets.collect { case ((e, _), o) if e == epoch => o }
    +
    +    if (thisEpochCommits.size == numWriterPartitions &&
    +      nextEpochOffsets.size == numReaderPartitions) {
    +      logDebug(s"Epoch $epoch has received commits from all partitions. 
Committing globally.")
    +      val query = 
session.streams.get(queryId).asInstanceOf[StreamingQueryWrapper]
    --- End diff --
    
    why not pass the query into `EpochCoordinator`'s constructor? Getting a 
query from StreamingQueryManager may have a race condition because the query 
can fail before we process `CommitPartitionEpoch` messages. If so, 
`session.streams.get(queryId)` will return null.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to