This is an automated email from the ASF dual-hosted git repository.
HeartSaVioR pushed a commit to branch branch-4.x
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.x by this push:
new 44750f2a064a [SPARK-57141][SS][RTM][STREAMINGSHUFFLE][PART3] Add
StreamingShuffleManager and MultiShuffleManager
44750f2a064a is described below
commit 44750f2a064a57d9179e30a0049673c534c305f0
Author: Boyang Jerry Peng <[email protected]>
AuthorDate: Mon Jun 8 07:13:16 2026 +0900
[SPARK-57141][SS][RTM][STREAMINGSHUFFLE][PART3] Add StreamingShuffleManager
and MultiShuffleManager
### What changes were proposed in this pull request?
This is **part 3** of a multi-PR effort to add *streaming shuffle* to
Spark — a push-based shuffle used by Real-Time Mode (RTM) structured streaming,
where writer tasks push records
directly to reader tasks over the network instead of writing map output
to disk for readers to pull.
This PR adds the shuffle-manager layer that later PRs plug into:
- **`StreamingShuffleManager`** — a `ShuffleManager` implementation for
streaming shuffle. `getWriter`/`getReader` are intentionally stubbed in this PR
(they throw
`UnsupportedOperationException`) and are implemented in the push-path /
pull-path PRs that follow.
- **`MultiShuffleManager`** — routes each shuffle to either the batch
`SortShuffleManager` or the `StreamingShuffleManager`, based on a per-query
local property, so a single application
can mix batch and streaming shuffle.
- **`TaskContextAwareLogging`** — a `Logging` mixin that prefixes log
lines with queryId / shuffleId / stageId / taskId.
- **`SparkEnv`** — exposes the `StreamingShuffleOutputTracker` (added in
part 2) to executors, and initializes it **only** when the configured shuffle
manager is `StreamingShuffleManager`
or `MultiShuffleManager`.
- Two streaming-shuffle error conditions
(`STREAMING_SHUFFLE_INCORRECT_SEQUENCE_NUMBER`,
`STREAMING_SHUFFLE_UNEXPECTED_MESSAGE_TYPE`) and the `STREAMING_QUERY_ID` log
key.
The full PR stack:
- **Part 1** (SPARK-56674, *merged*) — streaming shuffle wire protocol
(Netty messages).
- **Part 2** (SPARK-56962, *merged*) — `StreamingShuffleOutputTracker`
(driver-side writer-location coordination).
- **Part 3** (*this PR*) — shuffle-manager layer
(`StreamingShuffleManager` + `MultiShuffleManager`), logging mixin, and
SparkEnv tracker wiring.
- **Part 4** — `StreamingShuffleWriter` + server-side Netty handler (push
path).
- **Part 5** — `StreamingShuffleReader` + client-side Netty handler (pull
path).
- **Part 6** — register streaming shuffles with the tracker in
`DAGScheduler` (activation).
- **Part 7** — end-to-end `StreamingShuffleSuite`.
- **Part 8** — documentation.
### Why are the changes needed?
Real-Time Mode / low-latency continuous queries need shuffle data to flow
continuously between stages. The default sort shuffle (write map output to
disk, then have reducers pull it) adds
latency that is unacceptable for these workloads. Streaming shuffle
instead pushes records directly from writer tasks to reader tasks.
This PR lands the manager layer that the writer and reader
implementations attach to, plus `MultiShuffleManager` so batch stages keep
using the sort shuffle while streaming stages use the
streaming shuffle within the same application.
### Does this PR introduce _any_ user-facing change?
No. The new shuffle managers are opt-in via `spark.shuffle.manager` and
are not the default; `getWriter`/`getReader` are still stubbed in this PR, so
the feature is not yet usable
end-to-end (completed in later PRs). The `StreamingShuffleOutputTracker`
is initialized only when one of the new managers is configured, so there is no
change to the default (sort
shuffle) path — this is covered by tests.
### How was this patch tested?
New unit suites:
- **`StreamingShuffleManagerSuite`** — `getWriterId` for data/termination
messages and the unexpected-message-type error; `getQueryId` resolution and
failure; `registerShuffle` handle
type; and SparkEnv gating (tracker is present for
`StreamingShuffleManager`, absent for the default manager).
- **`MultiShuffleManagerSuite`** — per-query streaming-vs-batch routing,
the enable property, and SparkEnv gating for `MultiShuffleManager`.
13 tests, all passing. `SparkThrowableSuite` validates the two new error
conditions.
### Was this patch authored or co-authored using generative AI tooling?
Co-authored with Claude Code (Claude Opus 4.8)
Closes #56196 from jerrypeng/stack/streaming-shuffle-pr3-managers.
Authored-by: Boyang Jerry Peng <[email protected]>
Signed-off-by: Jungtaek Lim <[email protected]>
(cherry picked from commit a6ac0b8109c02969d685908c37062566653918cc)
Signed-off-by: Jungtaek Lim <[email protected]>
---
.../java/org/apache/spark/internal/LogKeys.java | 1 +
.../src/main/resources/error/error-conditions.json | 12 ++
.../src/main/scala/org/apache/spark/SparkEnv.scala | 44 ++++++
.../shuffle/streaming/MultiShuffleManager.scala | 154 +++++++++++++++++++++
.../streaming/StreamingShuffleManager.scala | 130 +++++++++++++++++
.../streaming/TaskContextAwareLogging.scala | 109 +++++++++++++++
.../streaming/MultiShuffleManagerSuite.scala | 71 ++++++++++
.../streaming/StreamingShuffleManagerSuite.scala | 125 +++++++++++++++++
8 files changed, 646 insertions(+)
diff --git
a/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java
b/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java
index d8ce9d025af9..37064bf77631 100644
--- a/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java
+++ b/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java
@@ -794,6 +794,7 @@ public enum LogKeys implements LogKey {
STREAMING_DATA_SOURCE_NAME,
STREAMING_OFFSETS_END,
STREAMING_OFFSETS_START,
+ STREAMING_QUERY_ID,
STREAMING_QUERY_PROGRESS,
STREAMING_SOURCE,
STREAMING_TABLE,
diff --git a/common/utils/src/main/resources/error/error-conditions.json
b/common/utils/src/main/resources/error/error-conditions.json
index a21f9aa08521..bf5b3092511f 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -7206,6 +7206,18 @@
},
"sqlState" : "0A000"
},
+ "STREAMING_SHUFFLE_INCORRECT_SEQUENCE_NUMBER" : {
+ "message" : [
+ "Streaming shuffle <messageType> between writer <writerId> and reader
<readerId> expected to have sequence number <expSeqNum>, but the actual
sequence number is <actSeqNum>. Please verify that the messages are sent in
order."
+ ],
+ "sqlState" : "XXKST"
+ },
+ "STREAMING_SHUFFLE_UNEXPECTED_MESSAGE_TYPE" : {
+ "message" : [
+ "Unexpected message type <messageType> encountered during streaming
shuffle."
+ ],
+ "sqlState" : "XXKST"
+ },
"STREAMING_STATEFUL_OPERATOR_MISSING_STATE_DIRECTORY" : {
"message" : [
"Cannot restart streaming query with stateful operators because the
state directory is empty or missing.",
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala
b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 4e56c88501ed..9c4abdf66579 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -46,6 +46,7 @@ import
org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinato
import org.apache.spark.security.CryptoStreamUtils
import org.apache.spark.serializer.{JavaSerializer, Serializer,
SerializerManager}
import org.apache.spark.shuffle.ShuffleManager
+import org.apache.spark.shuffle.streaming.{MultiShuffleManager,
StreamingShuffleManager}
import org.apache.spark.storage._
import org.apache.spark.udf.worker.UDFWorkerSpecification
import org.apache.spark.udf.worker.core.{UDFDispatcherFactory,
UDFDispatcherManager, WorkerDispatcher}
@@ -181,6 +182,7 @@ class SparkEnv (
pythonWorkers.values.foreach(_.stop())
udfDispatcherManager.foreach(_.close())
mapOutputTracker.stop()
+ _streamingShuffleOutputTracker.foreach(_.stop())
if (shuffleManager != null) {
shuffleManager.stop()
}
@@ -299,6 +301,48 @@ class SparkEnv (
// Signal that the ShuffleManager has been initialized
shuffleManagerInitLatch.countDown()
}
+ initializeStreamingShuffleOutputTracker()
+ }
+
+ // Holds the streaming shuffle output tracker, which is only present when
the configured
+ // shuffle manager requires it (i.e., StreamingShuffleManager or
MultiShuffleManager).
+ @volatile private var _streamingShuffleOutputTracker:
Option[StreamingShuffleOutputTracker] =
+ None
+
+ def streamingShuffleOutputTracker: Option[StreamingShuffleOutputTracker] =
+ _streamingShuffleOutputTracker
+
+ /**
+ * Initialize the StreamingShuffleOutputTracker if the configured shuffle
manager requires one
+ * and one does not already exist. This method is idempotent -- calling it
multiple times is safe.
+ */
+ private def initializeStreamingShuffleOutputTracker(): Unit = {
+ if (_streamingShuffleOutputTracker.isDefined) {
+ return
+ }
+
+ val shuffleManagerName = ShuffleManager.getShuffleManagerClassName(conf)
+ if (shuffleManagerName == classOf[StreamingShuffleManager].getName
+ || shuffleManagerName == classOf[MultiShuffleManager].getName) {
+ val tracker = if (SparkContext.isDriver(executorId)) {
+ new StreamingShuffleOutputTrackerMaster(conf)
+ } else {
+ new StreamingShuffleOutputTrackerWorker(conf)
+ }
+
+ if (SparkContext.isDriver(executorId)) {
+ tracker.trackerEndpoint = rpcEnv.setupEndpoint(
+ StreamingShuffleOutputTracker.ENDPOINT_NAME,
+ new StreamingShuffleOutputTrackerMasterEndpoint(
+ rpcEnv,
+ tracker.asInstanceOf[StreamingShuffleOutputTrackerMaster],
+ conf))
+ } else {
+ tracker.trackerEndpoint = RpcUtils.makeDriverRef(
+ StreamingShuffleOutputTracker.ENDPOINT_NAME, conf, rpcEnv)
+ }
+ _streamingShuffleOutputTracker = Some(tracker)
+ }
}
private[spark] def initializeMemoryManager(numUsableCores: Int): Unit = {
diff --git
a/core/src/main/scala/org/apache/spark/shuffle/streaming/MultiShuffleManager.scala
b/core/src/main/scala/org/apache/spark/shuffle/streaming/MultiShuffleManager.scala
new file mode 100644
index 000000000000..9e63c9375955
--- /dev/null
+++
b/core/src/main/scala/org/apache/spark/shuffle/streaming/MultiShuffleManager.scala
@@ -0,0 +1,154 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.streaming
+
+import java.util.Properties
+import java.util.concurrent.ConcurrentHashMap
+
+import org.apache.spark.{ShuffleDependency, SparkConf, SparkContext,
SparkException, TaskContext}
+import org.apache.spark.internal.Logging
+import org.apache.spark.shuffle.{ShuffleBlockResolver, ShuffleHandle,
ShuffleManager, ShuffleReader, ShuffleReadMetricsReporter,
ShuffleWriteMetricsReporter, ShuffleWriter}
+import org.apache.spark.shuffle.sort.SortShuffleManager
+import
org.apache.spark.shuffle.streaming.MultiShuffleManager.isStreamingShuffleEnabled
+
+class MultiShuffleHandle(
+ val streamingShuffleHandle: ShuffleHandle,
+ val otherShuffleHandle: ShuffleHandle)
+ extends ShuffleHandle(streamingShuffleHandle.shuffleId)
+
+object MultiShuffleManager {
+ // Streaming shuffle is used for queries running in Real-Time Mode
(concurrent stages), gated by
+ // the same per-query local property that the RTM micro-batch execution sets.
+ // TODO(SPARK-57000): once ConcurrentStageDAGScheduler is merged
(apache/spark#56055), reference
+ // ConcurrentStageDAGScheduler.CONCURRENT_STAGES_ENABLED_PROPERTY here (and
delegate to
+ // ConcurrentStageDAGScheduler.isConcurrentStagesEnabled) instead of
hardcoding the property.
+ val STREAMING_SHUFFLE_ENABLED_PROPERTY =
"streaming.concurrent.stages.enabled"
+
+ def isStreamingShuffleEnabled(properties: Properties): Boolean =
+ "true" == properties.getProperty(STREAMING_SHUFFLE_ENABLED_PROPERTY)
+}
+
+/* This shuffle manager is used to allow real-time queries that depends on
streaming shuffle
+and normal queries that depends on sort shuffle to coexist in a cluster. Right
now, we only
+allows configuration of shuffle manager at cluster level, so consider using
this shuffle
+manager if you want to run batch and real time queries at the same time.
+ */
+class MultiShuffleManager(conf: SparkConf) extends ShuffleManager with Logging
{
+ // To make sure the type of shuffle manager used for a shuffle is the same
during its lifetime
+ private val shuffleIdToManager = new ConcurrentHashMap[Int, ShuffleManager]()
+ private var streamingShuffleManager: Option[StreamingShuffleManager] = None
+ private var sortShuffleManager: Option[SortShuffleManager] = None
+
+ private def shuffleManager(shuffleId: Int): ShuffleManager = {
+ shuffleIdToManager.computeIfAbsent(shuffleId, _ => {
+ val properties = SparkContext.getActive.map(_.getLocalProperties)
+ .orElse(Option(TaskContext.get()).map(_.getLocalProperties))
+ .getOrElse(throw SparkException.internalError(
+ "Cannot determine streaming shuffle routing: no active SparkContext
or TaskContext"))
+ if (isStreamingShuffleEnabled(properties)) {
+ if (streamingShuffleManager.isEmpty) {
+ streamingShuffleManager = Some(new StreamingShuffleManager)
+ }
+ streamingShuffleManager.get
+ } else {
+ if (sortShuffleManager.isEmpty) {
+ sortShuffleManager = Some(new SortShuffleManager(conf))
+ }
+ sortShuffleManager.get
+ }
+ })
+ }
+
+ override def registerShuffle[K, V, C](
+ shuffleId: Int,
+ dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
+ shuffleIdToManager.synchronized {
+ shuffleManager(shuffleId).registerShuffle(shuffleId, dependency)
+ }
+ }
+
+ override def getWriter[K, V](
+ handle: ShuffleHandle,
+ mapId: Long,
+ context: TaskContext,
+ metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = {
+ shuffleIdToManager.synchronized {
+ shuffleManager(handle.shuffleId).getWriter(handle, mapId, context,
metrics)
+ }
+ }
+
+ override def getReader[K, C](
+ handle: ShuffleHandle,
+ startMapIndex: Int,
+ endMapIndex: Int,
+ startPartition: Int,
+ endPartition: Int,
+ context: TaskContext,
+ metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
+ shuffleIdToManager.synchronized {
+ shuffleManager(handle.shuffleId).getReader(
+ handle,
+ startMapIndex,
+ endMapIndex,
+ startPartition,
+ endPartition,
+ context,
+ metrics)
+ }
+ }
+
+ override def unregisterShuffle(shuffleId: Int): Boolean = {
+ shuffleIdToManager.synchronized {
+ val manager = shuffleIdToManager.get(shuffleId)
+ // During unregistering shuffle, which happens when shuffleDependency is
garbage
+ // collected, the context might not be active anymore, in this case, we
will
+ // perform no-op since there is no cached shuffle manager, meaning
+ // there are no other calls (i.e registerShuffle, getWriter, or
getReader) previously
+ // invoked, thereby no state to cleanup
+ if (manager == null) {
+ return true
+ }
+
+ shuffleIdToManager.remove(shuffleId)
+ manager.unregisterShuffle(shuffleId)
+ }
+ }
+
+ override def shuffleBlockResolver: ShuffleBlockResolver = {
+ shuffleIdToManager.synchronized {
+ if (sortShuffleManager.nonEmpty) {
+ sortShuffleManager.get.shuffleBlockResolver
+ } else {
+ // don't need to support this for the streaming shuffle implementation
+ // since block manager is not used
+ throw new UnsupportedOperationException()
+ }
+ }
+ }
+
+ override def stop(): Unit = {
+ shuffleIdToManager.synchronized {
+ if (streamingShuffleManager.nonEmpty) {
+ streamingShuffleManager.get.stop()
+ }
+ if (sortShuffleManager.nonEmpty) {
+ sortShuffleManager.get.stop()
+ }
+ }
+ }
+}
diff --git
a/core/src/main/scala/org/apache/spark/shuffle/streaming/StreamingShuffleManager.scala
b/core/src/main/scala/org/apache/spark/shuffle/streaming/StreamingShuffleManager.scala
new file mode 100644
index 000000000000..f56d4f0fc4f8
--- /dev/null
+++
b/core/src/main/scala/org/apache/spark/shuffle/streaming/StreamingShuffleManager.scala
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.streaming
+
+import org.apache.spark.{ShuffleDependency, SparkException,
SparkRuntimeException, TaskContext}
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.shuffle.streaming.{DataMessage,
StreamingShuffleMessage, StreamingShuffleMessageType, TerminationControlMessage}
+import org.apache.spark.shuffle._
+
+class StreamingShuffleHandle[K, V, C](shuffleId: Int, dependency:
ShuffleDependency[K, V, C])
+ extends BaseShuffleHandle[K, V, C](shuffleId, dependency)
+
+object StreamingShuffleManager extends Logging {
+ // Exposed for testing
+ private[spark] val QUERY_ID_PROPERTY_KEY = "sql.streaming.queryId"
+ // Since above is not applicable for batch query, we use below id to track
error for batch
+ // query with streaming shuffle
+ private val QUERY_EXECUTION_ID_PROPERTY_KEY = "spark.sql.execution.id"
+
+ def getQueryId(context: TaskContext): String = {
+ Option(context.getLocalProperty(QUERY_ID_PROPERTY_KEY))
+
.orElse(Option(context.getLocalProperty(QUERY_EXECUTION_ID_PROPERTY_KEY)))
+ .getOrElse(throw SparkException.internalError(
+ "Streaming shuffle requires the query id or SQL execution id local
property to be set"))
+ }
+
+ /* Called from the reader side to get the writerId associated with a message
*/
+ def getWriterId(message: StreamingShuffleMessage): Int = {
+ message.messageType() match {
+ case StreamingShuffleMessageType.DATA_MESSAGE_UNSAFE_ROW =>
+ message.asInstanceOf[DataMessage].shuffleWriterId
+ case StreamingShuffleMessageType.TERMINATION_CONTROL_MESSAGE =>
+ message.asInstanceOf[TerminationControlMessage].shuffleWriterId
+ case _ =>
+ // Should not reach here
+ throw streamingShuffleUnexpectedMessageType(message.messageType());
+ }
+ }
+
+ def streamingShuffleIncorrectSequenceNumber(
+ messageType: StreamingShuffleMessageType,
+ writerId: Int,
+ readerId: Int,
+ expSeqNum: Long,
+ actSeqNum: Long): RuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "STREAMING_SHUFFLE_INCORRECT_SEQUENCE_NUMBER",
+ messageParameters = Map(
+ "messageType" -> messageType.toString,
+ "writerId" -> writerId.toString,
+ "readerId" -> readerId.toString,
+ "expSeqNum" -> expSeqNum.toString,
+ "actSeqNum" -> actSeqNum.toString))
+ }
+
+ def streamingShuffleUnexpectedMessageType(
+ messageType: StreamingShuffleMessageType): RuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "STREAMING_SHUFFLE_UNEXPECTED_MESSAGE_TYPE",
+ messageParameters = Map("messageType" -> messageType.toString))
+ }
+}
+
+private[spark] class StreamingShuffleManager extends ShuffleManager with
Logging {
+
+ logInfo(log"Using StreamingShuffleManager")
+
+ override def registerShuffle[K, V, C](
+ shuffleId: Int,
+ dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
+ new StreamingShuffleHandle(shuffleId, dependency)
+ }
+
+ override def getWriter[K, V](
+ handle: ShuffleHandle,
+ mapId: Long,
+ context: TaskContext,
+ metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = {
+ // Implementation is added in a follow-up commit that introduces
StreamingShuffleWriter.
+ throw new UnsupportedOperationException(
+ "StreamingShuffleManager.getWriter is not yet implemented")
+ }
+
+ /**
+ * For the streaming shuffle, the startMapIndex, endMapIndex,
startPartition, and endPartition
+ * arguments are not relevant.
+ */
+ override def getReader[K, C](
+ handle: ShuffleHandle,
+ startMapIndex: Int,
+ endMapIndex: Int,
+ startPartition: Int,
+ endPartition: Int,
+ context: TaskContext,
+ metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
+ // Implementation is added in a follow-up commit that introduces
StreamingShuffleReader.
+ throw new UnsupportedOperationException(
+ "StreamingShuffleManager.getReader is not yet implemented")
+ }
+
+ override def unregisterShuffle(shuffleId: Int): Boolean = {
+ // No manager-side state to release here: the driver's
StreamingShuffleOutputTracker is
+ // unregistered in BlockManagerStorageEndpoint's RemoveShuffle handler,
and per-task writer
+ // and reader resources are released via task completion listeners.
+ true
+ }
+
+ override def shuffleBlockResolver: ShuffleBlockResolver = {
+ // don't need to support this for the streaming shuffle implementation
+ // since block manager is not used
+ throw new UnsupportedOperationException()
+ }
+
+ override def stop(): Unit = {}
+}
diff --git
a/core/src/main/scala/org/apache/spark/shuffle/streaming/TaskContextAwareLogging.scala
b/core/src/main/scala/org/apache/spark/shuffle/streaming/TaskContextAwareLogging.scala
new file mode 100644
index 000000000000..fd0ac89abc79
--- /dev/null
+++
b/core/src/main/scala/org/apache/spark/shuffle/streaming/TaskContextAwareLogging.scala
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.streaming
+
+import scala.concurrent.duration.Duration
+
+import org.apache.spark.TaskContext
+import org.apache.spark.internal.{LogEntry, Logging, LogKeys,
MessageWithContext}
+
+trait TaskContextAwareLogging extends Logging {
+
+ def context: TaskContext
+
+ private val queryId: Option[String] = Option(context)
+ .flatMap(ctx =>
Option(ctx.getLocalProperty("sql.streaming.queryId")).map(_.take(5)))
+ .filter(_.nonEmpty)
+
+ @volatile private var shuffleId: Option[Int] = None
+
+ def setShuffleIdForLogging(shuffleId: Int): Unit = {
+ this.shuffleId = Some(shuffleId)
+ }
+
+ private def loadTaskId: Option[String] = {
+ Option(context)
+ .flatMap(ctx => Option(ctx.partitionId()))
+ .map(_.toString)
+ }
+
+ private def loadStageId: Option[String] = {
+ Option(context)
+ .flatMap(ctx => Option(ctx.stageId()))
+ .map(_.toString)
+ }
+
+ protected def formatMessage(
+ msg: => String,
+ taskId: Option[String] = loadTaskId,
+ stageId: Option[String] = loadStageId): String = {
+ val taskIdMsg = taskId.map(tid => s"[taskId = $tid] ").getOrElse("")
+ val stageIdMsg = stageId.map(sid => s"[stageId = $sid] ").getOrElse("")
+ val shuffleIdMsg = shuffleId.map(shid => s"[shuffleId = $shid]
").getOrElse("")
+ val queryIdMsg = queryId.map(qid => s"[queryId = $qid] ").getOrElse("")
+ s"$queryIdMsg$shuffleIdMsg$stageIdMsg$taskIdMsg$msg"
+ }
+
+ override protected def logInfo(msg: => String): Unit =
+ super.logInfo(formatMessage(msg))
+
+ override protected def logInfo(entry: LogEntry): Unit =
+ super.logInfo(log"${MDC(LogKeys.STREAMING_QUERY_ID,
queryId.getOrElse(""))} " +
+ log"${MDC(LogKeys.SHUFFLE_ID, shuffleId.getOrElse(-1))} " + entry)
+
+ override protected def logWarning(msg: => String): Unit =
+ super.logWarning(formatMessage(msg))
+
+ override protected def logWarning(entry: LogEntry): Unit =
+ super.logWarning(log"${MDC(LogKeys.STREAMING_QUERY_ID,
queryId.getOrElse(""))} " +
+ log"${MDC(LogKeys.SHUFFLE_ID, shuffleId.getOrElse(-1))} " + entry)
+
+ override protected def logDebug(msg: => String): Unit =
+ super.logDebug(formatMessage(msg))
+
+ override protected def logError(msg: => String): Unit =
+ super.logError(formatMessage(msg))
+
+ override protected def logError(entry: LogEntry): Unit =
+ super.logError(log"${MDC(LogKeys.STREAMING_QUERY_ID,
queryId.getOrElse(""))} " +
+ log"${MDC(LogKeys.SHUFFLE_ID, shuffleId.getOrElse(-1))} " + entry)
+
+ override protected def logError(entry: LogEntry, throwable: Throwable): Unit
=
+ super.logError(log"${MDC(LogKeys.STREAMING_QUERY_ID,
queryId.getOrElse(""))} " +
+ log"${MDC(LogKeys.SHUFFLE_ID, shuffleId.getOrElse(-1))} " + entry,
throwable)
+
+ override protected def logError(msg: => String, throwable: Throwable): Unit =
+ super.logError(formatMessage(msg), throwable)
+
+ protected case class LogThrottler(logFn: String => Unit, interval: Duration)
{
+ private var nextLogNanos = Long.MinValue
+ private var suppressed = 0
+
+ def apply(msg: => MessageWithContext): Unit = {
+ val now = System.nanoTime()
+ if (now >= nextLogNanos) {
+ val suffix = if (suppressed > 0) s" ($suppressed suppressed)" else ""
+ logFn(msg.message + suffix)
+ nextLogNanos = now + interval.toNanos
+ suppressed = 0
+ } else {
+ suppressed += 1
+ }
+ }
+ }
+}
diff --git
a/core/src/test/scala/org/apache/spark/shuffle/streaming/MultiShuffleManagerSuite.scala
b/core/src/test/scala/org/apache/spark/shuffle/streaming/MultiShuffleManagerSuite.scala
new file mode 100644
index 000000000000..9d9e4ce1c99a
--- /dev/null
+++
b/core/src/test/scala/org/apache/spark/shuffle/streaming/MultiShuffleManagerSuite.scala
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.streaming
+
+import java.util.Properties
+
+import org.scalatest.matchers.should.Matchers
+
+import org.apache.spark._
+import org.apache.spark.LocalSparkContext.withSpark
+import org.apache.spark.internal.config.SHUFFLE_MANAGER
+import
org.apache.spark.shuffle.streaming.MultiShuffleManager.{isStreamingShuffleEnabled,
STREAMING_SHUFFLE_ENABLED_PROPERTY}
+
+class MultiShuffleManagerSuite
+ extends SparkFunSuite
+ with LocalSparkContext
+ with Matchers {
+
+ test("isStreamingShuffleEnabled reflects the per-query property") {
+ val props = new Properties()
+ isStreamingShuffleEnabled(props) should be(false)
+
+ props.setProperty(STREAMING_SHUFFLE_ENABLED_PROPERTY, "true")
+ isStreamingShuffleEnabled(props) should be(true)
+
+ props.setProperty(STREAMING_SHUFFLE_ENABLED_PROPERTY, "false")
+ isStreamingShuffleEnabled(props) should be(false)
+ }
+
+ private def assertRoutesToStreaming(enabled: Boolean): Unit = {
+ withSpark(new SparkContext("local", "MultiShuffleManagerSuite", new
SparkConf())) { sc =>
+ if (enabled) {
+ sc.setLocalProperty(STREAMING_SHUFFLE_ENABLED_PROPERTY, "true")
+ }
+ val rdd = sc.parallelize(1 to 4).map(x => (x, x))
+ val dep = new ShuffleDependency[Int, Int, Int](rdd, new
HashPartitioner(2))
+ val handle = new MultiShuffleManager(sc.conf).registerShuffle(7, dep)
+ assert(handle.isInstanceOf[StreamingShuffleHandle[_, _, _]] == enabled)
+ }
+ }
+
+ test("registerShuffle routes to the streaming manager when enabled for the
query") {
+ assertRoutesToStreaming(enabled = true)
+ }
+
+ test("registerShuffle routes to the sort manager when not enabled for the
query") {
+ assertRoutesToStreaming(enabled = false)
+ }
+
+ test("SparkEnv initializes the streaming shuffle tracker when
MultiShuffleManager is set") {
+ val conf = new SparkConf().set(SHUFFLE_MANAGER,
classOf[MultiShuffleManager].getName)
+ withSpark(new SparkContext("local", "MultiShuffleManagerSuite", conf)) { _
=>
+ assert(SparkEnv.get.streamingShuffleOutputTracker.isDefined)
+ }
+ }
+}
diff --git
a/core/src/test/scala/org/apache/spark/shuffle/streaming/StreamingShuffleManagerSuite.scala
b/core/src/test/scala/org/apache/spark/shuffle/streaming/StreamingShuffleManagerSuite.scala
new file mode 100644
index 000000000000..181d779e8bb5
--- /dev/null
+++
b/core/src/test/scala/org/apache/spark/shuffle/streaming/StreamingShuffleManagerSuite.scala
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.streaming
+
+import io.netty.buffer.Unpooled
+import org.mockito.Mockito.when
+import org.scalatest.matchers.should.Matchers
+import org.scalatestplus.mockito.MockitoSugar
+
+import org.apache.spark._
+import org.apache.spark.LocalSparkContext.withSpark
+import org.apache.spark.internal.config.SHUFFLE_MANAGER
+import org.apache.spark.network.shuffle.streaming.{DataMessage,
TerminationAckMessage, TerminationControlMessage}
+import org.apache.spark.shuffle.sort.SortShuffleManager
+import org.apache.spark.shuffle.streaming.StreamingShuffleManager.{getQueryId,
getWriterId, QUERY_ID_PROPERTY_KEY}
+
+class StreamingShuffleManagerSuite
+ extends SparkFunSuite
+ with LocalSparkContext
+ with Matchers
+ with MockitoSugar {
+
+ private val SQL_EXECUTION_ID_KEY = "spark.sql.execution.id"
+
+ // ---- getWriterId ----
+
+ test("getWriterId returns the writer id for a data message") {
+ val msg = new DataMessage(7, 3, 0, Unpooled.EMPTY_BUFFER, 0L)
+ getWriterId(msg) should be(7)
+ }
+
+ test("getWriterId returns the writer id for a termination control message") {
+ getWriterId(new TerminationControlMessage(5, 2)) should be(5)
+ }
+
+ test("getWriterId throws on an unexpected message type") {
+ val e = intercept[SparkRuntimeException] {
+ getWriterId(new TerminationAckMessage(1, 1))
+ }
+ checkError(
+ e,
+ condition = "STREAMING_SHUFFLE_UNEXPECTED_MESSAGE_TYPE",
+ parameters = Map("messageType" -> "TERMINATION_ACK_MESSAGE"))
+ }
+
+ // ---- getQueryId ----
+
+ test("getQueryId returns the streaming query id when set") {
+ val context = mock[TaskContext]
+
when(context.getLocalProperty(QUERY_ID_PROPERTY_KEY)).thenReturn("query-123")
+ getQueryId(context) should be("query-123")
+ }
+
+ test("getQueryId falls back to the SQL execution id for batch queries") {
+ val context = mock[TaskContext]
+ when(context.getLocalProperty(SQL_EXECUTION_ID_KEY)).thenReturn("42")
+ getQueryId(context) should be("42")
+ }
+
+ test("getQueryId throws when no query id property is set") {
+ val context = mock[TaskContext]
+ val e = intercept[SparkException] {
+ getQueryId(context)
+ }
+ checkError(
+ e,
+ condition = "INTERNAL_ERROR",
+ parameters = Map("message" ->
+ "Streaming shuffle requires the query id or SQL execution id local
property to be set"))
+ }
+
+ // ---- registerShuffle ----
+
+ test("registerShuffle returns a StreamingShuffleHandle") {
+ withSpark(new SparkContext("local", "StreamingShuffleManagerSuite", new
SparkConf())) { sc =>
+ val rdd = sc.parallelize(1 to 4).map(x => (x, x))
+ val dep = new ShuffleDependency[Int, Int, Int](rdd, new
HashPartitioner(2))
+ val handle = new StreamingShuffleManager().registerShuffle(0, dep)
+ assert(handle.isInstanceOf[StreamingShuffleHandle[_, _, _]])
+ }
+ }
+
+ // ---- SparkEnv tracker initialization gating ----
+
+ private def assertTrackerInitialized(shuffleManager: Option[String],
expectPresent: Boolean):
+ Unit = {
+ val conf = new SparkConf()
+ shuffleManager.foreach(conf.set(SHUFFLE_MANAGER, _))
+ withSpark(new SparkContext("local", "StreamingShuffleManagerSuite", conf))
{ _ =>
+ val tracker = SparkEnv.get.streamingShuffleOutputTracker
+ assert(tracker.isDefined == expectPresent)
+ // On the driver a present tracker is always the master.
+ if (expectPresent) {
+ assert(tracker.get.isInstanceOf[StreamingShuffleOutputTrackerMaster])
+ }
+ }
+ }
+
+ test("SparkEnv initializes the streaming shuffle tracker for
StreamingShuffleManager") {
+ assertTrackerInitialized(Some(classOf[StreamingShuffleManager].getName),
expectPresent = true)
+ }
+
+ test("SparkEnv does not initialize the tracker for a non-streaming (sort)
manager") {
+ assertTrackerInitialized(Some(classOf[SortShuffleManager].getName),
expectPresent = false)
+ }
+
+ test("SparkEnv does not initialize the tracker for the default manager") {
+ assertTrackerInitialized(None, expectPresent = false)
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]