This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch branch-4.1
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.1 by this push:
new c27faeb283ce [SPARK-53998][TESTS] Add addition E2E tests for RTM
c27faeb283ce is described below
commit c27faeb283cea190185b5d86e415024575fa72a9
Author: Jerry Peng <[email protected]>
AuthorDate: Wed Dec 10 09:10:40 2025 -0800
[SPARK-53998][TESTS] Add addition E2E tests for RTM
### What changes were proposed in this pull request?
Add some additional end to end tests for RTM
### Why are the changes needed?
To have better test coverage for RTM functionality
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
N/A. Only tests are added
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #52870 from jerrypeng/SPARK-53998-2.
Authored-by: Jerry Peng <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
(cherry picked from commit 7df7dad29a45f447ed57ba190ecedd3f0feaec17)
Signed-off-by: Dongjoon Hyun <[email protected]>
---
.../sql/streaming/StreamRealTimeModeE2ESuite.scala | 393 +++++++++++++++++++++
.../streaming/StreamRealTimeModeSuiteBase.scala | 6 +
2 files changed, 399 insertions(+)
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamRealTimeModeE2ESuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamRealTimeModeE2ESuite.scala
new file mode 100644
index 000000000000..3615edc75cb2
--- /dev/null
+++
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamRealTimeModeE2ESuite.scala
@@ -0,0 +1,393 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming
+
+import java.util.concurrent.ConcurrentLinkedQueue
+
+import scala.collection.mutable
+
+import org.scalatest.time.SpanSugar._
+
+import org.apache.spark.SparkContext
+import org.apache.spark.sql.{ForeachWriter, Row}
+import org.apache.spark.sql.execution.datasources.v2.LowLatencyClock
+import org.apache.spark.sql.execution.streaming.LowLatencyMemoryStream
+import org.apache.spark.sql.execution.streaming.runtime.StreamingQueryWrapper
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.streaming.util.GlobalSingletonManualClock
+import org.apache.spark.sql.test.TestSparkSession
+import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
+
+class StreamRealTimeModeE2ESuite extends StreamRealTimeModeE2ESuiteBase {
+
+ import testImplicits._
+
+ override protected def createSparkSession =
+ new TestSparkSession(
+ new SparkContext(
+ "local[15]",
+ "streaming-rtm-e2e-context",
+ sparkConf.set("spark.sql.shuffle.partitions", "5")
+ )
+ )
+
+ private def runForeachTest(withUnion: Boolean): Unit = {
+ var query: StreamingQuery = null
+ try {
+ withTempDir { checkpointDir =>
+ val clock = new GlobalSingletonManualClock()
+ LowLatencyClock.setClock(clock)
+ val uniqueSinkName = if (withUnion) {
+ sinkName + "-union"
+ } else {
+ sinkName
+ }
+
+ val read = LowLatencyMemoryStream[(String, Int)](5)
+ val read1 = LowLatencyMemoryStream[(String, Int)](5)
+ val dataframe = if (withUnion) {
+ read.toDF().union(read1.toDF())
+ } else {
+ read.toDF()
+ }
+
+ query = dataframe
+ .select(col("_1").as("key"), col("_2").as("value"))
+ .select(
+ concat(
+ col("key").cast("STRING"),
+ lit("-"),
+ col("value").cast("STRING")
+ ).as("output")
+ )
+ .writeStream
+ .outputMode(OutputMode.Update())
+ .foreach(new ForeachWriter[Row] {
+ private var batchPartitionId: String = null
+ private val processedThisBatch = new
ConcurrentLinkedQueue[String]()
+ override def open(partitionId: Long, epochId: Long): Boolean = {
+ ResultsCollector
+ .computeIfAbsent(uniqueSinkName, (_) => new
ConcurrentLinkedQueue[String]())
+ batchPartitionId = s"$uniqueSinkName-$epochId-$partitionId"
+ assert(
+ !ResultsCollector.containsKey(batchPartitionId),
+ s"should NOT contain batchPartitionId ${batchPartitionId}"
+ )
+ ResultsCollector
+ .put(batchPartitionId, new ConcurrentLinkedQueue[String]())
+ true
+ }
+
+ override def process(value: Row): Unit = {
+ val v = value.getAs[String]("output")
+ ResultsCollector.get(uniqueSinkName).add(v)
+ processedThisBatch.add(v)
+ }
+
+ override def close(errorOrNull: Throwable): Unit = {
+
+ assert(
+ ResultsCollector.containsKey(batchPartitionId),
+ s"should contain batchPartitionId ${batchPartitionId}"
+ )
+ ResultsCollector.get(batchPartitionId).addAll(processedThisBatch)
+ processedThisBatch.clear()
+ }
+ })
+ .option("checkpointLocation", checkpointDir.getName)
+ .queryName("foreach")
+ // doesn't matter the batch duration set here since we are going
+ // to manually control batch durations via manual clock
+ .trigger(defaultTrigger)
+ .start()
+
+ val expectedResults = mutable.ListBuffer[String]()
+ val expectedResultsByBatch = mutable.HashMap[Int,
mutable.ListBuffer[String]]()
+
+ val numRows = 10
+ for (i <- 0 until 3) {
+ expectedResultsByBatch(i) = new mutable.ListBuffer[String]()
+ for (key <- List("a", "b", "c")) {
+ for (j <- 1 to numRows) {
+ read.addData((key, 1))
+ val data = s"$key-1"
+ expectedResults += data
+ expectedResultsByBatch(i) += data
+ }
+ }
+
+ if (withUnion) {
+ for (key <- List("d", "e", "f")) {
+ for (j <- 1 to numRows) {
+ read1.addData((key, 2))
+ val data = s"$key-2"
+ expectedResults += data
+ expectedResultsByBatch(i) += data
+ }
+ }
+ }
+
+ eventually(timeout(60.seconds)) {
+ ResultsCollector
+ .get(uniqueSinkName)
+ .toArray(new
Array[String](ResultsCollector.get(uniqueSinkName).size()))
+ .toList
+ .sorted should equal(expectedResults.sorted)
+ }
+
+ clock.advance(defaultTrigger.batchDurationMs)
+ eventually(timeout(60.seconds)) {
+ query
+ .asInstanceOf[StreamingQueryWrapper]
+ .streamingQuery
+ .getLatestExecutionContext()
+ .batchId should be(i + 1)
+ query.lastProgress.sources(0).numInputRows should be(numRows * 3)
+
+ val commitedResults = new mutable.ListBuffer[String]()
+ val numPartitions = if (withUnion) 10 else 5
+ for (v <- 0 until numPartitions) {
+ val it =
ResultsCollector.get(s"$uniqueSinkName-${i}-$v").iterator()
+ while (it.hasNext) {
+ commitedResults += it.next()
+ }
+ }
+
+ commitedResults.sorted should
equal(expectedResultsByBatch(i).sorted)
+ }
+ }
+ }
+ } finally {
+ if (query != null) {
+ query.stop()
+ }
+ }
+ }
+
+ private def runMapPartitionsTest(withUnion: Boolean): Unit = {
+ var query: StreamingQuery = null
+ try {
+ withTempDir { checkpointDir =>
+ val clock = new GlobalSingletonManualClock()
+ LowLatencyClock.setClock(clock)
+ val uniqueSinkName = if (withUnion) {
+ sinkName + "mapPartitions-union"
+ } else {
+ sinkName + "mapPartitions"
+ }
+
+ val read = LowLatencyMemoryStream[(String, Int)](5)
+ val read1 = LowLatencyMemoryStream[(String, Int)](5)
+ val dataframe = if (withUnion) {
+ read.toDF().union(read1.toDF())
+ } else {
+ read.toDF()
+ }
+
+ val df = dataframe
+ .select(col("_1").as("key"), col("_2").as("value"))
+ .select(
+ concat(
+ col("key").cast("STRING"),
+ lit("-"),
+ col("value").cast("STRING")
+ ).as("output")
+ )
+ .as[String]
+ .mapPartitions(rows => {
+ rows.map(row => {
+ val collector = ResultsCollector
+ .computeIfAbsent(uniqueSinkName, (_) => new
ConcurrentLinkedQueue[String]())
+ collector.add(row)
+ row
+ })
+ })
+ .toDF()
+
+ query = runStreamingQuery(sinkName, df)
+
+ val expectedResults = mutable.ListBuffer[String]()
+ val expectedResultsByBatch = mutable.HashMap[Int,
mutable.ListBuffer[String]]()
+
+ val numRows = 10
+ for (i <- 0 until 3) {
+ expectedResultsByBatch(i) = new mutable.ListBuffer[String]()
+ for (key <- List("a", "b", "c")) {
+ for (j <- 1 to numRows) {
+ read.addData((key, 1))
+ val data = s"$key-1"
+ expectedResults += data
+ expectedResultsByBatch(i) += data
+ }
+ }
+
+ if (withUnion) {
+ for (key <- List("d", "e", "f")) {
+ for (j <- 1 to numRows) {
+ read1.addData((key, 2))
+ val data = s"$key-2"
+ expectedResults += data
+ expectedResultsByBatch(i) += data
+ }
+ }
+ }
+
+ // results collected from mapPartitions
+ eventually(timeout(60.seconds)) {
+ ResultsCollector
+ .get(uniqueSinkName)
+ .toArray(new
Array[String](ResultsCollector.get(uniqueSinkName).size()))
+ .toList
+ .sorted should equal(expectedResults.sorted)
+ }
+
+ // results collected from foreach sink
+ eventually(timeout(60.seconds)) {
+ ResultsCollector
+ .get(sinkName)
+ .toArray(new
Array[String](ResultsCollector.get(sinkName).size()))
+ .toList
+ .sorted should equal(expectedResults.sorted)
+ }
+
+ clock.advance(defaultTrigger.batchDurationMs)
+ eventually(timeout(60.seconds)) {
+ query
+ .asInstanceOf[StreamingQueryWrapper]
+ .streamingQuery
+ .getLatestExecutionContext()
+ .batchId should be(i + 1)
+ query.lastProgress.sources(0).numInputRows should be(numRows * 3)
+ }
+ }
+ }
+ } finally {
+ if (query != null) {
+ query.stop()
+ }
+ }
+ }
+
+ test("foreach") {
+ runForeachTest(withUnion = false)
+ }
+
+ test("union - foreach") {
+ runForeachTest(withUnion = true)
+ }
+
+ test("mapPartitions") {
+ runMapPartitionsTest(withUnion = false)
+ }
+
+ test("union - mapPartitions") {
+ runMapPartitionsTest(withUnion = true)
+ }
+
+ test("scala stateless UDF") {
+ val myUDF = (id: Int) => id + 1
+ val udf = spark.udf.register("myUDF", myUDF)
+ val (read, clock) = createMemoryStream()
+
+ val df = read
+ .toDF()
+ .select(col("_1").as("key"), udf(col("_2")).as("value_plus_1"))
+ .select(concat(col("key"), lit("-"),
col("value_plus_1").cast("STRING")).as("output"))
+
+ var query: StreamingQuery = null
+ try {
+ query = runStreamingQuery("scala_udf", df)
+ processBatches(query, read, clock, 10, 3, (key, value) =>
Array(s"$key-${value + 1}"))
+ } finally {
+ if (query != null) query.stop()
+ }
+ }
+
+ test("stream static join") {
+ val (read, clock) = createMemoryStream()
+ val staticDf = spark
+ .range(1, 31, 1, 10)
+ .selectExpr("id AS join_key", "id AS join_value")
+ // This will produce HashAggregateExec which should not be blocked by
allowList
+ // since it's the batch subquery
+ .groupBy("join_key")
+ .agg(max($"join_value").as("join_value"))
+
+ val df = read
+ .toDF()
+ .select(col("_1").as("key"), col("_2").as("value"))
+ .join(staticDf, col("value") === col("join_key"))
+ .select(concat(col("key"), lit("-"), col("value"), lit("-"),
col("join_value")).as("output"))
+
+ var query: StreamingQuery = null
+ try {
+ query = runStreamingQuery("stream_static_join", df)
+ processBatches(query, read, clock, 10, 3, (key, value) =>
Array(s"$key-$value-$value"))
+ } finally {
+ if (query != null) query.stop()
+ }
+ }
+
+ test("to_json and from_json round-trip") {
+ val (read, clock) = createMemoryStream()
+ val schema = new StructType().add("key", StringType).add("value",
IntegerType)
+
+ val df = read
+ .toDF()
+ .select(struct(col("_1").as("key"), col("_2").as("value")).as("json"))
+ .select(from_json(to_json(col("json")), schema).as("json"))
+ .select(concat(col("json.key"), lit("-"), col("json.value")))
+
+ var query: StreamingQuery = null
+ try {
+ query = runStreamingQuery("json_roundtrip", df)
+ processBatches(query, read, clock, 10, 3, (key, value) =>
Array(s"$key-$value"))
+ } finally {
+ if (query != null) query.stop()
+ }
+ }
+
+ test("generateExec passthrough") {
+ val (read, clock) = createMemoryStream()
+
+ val df = read
+ .toDF()
+ .select(col("_1").as("key"), col("_2").as("value"))
+ .withColumn("value_array", array(col("value"), -col("value")))
+ df.createOrReplaceTempView("tempView")
+ val explodeDF =
+ spark
+ .sql("select key, explode(value_array) as exploded_value from
tempView")
+ .select(concat(col("key"), lit("-"),
col("exploded_value").cast("STRING")).as("output"))
+
+ var query: StreamingQuery = null
+ try {
+ query = runStreamingQuery("generateExec_passthrough", explodeDF)
+ processBatches(
+ query,
+ read,
+ clock,
+ 10,
+ 3,
+ (key, value) => Array(s"$key-$value", s"$key--$value")
+ )
+ } finally {
+ if (query != null) query.stop()
+ }
+ }
+}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamRealTimeModeSuiteBase.scala
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamRealTimeModeSuiteBase.scala
index 9199580f6587..7ec5d8e51f09 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamRealTimeModeSuiteBase.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamRealTimeModeSuiteBase.scala
@@ -32,6 +32,7 @@ import
org.apache.spark.sql.execution.streaming.runtime.StreamingQueryWrapper
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.util.GlobalSingletonManualClock
import org.apache.spark.sql.test.TestSparkSession
+import org.apache.spark.util.SystemClock
/**
* Base class for tests that require real-time mode.
@@ -45,6 +46,11 @@ trait StreamRealTimeModeSuiteBase extends StreamTest with
Matchers {
defaultTrigger.batchDurationMs)
}
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ LowLatencyClock.setClock(new SystemClock)
+ }
+
override protected def createSparkSession = new TestSparkSession(
new SparkContext(
"local[10]", // Ensure enough number of cores to ensure concurrent
schedule of all tasks.
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]