This is an automated email from the ASF dual-hosted git repository.
gengliangwang pushed a commit to branch branch-4.2
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.2 by this push:
new 2f411c1c2d9c [SPARK-56132][SS] Call pruneColumns on V2 streaming to
fix metadata reading issue
2f411c1c2d9c is described below
commit 2f411c1c2d9c283bdc8f090528e3252bf01a87df
Author: Zikang Han <[email protected]>
AuthorDate: Thu May 28 13:19:21 2026 -0700
[SPARK-56132][SS] Call pruneColumns on V2 streaming to fix metadata reading
issue
### What changes were proposed in this pull request?
In `MicroBatchExecution.logicalPlan`, before calling `build()` on the V2
streaming scan
builder, call
`SupportsPushDownRequiredColumns.pruneColumns(output.toStructType)` if the
builder supports it. `output` is the analyzed relation output, which
already includes any
metadata columns the query references (added by the `AddMetadataColumns`
rule).
### Why are the changes needed?
1. **Metadata column reads in V2 streaming crash with
`ArrayIndexOutOfBoundsException`.**
When a query selects a metadata column (e.g. `_metadata.row_id`) from a
V2 streaming
source that implements both `SupportsMetadataColumns` and
`SupportsPushDownRequiredColumns`,
the analyzed plan expects the metadata column in the scan output, but
`Scan.readSchema()`
does not include it. Spark tries to read a column at an index the scan
never produced.
2. **Root cause: `pruneColumns` is never called in streaming.**
In batch, `V2ScanRelationPushDown` calls
`SupportsPushDownRequiredColumns.pruneColumns`
with the required schema (which includes metadata columns resolved by
`AddMetadataColumns`)
before `build()`. In `MicroBatchExecution.logicalPlan`, the scan is
built directly with
`table.newScanBuilder(options).build()` — no pushdown of any kind is
applied (a
`// TODO: operator pushdown` comment marks this). Connectors that use
`pruneColumns` to
configure `readSchema()` — including whether to produce metadata columns
— are never
informed of what the query needs.
3. **This change fixes metadata column reads only, not column pruning.**
We call `pruneColumns(output.toStructType)` where `output` is the full
analyzed relation
output — all data columns plus any metadata columns added by
`AddMetadataColumns`. This
communicates required metadata columns to the scan builder so they
appear in `readSchema()`,
but does not prune data columns. Full column pruning in streaming, along
with filter and
aggregate pushdown, is deferred to the existing TODO.
### Does this PR introduce _any_ user-facing change?
Yes. Fixes a bug associated with metadata columns.
### How was this patch tested?
Added a test in `DataStreamTableAPISuite`.
### Was this patch authored or co-authored using generative AI tooling?
Yes
Closes #56133 from zikangh/stack/prunecolumns-streaming.
Authored-by: Zikang Han <[email protected]>
Signed-off-by: Gengliang Wang <[email protected]>
(cherry picked from commit 2adc69aca3de7415d011fbe4b3dedaa3a987e8d5)
Signed-off-by: Gengliang Wang <[email protected]>
---
.../sql/connector/catalog/InMemoryBaseTable.scala | 42 ++++++++++---
.../streaming/continuous/ContinuousExecution.scala | 11 +++-
.../streaming/runtime/MicroBatchExecution.scala | 11 +++-
.../spark/sql/connector/MetadataColumnSuite.scala | 26 ++++++++
.../streaming/test/DataStreamTableAPISuite.scala | 73 +++++++++++++++++++++-
5 files changed, 150 insertions(+), 13 deletions(-)
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
index f582f3e408cb..06997662fd8b 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
@@ -545,7 +545,8 @@ abstract class InMemoryBaseTable(
override def json(): String = rowCount.toString
}
- class InMemoryMicroBatchStream extends MicroBatchStream {
+ class InMemoryMicroBatchStream(readSchema: StructType, tableSchema:
StructType)
+ extends MicroBatchStream {
override def initialOffset(): Offset = new InMemoryTableOffset(0)
override def latestOffset(): Offset =
new InMemoryTableOffset(InMemoryBaseTable.this.rows.size.toLong)
@@ -554,14 +555,13 @@ abstract class InMemoryBaseTable(
val e = end.asInstanceOf[InMemoryTableOffset].rowCount.toInt
Array(InMemoryMicroBatchPartition(InMemoryBaseTable.this.rows.slice(s,
e)))
}
- override def createReaderFactory(): PartitionReaderFactory = { partition =>
- val rows = partition.asInstanceOf[InMemoryMicroBatchPartition].rows
- new PartitionReader[InternalRow] {
- private var idx = -1
- override def next(): Boolean = { idx += 1; idx < rows.size }
- override def get(): InternalRow = rows(idx)
- override def close(): Unit = {}
+ override def createReaderFactory(): PartitionReaderFactory = {
+ val metadataColNames = new mutable.ArrayBuffer[String]()
+ readSchema.foreach {
+ case MetadataStructFieldWithLogicalName(_, name) => metadataColNames
+= name
+ case _ =>
}
+ new InMemoryMicroBatchReaderFactory(metadataColNames.toArray)
}
override def deserializeOffset(json: String): Offset = new
InMemoryTableOffset(json.toLong)
override def commit(end: Offset): Unit = {}
@@ -655,7 +655,7 @@ abstract class InMemoryBaseTable(
}
override def toMicroBatchStream(checkpointLocation: String):
MicroBatchStream =
- new InMemoryMicroBatchStream
+ new InMemoryMicroBatchStream(readSchema, tableSchema)
}
case class InMemoryBatchScan(
@@ -954,6 +954,30 @@ class BufferedRows(val key: Seq[Any], val schema:
StructType)
def clear(): Unit = rows.clear()
}
+private class InMemoryMicroBatchReaderFactory(
+ metaNames: Array[String]) extends PartitionReaderFactory with Serializable
{
+ override def createReader(partition: InputPartition):
PartitionReader[InternalRow] = {
+ val rows = partition.asInstanceOf[InMemoryMicroBatchPartition].rows
+ new PartitionReader[InternalRow] {
+ private var idx = -1
+ override def next(): Boolean = { idx += 1; idx < rows.size }
+ override def get(): InternalRow = {
+ val rawRow = rows(idx)
+ if (metaNames.isEmpty) rawRow
+ else {
+ val metaRow = new GenericInternalRow(metaNames.map {
+ case "index" => idx.asInstanceOf[Any]
+ case "_partition" => UTF8String.fromString("").asInstanceOf[Any]
+ case _ => null
+ })
+ new JoinedRow(rawRow, metaRow)
+ }
+ }
+ override def close(): Unit = {}
+ }
+ }
+}
+
object BufferedRows {
def apply(key: Seq[Any], schema: Array[Column]): BufferedRows = {
new BufferedRows(key, CatalogV2Util.v2ColumnsToStructType(schema))
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
index 14cd06038b5a..4c7a8437a46f 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
@@ -33,6 +33,7 @@ import
org.apache.spark.sql.catalyst.trees.TreePattern.CURRENT_LIKE
import org.apache.spark.sql.classic.SparkSession
import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite,
TableCapability}
import org.apache.spark.sql.connector.distributions.UnspecifiedDistribution
+import org.apache.spark.sql.connector.read.SupportsPushDownRequiredColumns
import org.apache.spark.sql.connector.read.streaming.{ContinuousStream,
PartitionOffset, ReadLimit, SparkDataStream}
import org.apache.spark.sql.connector.write.{RequiresDistributionAndOrdering,
Write}
import org.apache.spark.sql.errors.{QueryCompilationErrors,
QueryExecutionErrors}
@@ -92,7 +93,15 @@ class ContinuousExecution(
log"from DataSourceV2 named '${MDC(STREAMING_DATA_SOURCE_NAME,
sourceName)}' " +
log"${MDC(STREAMING_DATA_SOURCE_DESCRIPTION, dsStr)}")
// TODO: operator pushdown.
- val scan = table.newScanBuilder(options).build()
+ // Passes the full output schema (not a pruned subset) so that
connectors
+ // implementing SupportsMetadataColumns can include metadata columns
in readSchema().
+ val scanBuilder = table.newScanBuilder(options)
+ scanBuilder match {
+ case r: SupportsPushDownRequiredColumns =>
+ r.pruneColumns(output.toStructType)
+ case _ =>
+ }
+ val scan = scanBuilder.build()
val stream = scan.toContinuousStream(metadataPath)
val relation = StreamingDataSourceV2Relation(
table, output, catalog, identifier, options, metadataPath)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/MicroBatchExecution.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/MicroBatchExecution.scala
index 726586ac72e6..84f0373ca5d4 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/MicroBatchExecution.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/MicroBatchExecution.scala
@@ -39,6 +39,7 @@ import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.classic.{Dataset, SparkSession}
import org.apache.spark.sql.classic.ClassicConversions.castToImpl
import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite,
TableCapability, TransactionalCatalogPlugin}
+import org.apache.spark.sql.connector.read.SupportsPushDownRequiredColumns
import org.apache.spark.sql.connector.read.streaming.{MicroBatchStream, Offset
=> OffsetV2, ReadLimit, SparkDataStream, SupportsAdmissionControl,
SupportsRealTimeMode, SupportsTriggerAvailableNow}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
@@ -250,7 +251,15 @@ class MicroBatchExecution(
log"from DataSourceV2 named
'${MDC(LogKeys.STREAMING_DATA_SOURCE_NAME, srcName)}' " +
log"${MDC(LogKeys.STREAMING_DATA_SOURCE_DESCRIPTION, dsStr)}")
// TODO: operator pushdown.
- val scan = table.newScanBuilder(options).build()
+ // Passes the full output schema (not a pruned subset) so that
connectors
+ // implementing SupportsMetadataColumns can include metadata
columns in readSchema().
+ val scanBuilder = table.newScanBuilder(options)
+ scanBuilder match {
+ case r: SupportsPushDownRequiredColumns =>
+ r.pruneColumns(output.toStructType)
+ case _ =>
+ }
+ val scan = scanBuilder.build()
val stream = scan.toMicroBatchStream(metadataPath)
val relation = StreamingDataSourceV2Relation(
table,
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/connector/MetadataColumnSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/connector/MetadataColumnSuite.scala
index fe338175ec88..77e3818aafe8 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/connector/MetadataColumnSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/connector/MetadataColumnSuite.scala
@@ -376,6 +376,32 @@ class MetadataColumnSuite extends DatasourceV2SQLBase {
}
}
+ test("SPARK-56132: streaming read of metadata columns from V2 source") {
+ withTable(tbl) {
+ prepareTable()
+ withTempDir { checkpointDir =>
+ // "index" is a metadata column (not in the table schema); "id" and
"data" are data columns.
+ val df = spark.readStream.table(tbl).select("id", "data", "index")
+ val q = df.writeStream
+ .format("memory")
+ .queryName("result_56132")
+ .option("checkpointLocation", checkpointDir.getCanonicalPath)
+ .start()
+ try {
+ q.processAllAvailable()
+ val result = spark.table("result_56132")
+ // Verify data columns arrive correctly and index (metadata) is
non-null.
+ checkAnswer(result.select("id", "data").orderBy("id"),
+ Seq(Row(1, "a"), Row(2, "b"), Row(3, "c")))
+ assert(result.select("index").collect().forall(!_.isNullAt(0)),
+ "index metadata column should be non-null in streaming output")
+ } finally {
+ q.stop()
+ }
+ }
+ }
+ }
+
test("SPARK-43123: Metadata column related field metadata should not be
leaked to catalogs") {
withTable(tbl, "testcat.target") {
prepareTable()
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala
index f10d1cdab0d5..3930beec084d 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala
@@ -30,7 +30,8 @@ import
org.apache.spark.sql.catalyst.streaming.StreamingRelationV2
import org.apache.spark.sql.connector.{FakeV2Provider,
FakeV2ProviderWithCustomSchema, InMemoryTableSessionCatalog}
import org.apache.spark.sql.connector.catalog.{Column, Identifier,
InMemoryTable, InMemoryTableCatalog, MetadataColumn, SupportsMetadataColumns,
SupportsRead, Table, TableCapability, TableInfo, V2TableWithV1Fallback}
import org.apache.spark.sql.connector.expressions.{ClusterByTransform,
FieldReference, Transform}
-import org.apache.spark.sql.connector.read.ScanBuilder
+import org.apache.spark.sql.connector.read.{Scan, ScanBuilder,
SupportsPushDownRequiredColumns}
+import org.apache.spark.sql.connector.read.streaming.MicroBatchStream
import org.apache.spark.sql.execution.streaming.runtime.{MemoryStream,
MemoryStreamScanBuilder, StreamingQueryWrapper}
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.internal.SQLConf
@@ -564,6 +565,42 @@ class DataStreamTableAPISuite extends StreamTest with
BeforeAndAfter {
}
}
+ test("SPARK-56132: pruneColumns called on SupportsPushDownRequiredColumns " +
+ "V2 streaming scan builder") {
+ val tblName = "teststream.table_name"
+ withTable(tblName) {
+ spark.sql(s"CREATE TABLE $tblName (data int) USING foo")
+ val stream = MemoryStream[Int]
+ val testCatalog =
spark.sessionState.catalogManager.catalog("teststream").asTableCatalog
+ val table = testCatalog.loadTable(Identifier.of(Array(), "table_name"))
+ .asInstanceOf[InMemoryStreamTable]
+ table.setStream(stream)
+
+ // Wrap the table's scan builder so we can record pruneColumns calls.
+ val recorded = new PrunedSchemaRecorder
+ table.scanBuilderWrapper = Some(inner => new
RecordingPruneScanBuilder(inner, recorded))
+
+ withTempDir { checkpointDir =>
+ val q = spark.readStream.table(tblName)
+ .select("value", "_seq")
+ .writeStream.format("noop")
+ .option("checkpointLocation", checkpointDir.getCanonicalPath)
+ .start()
+ try {
+ // logicalPlan is initialized lazily when the query thread starts;
wait for it.
+ eventually(timeout(streamingTimeout)) {
+ assert(recorded.called,
+ "pruneColumns should have been called on the streaming scan
builder")
+ }
+ assert(recorded.schema.fieldNames.toSet === Set("value", "_seq"),
+ s"Expected pruneColumns to receive {value, _seq}, got
${recorded.schema}")
+ } finally {
+ q.stop()
+ }
+ }
+ }
+ }
+
private def checkForStreamTable(dir: Option[File], tableName: String): Unit
= {
val memory = MemoryStream[Int]
val dsw = memory.toDS().writeStream.format("parquet")
@@ -683,6 +720,7 @@ class InMemoryStreamTable(override val name: String)
with SupportsRead
with SupportsMetadataColumns {
var stream: MemoryStream[Int] = _
+ var scanBuilderWrapper: Option[MemoryStreamScanBuilder => ScanBuilder] = None
def setStream(inputData: MemoryStream[Int]): Unit = stream = inputData
@@ -693,7 +731,8 @@ class InMemoryStreamTable(override val name: String)
}
override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder
= {
- new MemoryStreamScanBuilder(stream)
+ val inner = new MemoryStreamScanBuilder(stream)
+ scanBuilderWrapper.map(_(inner)).getOrElse(inner)
}
private object SeqColumn extends MetadataColumn {
@@ -705,6 +744,36 @@ class InMemoryStreamTable(override val name: String)
override val metadataColumns: Array[MetadataColumn] = Array(SeqColumn)
}
+class PrunedSchemaRecorder {
+ @volatile var called = false
+ @volatile var schema: StructType = new StructType()
+}
+
+class RecordingPruneScanBuilder(inner: MemoryStreamScanBuilder, recorder:
PrunedSchemaRecorder)
+ extends ScanBuilder
+ with SupportsPushDownRequiredColumns {
+
+ override def pruneColumns(requiredSchema: StructType): Unit = {
+ recorder.called = true
+ recorder.schema = requiredSchema
+ }
+
+ override def build(): Scan = {
+ val innerScan = inner.build()
+ val prunedSchema = recorder.schema
+ // Return a scan whose readSchema() reflects the pruned schema so the
streaming plan
+ // and scan agree on output columns. Without the fix, pruneColumns is
never called and
+ // readSchema() defaults to the full table schema, causing
ArrayIndexOutOfBoundsException
+ // when metadata columns are in the plan output but absent from the scan
output.
+ new Scan {
+ override def readSchema(): StructType =
+ if (recorder.called) prunedSchema else innerScan.readSchema()
+ override def toMicroBatchStream(checkpointLocation: String):
MicroBatchStream =
+ innerScan.toMicroBatchStream(checkpointLocation)
+ }
+ }
+}
+
class NonStreamV2Table(override val name: String)
extends Table with SupportsRead with V2TableWithV1Fallback {
override def schema(): StructType = StructType(Nil)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]