This is an automated email from the ASF dual-hosted git repository.
huaxingao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new a35c9f36da63 [SPARK-53805][SQL] Push Variant into DSv2 scan
a35c9f36da63 is described below
commit a35c9f36da63202d68dde70f5ad3058ba7357715
Author: Huaxin Gao <[email protected]>
AuthorDate: Fri Oct 10 12:15:59 2025 -0700
[SPARK-53805][SQL] Push Variant into DSv2 scan
### What changes were proposed in this pull request?
Push Variant into DSv2 scan
### Why are the changes needed?
with the change, DSV2 scan only needs to fetch the necessary shredded
columns required by the plan
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
new tests
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #52522 from huaxingao/variant-v2-pushdown.
Authored-by: Huaxin Gao <[email protected]>
Signed-off-by: Huaxin Gao <[email protected]>
---
.../spark/sql/execution/SparkOptimizer.scala | 4 +-
.../datasources/PushVariantIntoScan.scala | 109 ++++++++++++---
.../datasources/v2/VariantV2ReadSuite.scala | 148 +++++++++++++++++++++
3 files changed, 242 insertions(+), 19 deletions(-)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
index 8edb59f49282..9699d8a2563f 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
@@ -40,11 +40,11 @@ class SparkOptimizer(
SchemaPruning,
GroupBasedRowLevelOperationScanPlanning,
V1Writes,
+ PushVariantIntoScan,
V2ScanRelationPushDown,
V2ScanPartitioningAndOrdering,
V2Writes,
- PruneFileSourcePartitions,
- PushVariantIntoScan)
+ PruneFileSourcePartitions)
override def preCBORules: Seq[Rule[LogicalPlan]] =
Seq(OptimizeMetadataOnlyDeleteFromTable)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala
index 5960cf8c38ce..6ce53e3367c4 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala
@@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
+import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -279,6 +280,8 @@ object PushVariantIntoScan extends Rule[LogicalPlan] {
relation @ LogicalRelationWithTable(
hadoopFsRelation@HadoopFsRelation(_, _, _, _, _: ParquetFileFormat, _),
_)) =>
rewritePlan(p, projectList, filters, relation, hadoopFsRelation)
+ case p@PhysicalOperation(projectList, filters, relation:
DataSourceV2Relation) =>
+ rewriteV2RelationPlan(p, projectList, filters, relation)
}
}
@@ -288,23 +291,91 @@ object PushVariantIntoScan extends Rule[LogicalPlan] {
filters: Seq[Expression],
relation: LogicalRelation,
hadoopFsRelation: HadoopFsRelation): LogicalPlan = {
- val variants = new VariantInRelation
-
val schemaAttributes = relation.resolve(hadoopFsRelation.dataSchema,
hadoopFsRelation.sparkSession.sessionState.analyzer.resolver)
- val defaultValues =
ResolveDefaultColumns.existenceDefaultValues(StructType(
- schemaAttributes.map(a => StructField(a.name, a.dataType, a.nullable,
a.metadata))))
- for ((a, defaultValue) <- schemaAttributes.zip(defaultValues)) {
- variants.addVariantFields(a.exprId, a.dataType, defaultValue, Nil)
+
+ // Collect variant fields from the relation output
+ val variants = collectAndRewriteVariants(schemaAttributes)
+ if (variants.mapping.isEmpty) return originalPlan
+
+ // Collect requested fields from projections and filters
+ projectList.foreach(variants.collectRequestedFields)
+ filters.foreach(variants.collectRequestedFields)
+ // `collectRequestedFields` may have removed all variant columns.
+ if (variants.mapping.forall(_._2.isEmpty)) return originalPlan
+
+ // Build attribute map with rewritten types
+ val attributeMap = buildAttributeMap(schemaAttributes, variants)
+
+ // Build new schema with variant types replaced by struct types
+ val newFields = schemaAttributes.map { a =>
+ val dataType = attributeMap(a.exprId).dataType
+ StructField(a.name, dataType, a.nullable, a.metadata)
}
+ // Update relation output attributes with new types
+ val newOutput = relation.output.map(a => attributeMap.getOrElse(a.exprId,
a))
+
+ // Update HadoopFsRelation's data schema so the file source reads the
struct columns
+ val newHadoopFsRelation = hadoopFsRelation.copy(dataSchema =
StructType(newFields))(
+ hadoopFsRelation.sparkSession)
+ val newRelation = relation.copy(relation = newHadoopFsRelation, output =
newOutput.toIndexedSeq)
+
+ // Build filter and project with rewritten expressions
+ buildFilterAndProject(newRelation, projectList, filters, variants,
attributeMap)
+ }
+
+ private def rewriteV2RelationPlan(
+ originalPlan: LogicalPlan,
+ projectList: Seq[NamedExpression],
+ filters: Seq[Expression],
+ relation: DataSourceV2Relation): LogicalPlan = {
+
+ // Collect variant fields from the relation output
+ val variants = collectAndRewriteVariants(relation.output)
if (variants.mapping.isEmpty) return originalPlan
+ // Collect requested fields from projections and filters
projectList.foreach(variants.collectRequestedFields)
filters.foreach(variants.collectRequestedFields)
// `collectRequestedFields` may have removed all variant columns.
if (variants.mapping.forall(_._2.isEmpty)) return originalPlan
- val attributeMap = schemaAttributes.map { a =>
+ // Build attribute map with rewritten types
+ val attributeMap = buildAttributeMap(relation.output, variants)
+
+ // Update relation output attributes with new types
+ // Note: DSv2 doesn't need to update the schema in the relation itself.
The schema will be
+ // communicated to the data source later via
V2ScanRelationPushDown.pruneColumns() API.
+ val newOutput = relation.output.map(a => attributeMap.getOrElse(a.exprId,
a))
+ val newRelation = relation.copy(output = newOutput.toIndexedSeq)
+
+ // Build filter and project with rewritten expressions
+ buildFilterAndProject(newRelation, projectList, filters, variants,
attributeMap)
+ }
+
+ /**
+ * Collect variant fields and return initialized VariantInRelation.
+ */
+ private def collectAndRewriteVariants(
+ schemaAttributes: Seq[Attribute]): VariantInRelation = {
+ val variants = new VariantInRelation
+ val defaultValues =
ResolveDefaultColumns.existenceDefaultValues(StructType(
+ schemaAttributes.map(a => StructField(a.name, a.dataType, a.nullable,
a.metadata))))
+
+ for ((a, defaultValue) <- schemaAttributes.zip(defaultValues)) {
+ variants.addVariantFields(a.exprId, a.dataType, defaultValue, Nil)
+ }
+
+ variants
+ }
+
+ /**
+ * Build attribute map with rewritten variant types.
+ */
+ private def buildAttributeMap(
+ schemaAttributes: Seq[Attribute],
+ variants: VariantInRelation): Map[ExprId, AttributeReference] = {
+ schemaAttributes.map { a =>
if (variants.mapping.get(a.exprId).exists(_.nonEmpty)) {
val newType = variants.rewriteType(a.exprId, a.dataType, Nil)
val newAttr = AttributeReference(a.name, newType, a.nullable,
a.metadata)(
@@ -316,21 +387,24 @@ object PushVariantIntoScan extends Rule[LogicalPlan] {
(a.exprId, a.asInstanceOf[AttributeReference])
}
}.toMap
- val newFields = schemaAttributes.map { a =>
- val dataType = attributeMap(a.exprId).dataType
- StructField(a.name, dataType, a.nullable, a.metadata)
- }
- val newOutput = relation.output.map(a => attributeMap.getOrElse(a.exprId,
a))
+ }
- val newHadoopFsRelation = hadoopFsRelation.copy(dataSchema =
StructType(newFields))(
- hadoopFsRelation.sparkSession)
- val newRelation = relation.copy(relation = newHadoopFsRelation, output =
newOutput.toIndexedSeq)
+ /**
+ * Build the final Project(Filter(relation)) plan with rewritten expressions.
+ */
+ private def buildFilterAndProject(
+ relation: LogicalPlan,
+ projectList: Seq[NamedExpression],
+ filters: Seq[Expression],
+ variants: VariantInRelation,
+ attributeMap: Map[ExprId, AttributeReference]): LogicalPlan = {
val withFilter = if (filters.nonEmpty) {
- Filter(filters.map(variants.rewriteExpr(_, attributeMap)).reduce(And),
newRelation)
+ Filter(filters.map(variants.rewriteExpr(_, attributeMap)).reduce(And),
relation)
} else {
- newRelation
+ relation
}
+
val newProjectList = projectList.map { e =>
val rewritten = variants.rewriteExpr(e, attributeMap)
rewritten match {
@@ -341,6 +415,7 @@ object PushVariantIntoScan extends Rule[LogicalPlan] {
case _ => Alias(rewritten, e.name)(e.exprId, e.qualifier)
}
}
+
Project(newProjectList, withFilter)
}
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/VariantV2ReadSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/VariantV2ReadSuite.scala
new file mode 100644
index 000000000000..a6521dfe76da
--- /dev/null
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/VariantV2ReadSuite.scala
@@ -0,0 +1,148 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.v2
+
+import org.apache.spark.sql.QueryTest
+import org.apache.spark.sql.execution.datasources.VariantMetadata
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.{IntegerType, StringType, StructType,
VariantType}
+
+class VariantV2ReadSuite extends QueryTest with SharedSparkSession {
+
+ private val testCatalogClass =
"org.apache.spark.sql.connector.catalog.InMemoryTableCatalog"
+
+ private def withV2Catalog(f: => Unit): Unit = {
+ withSQLConf(
+ SQLConf.DEFAULT_CATALOG.key -> "testcat",
+ s"spark.sql.catalog.testcat" -> testCatalogClass,
+ SQLConf.USE_V1_SOURCE_LIST.key -> "",
+ SQLConf.PUSH_VARIANT_INTO_SCAN.key -> "true",
+ SQLConf.VARIANT_ALLOW_READING_SHREDDED.key -> "true") {
+ f
+ }
+ }
+
+ test("DSV2: push variant_get fields") {
+ withV2Catalog {
+ sql("DROP TABLE IF EXISTS testcat.ns.users")
+ sql(
+ """CREATE TABLE testcat.ns.users (
+ | id bigint,
+ | name string,
+ | v variant,
+ | vd variant default parse_json('1')
+ |) USING parquet""".stripMargin)
+
+ val out = sql(
+ """
+ |SELECT
+ | id,
+ | variant_get(v, '$.username', 'string') as username,
+ | variant_get(v, '$.age', 'int') as age
+ |FROM testcat.ns.users
+ |WHERE variant_get(v, '$.status', 'string') = 'active'
+ |""".stripMargin)
+
+ checkAnswer(out, Seq.empty)
+
+ // Verify variant column rewrite
+ val optimized = out.queryExecution.optimizedPlan
+ val relOutput = optimized.collectFirst {
+ case s: DataSourceV2ScanRelation => s.output
+ }.getOrElse(fail("Expected DSv2 relation in optimized plan"))
+
+ val vAttr = relOutput.find(_.name == "v").getOrElse(fail("Missing 'v'
column"))
+ vAttr.dataType match {
+ case s: StructType =>
+ assert(s.fields.length == 3,
+ s"Expected 3 fields (username, age, status), got
${s.fields.length}")
+
assert(s.fields.forall(_.metadata.contains(VariantMetadata.METADATA_KEY)),
+ "All fields should have VariantMetadata")
+
+ val paths = s.fields.map(f =>
VariantMetadata.fromMetadata(f.metadata).path).toSet
+ assert(paths == Set("$.username", "$.age", "$.status"),
+ s"Expected username, age, status paths, got: $paths")
+
+ val fieldTypes = s.fields.map(_.dataType).toSet
+ assert(fieldTypes.contains(StringType), "Expected StringType for
string fields")
+ assert(fieldTypes.contains(IntegerType), "Expected IntegerType for
age")
+
+ case other =>
+ fail(s"Expected StructType for 'v', got: $other")
+ }
+
+ // Verify variant with default value is NOT rewritten
+ relOutput.find(_.name == "vd").foreach { vdAttr =>
+ assert(vdAttr.dataType == VariantType,
+ "Variant column with default value should not be rewritten")
+ }
+ }
+ }
+
+ test("DSV2: nested column pruning for variant struct") {
+ withV2Catalog {
+ sql("DROP TABLE IF EXISTS testcat.ns.users2")
+ sql(
+ """CREATE TABLE testcat.ns.users2 (
+ | id bigint,
+ | name string,
+ | v variant
+ |) USING parquet""".stripMargin)
+
+ val out = sql(
+ """
+ |SELECT id, variant_get(v, '$.username', 'string') as username
+ |FROM testcat.ns.users2
+ |""".stripMargin)
+
+ checkAnswer(out, Seq.empty)
+
+ val scan = out.queryExecution.executedPlan.collectFirst {
+ case b: BatchScanExec => b.scan
+ }.getOrElse(fail("Expected BatchScanExec in physical plan"))
+
+ val readSchema = scan.readSchema()
+
+ // Verify 'v' field exists and is a struct
+ val vField = readSchema.fields.find(_.name == "v").getOrElse(
+ fail("Expected 'v' field in read schema")
+ )
+
+ vField.dataType match {
+ case s: StructType =>
+ assert(s.fields.length == 1,
+ "Expected only 1 field ($.username) in pruned schema, got " +
s.fields.length + ": " +
+ s.fields.map(f =>
VariantMetadata.fromMetadata(f.metadata).path).mkString(", "))
+
+ val field = s.fields(0)
+ assert(field.metadata.contains(VariantMetadata.METADATA_KEY),
+ "Field should have VariantMetadata")
+
+ val metadata = VariantMetadata.fromMetadata(field.metadata)
+ assert(metadata.path == "$.username",
+ "Expected path '$.username', got '" + metadata.path + "'")
+ assert(field.dataType == StringType,
+ s"Expected StringType, got ${field.dataType}")
+
+ case other =>
+ fail(s"Expected StructType for 'v' after rewrite and pruning, got:
$other")
+ }
+ }
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]