This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 78592a070c27 [SPARK-50615][SQL] Push variant into scan
78592a070c27 is described below
commit 78592a070c279b2aa181c5a7f84adff8a2fc0e74
Author: Chenhao Li <[email protected]>
AuthorDate: Fri Dec 20 13:03:33 2024 +0800
[SPARK-50615][SQL] Push variant into scan
### What changes were proposed in this pull request?
It adds an optimizer rule to push variant into scan by rewriting the
variant type with a struct type producing all requested fields and rewriting
the variant extraction expressions by struct accesses. This will be the
foundation of the variant shredding reader. The rule must be disabled at this
point because the scan part is not yet able to recognize the special struct.
### Why are the changes needed?
It is necessary for the performance of reading from shredded variant. With
this rule (and the reader implemented), the scan only needs to fetch the
necessary shredded columns required by the plan.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Unit test.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #49235 from chenhao-db/PushVariantIntoScan.
Authored-by: Chenhao Li <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../org/apache/spark/sql/internal/SQLConf.scala | 9 +
.../spark/sql/execution/SparkOptimizer.scala | 5 +-
.../datasources/PushVariantIntoScan.scala | 340 +++++++++++++++++++++
.../datasources/PushVariantIntoScanSuite.scala | 178 +++++++++++
4 files changed, 530 insertions(+), 2 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 31282d43bbce..306058fb3681 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -4635,6 +4635,15 @@ object SQLConf {
.booleanConf
.createWithDefault(true)
+ val PUSH_VARIANT_INTO_SCAN =
+ buildConf("spark.sql.variant.pushVariantIntoScan")
+ .internal()
+ .doc("When true, replace variant type in the scan schema with a struct
containing " +
+ "requested fields.")
+ .version("4.0.0")
+ .booleanConf
+ .createWithDefault(false)
+
val LEGACY_CSV_ENABLE_DATE_TIME_PARSING_FALLBACK =
buildConf("spark.sql.legacy.csv.enableDateTimeParsingFallback")
.internal()
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
index 6ceb363b41ae..a51870cfd7fd 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.optimizer._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.catalog.CatalogManager
-import org.apache.spark.sql.execution.datasources.{PruneFileSourcePartitions,
SchemaPruning, V1Writes}
+import org.apache.spark.sql.execution.datasources.{PruneFileSourcePartitions,
PushVariantIntoScan, SchemaPruning, V1Writes}
import
org.apache.spark.sql.execution.datasources.v2.{GroupBasedRowLevelOperationScanPlanning,
OptimizeMetadataOnlyDeleteFromTable, V2ScanPartitioningAndOrdering,
V2ScanRelationPushDown, V2Writes}
import
org.apache.spark.sql.execution.dynamicpruning.{CleanupDynamicPruningFilters,
PartitionPruning, RowLevelOperationRuntimeGroupFiltering}
import
org.apache.spark.sql.execution.python.{ExtractGroupingPythonUDFFromAggregate,
ExtractPythonUDFFromAggregate, ExtractPythonUDFs, ExtractPythonUDTFs}
@@ -43,7 +43,8 @@ class SparkOptimizer(
V2ScanRelationPushDown,
V2ScanPartitioningAndOrdering,
V2Writes,
- PruneFileSourcePartitions)
+ PruneFileSourcePartitions,
+ PushVariantIntoScan)
override def preCBORules: Seq[Rule[LogicalPlan]] =
Seq(OptimizeMetadataOnlyDeleteFromTable)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala
new file mode 100644
index 000000000000..83d219c28983
--- /dev/null
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala
@@ -0,0 +1,340 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources
+
+import scala.collection.mutable.HashMap
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.variant.{VariantGet,
VariantPathParser}
+import org.apache.spark.sql.catalyst.planning.PhysicalOperation
+import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan,
Project, Subquery}
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns
+import org.apache.spark.sql.errors.QueryExecutionErrors
+import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
+
+// A metadata class of a struct field. All struct fields in a struct must
either all have this
+// metadata, or all don't have it.
+// We define a "variant struct" as: a special struct with its fields annotated
with this metadata.
+// It indicates that the struct should produce all requested fields of a
variant type, and should be
+// treated specially by the scan.
+case class VariantMetadata(
+ // The `path` parameter of VariantGet. It has the same format as a JSON
path, except that
+ // `[*]` is not supported.
+ path: String,
+ failOnError: Boolean,
+ timeZoneId: String) {
+ // Produce a metadata contain one key-value pair. The key is the special
`METADATA_KEY`.
+ // The value contains three key-value pairs for `path`, `failOnError`, and
`timeZoneId`.
+ def toMetadata: Metadata =
+ new MetadataBuilder().putMetadata(
+ VariantMetadata.METADATA_KEY,
+ new MetadataBuilder()
+ .putString(VariantMetadata.PATH_KEY, path)
+ .putBoolean(VariantMetadata.FAIL_ON_ERROR_KEY, failOnError)
+ .putString(VariantMetadata.TIME_ZONE_ID_KEY, timeZoneId)
+ .build()
+ ).build()
+
+ def parsedPath(): Array[VariantPathParser.PathSegment] = {
+ VariantPathParser.parse(path).getOrElse {
+ val name = if (failOnError) "variant_get" else "try_variant_get"
+ throw QueryExecutionErrors.invalidVariantGetPath(path, name)
+ }
+ }
+}
+
+object VariantMetadata {
+ val METADATA_KEY = "__VARIANT_METADATA_KEY"
+ val PATH_KEY = "path"
+ val FAIL_ON_ERROR_KEY = "failOnError"
+ val TIME_ZONE_ID_KEY = "timeZoneId"
+
+ def isVariantStruct(s: StructType): Boolean =
+ s.fields.length > 0 && s.fields.forall(_.metadata.contains(METADATA_KEY))
+
+ def isVariantStruct(t: DataType): Boolean = t match {
+ case s: StructType => isVariantStruct(s)
+ case _ => false
+ }
+
+ // Parse the `VariantMetadata` from a metadata produced by `toMetadata`.
+ def fromMetadata(metadata: Metadata): VariantMetadata = {
+ val value = metadata.getMetadata(METADATA_KEY)
+ VariantMetadata(
+ value.getString(PATH_KEY),
+ value.getBoolean(FAIL_ON_ERROR_KEY),
+ value.getString(TIME_ZONE_ID_KEY)
+ )
+ }
+}
+
+// Represent a requested field of a variant that the scan should produce.
+// Each `RequestedVariantField` is corresponded to a variant path extraction
in the plan.
+case class RequestedVariantField(path: VariantMetadata, targetType: DataType)
+
+object RequestedVariantField {
+ def fullVariant: RequestedVariantField =
+ RequestedVariantField(VariantMetadata("$", failOnError = true, "UTC"),
VariantType)
+
+ def apply(v: VariantGet): RequestedVariantField =
+ RequestedVariantField(
+ VariantMetadata(v.path.eval().toString, v.failOnError,
v.timeZoneId.get), v.dataType)
+
+ def apply(c: Cast): RequestedVariantField =
+ RequestedVariantField(
+ VariantMetadata("$", c.evalMode != EvalMode.TRY, c.timeZoneId.get),
c.dataType)
+}
+
+// Extract a nested struct access path. Return the (root attribute id, a
sequence of ordinals to
+// access the field). For non-nested attribute access, the sequence is empty.
+object StructPath {
+ def unapply(expr: Expression): Option[(ExprId, Seq[Int])] = expr match {
+ case GetStructField(StructPath(root, path), ordinal, _) => Some((root,
path :+ ordinal))
+ case a: Attribute => Some(a.exprId, Nil)
+ case _ => None
+ }
+}
+
+// A collection of all eligible variants in a relation, which are in the root
of the relation output
+// schema, or only nested in struct types.
+// The user should:
+// 1. Call `addVariantFields` to add all eligible variants in a relation.
+// 2. Call `collectRequestedFields` on all expressions depending on the
relation. This process will
+// add the requested fields of each variant and potentially remove
non-eligible variants. See
+// `collectRequestedFields` for details.
+// 3. Call `rewriteType` to produce a new output schema for the relation.
+// 4. Call `rewriteExpr` to rewrite the previously visited expressions by
replacing variant
+// extractions with struct accessed.
+class VariantInRelation {
+ // First level key: root attribute id.
+ // Second level key: struct access paths to the variant type.
+ // Third level key: requested fields of a variant type.
+ // Final value: the ordinal of a requested field in the final struct of
requested fields.
+ val mapping = new HashMap[ExprId, HashMap[Seq[Int],
HashMap[RequestedVariantField, Int]]]
+
+ // Extract the SQL-struct path where the leaf is a variant.
+ object StructPathToVariant {
+ def unapply(expr: Expression): Option[HashMap[RequestedVariantField, Int]]
= expr match {
+ case StructPath(attrId, path) =>
+ mapping.get(attrId).flatMap(_.get(path))
+ case _ => None
+ }
+ }
+
+ // Find eligible variants recursively. `attrId` is the root attribute id.
+ // `path` is the current struct access path. `dataType` is the child data
type after extracting
+ // `path` from the root attribute struct.
+ def addVariantFields(
+ attrId: ExprId,
+ dataType: DataType,
+ defaultValue: Any,
+ path: Seq[Int]): Unit = {
+ dataType match {
+ // TODO(SHREDDING): non-null default value is not yet supported.
+ case _: VariantType if defaultValue == null =>
+ mapping.getOrElseUpdate(attrId, new HashMap).put(path, new HashMap)
+ case s: StructType if !VariantMetadata.isVariantStruct(s) =>
+ val row = defaultValue.asInstanceOf[InternalRow]
+ for ((field, idx) <- s.fields.zipWithIndex) {
+ val fieldDefault = if (row == null || row.isNullAt(idx)) {
+ null
+ } else {
+ row.get(idx, field.dataType)
+ }
+ addVariantFields(attrId, field.dataType, fieldDefault, path :+ idx)
+ }
+ case _ =>
+ }
+ }
+
+ def rewriteType(attrId: ExprId, dataType: DataType, path: Seq[Int]):
DataType = {
+ dataType match {
+ case _: VariantType =>
+ mapping.get(attrId).flatMap(_.get(path)) match {
+ case Some(fields) =>
+ var requestedFields = fields.toArray.sortBy(_._2).map { case
(field, ordinal) =>
+ StructField(ordinal.toString, field.targetType, metadata =
field.path.toMetadata)
+ }
+ // Avoid producing an empty struct of requested fields. This is
intended to simplify the
+ // scan implementation, which may not be able to handle empty
struct type. This happens
+ // if the variant is not used, or only used in `IsNotNull/IsNull`
expressions. The value
+ // of the placeholder field doesn't matter, even if the scan
source accidentally
+ // contains such a field.
+ if (requestedFields.isEmpty) {
+ val placeholder = VariantMetadata("$.__placeholder_field__",
+ failOnError = false, timeZoneId = "UTC")
+ requestedFields = Array(StructField("0", BooleanType,
+ metadata = placeholder.toMetadata))
+ }
+ StructType(requestedFields)
+ case _ => dataType
+ }
+ case s: StructType if !VariantMetadata.isVariantStruct(s) =>
+ val newFields = s.fields.zipWithIndex.map { case (field, idx) =>
+ field.copy(dataType = rewriteType(attrId, field.dataType, path :+
idx))
+ }
+ StructType(newFields)
+ case _ => dataType
+ }
+ }
+
+ // Add a requested field to a variant column.
+ private def addField(
+ map: HashMap[RequestedVariantField, Int],
+ field: RequestedVariantField): Unit = {
+ val idx = map.size
+ map.getOrElseUpdate(field, idx)
+ }
+
+ // Update `mapping` with any access to a variant. Add the requested fields
of each variant and
+ // potentially remove non-eligible variants.
+ // If a struct containing a variant is directly used, this variant is not
eligible for push down.
+ // This is because we need to replace the variant type with a struct
producing all requested
+ // fields, which also changes the struct type containing it, and it is
difficult to reconstruct
+ // the original struct value. This is not a big loss, because we need the
full variant anyway.
+ def collectRequestedFields(expr: Expression): Unit = expr match {
+ case v@VariantGet(StructPathToVariant(fields), _, _, _, _) =>
+ addField(fields, RequestedVariantField(v))
+ case c@Cast(StructPathToVariant(fields), _, _, _) => addField(fields,
RequestedVariantField(c))
+ case IsNotNull(StructPath(_, _)) | IsNull(StructPath(_, _)) =>
+ case StructPath(attrId, path) =>
+ mapping.get(attrId) match {
+ case Some(variants) =>
+ variants.get(path) match {
+ case Some(fields) =>
+ addField(fields, RequestedVariantField.fullVariant)
+ case _ =>
+ // Remove non-eligible variants.
+ variants.filterInPlace { case (key, _) => !key.startsWith(path) }
+ }
+ case _ =>
+ }
+ case _ => expr.children.foreach(collectRequestedFields)
+ }
+
+ def rewriteExpr(
+ expr: Expression,
+ attributeMap: Map[ExprId, AttributeReference]): Expression = {
+ def rewriteAttribute(expr: Expression): Expression = expr.transformDown {
+ case a: Attribute => attributeMap.getOrElse(a.exprId, a)
+ }
+
+ // Rewrite patterns should be consistent with visit patterns in
`collectRequestedFields`.
+ expr.transformDown {
+ case g@VariantGet(v@StructPathToVariant(fields), _, _, _, _) =>
+ // Rewrite the attribute in advance, rather than depending on the last
branch to rewrite it.
+ // Ww need to avoid the `v@StructPathToVariant(fields)` branch to
rewrite the child again.
+ GetStructField(rewriteAttribute(v), fields(RequestedVariantField(g)))
+ case c@Cast(v@StructPathToVariant(fields), _, _, _) =>
+ GetStructField(rewriteAttribute(v), fields(RequestedVariantField(c)))
+ case i@IsNotNull(StructPath(_, _)) => rewriteAttribute(i)
+ case i@IsNull(StructPath(_, _)) => rewriteAttribute(i)
+ case v@StructPathToVariant(fields) =>
+ GetStructField(rewriteAttribute(v),
fields(RequestedVariantField.fullVariant))
+ case a: Attribute => attributeMap.getOrElse(a.exprId, a)
+ }
+ }
+}
+
+// Push variant into scan by rewriting the variant type with a struct type
producing all requested
+// fields and rewriting the variant extraction expressions by struct accesses.
+// For example, for an input plan:
+// - Project [v:a::int, v:b::string, v]
+// - Filter [v:a::int = 1]
+// - Relation [v: variant]
+// Rewrite it as:
+// - Project [v.0, v.1, v.2]
+// - Filter [v.0 = 1]
+// - Relation [v: struct<0: int, 1: string, 2: variant>]
+// The struct fields are annotated with `VariantMetadata` to indicate the
extraction path.
+object PushVariantIntoScan extends Rule[LogicalPlan] {
+ override def apply(plan: LogicalPlan): LogicalPlan = plan match {
+ // A correlated subquery will be rewritten into join later, and will go
through this rule
+ // eventually.
+ case s: Subquery if s.correlated => plan
+ case _ if !SQLConf.get.getConf(SQLConf.PUSH_VARIANT_INTO_SCAN) => plan
+ case _ => plan.transformDown {
+ case p@PhysicalOperation(projectList, filters,
+ relation @ LogicalRelationWithTable(
+ hadoopFsRelation@HadoopFsRelation(_, _, _, _, _: ParquetFileFormat, _),
_)) =>
+ rewritePlan(p, projectList, filters, relation, hadoopFsRelation)
+ }
+ }
+
+ private def rewritePlan(
+ originalPlan: LogicalPlan,
+ projectList: Seq[NamedExpression],
+ filters: Seq[Expression],
+ relation: LogicalRelation,
+ hadoopFsRelation: HadoopFsRelation): LogicalPlan = {
+ val variants = new VariantInRelation
+ val defaultValues =
ResolveDefaultColumns.existenceDefaultValues(hadoopFsRelation.schema)
+ // I'm not aware of any case that an attribute `relation.output` can have
a different data type
+ // than the corresponding field in `hadoopFsRelation.schema`. Other code
seems to prefer using
+ // the data type in `hadoopFsRelation.schema`, let's also stick to it.
+ val schemaWithAttributes =
hadoopFsRelation.schema.fields.zip(relation.output)
+ for (((f, attr), defaultValue) <- schemaWithAttributes.zip(defaultValues))
{
+ variants.addVariantFields(attr.exprId, f.dataType, defaultValue, Nil)
+ }
+ if (variants.mapping.isEmpty) return originalPlan
+
+ projectList.foreach(variants.collectRequestedFields)
+ filters.foreach(variants.collectRequestedFields)
+ // `collectRequestedFields` may have removed all variant columns.
+ if (variants.mapping.forall(_._2.isEmpty)) return originalPlan
+
+ val (newFields, newOutput) = schemaWithAttributes.map {
+ case (f, attr) =>
+ if (variants.mapping.get(attr.exprId).exists(_.nonEmpty)) {
+ val newType = variants.rewriteType(attr.exprId, f.dataType, Nil)
+ val newAttr = AttributeReference(f.name, newType, f.nullable,
f.metadata)()
+ (f.copy(dataType = newType), newAttr)
+ } else {
+ (f, attr)
+ }
+ }.unzip
+
+ val newHadoopFsRelation = hadoopFsRelation.copy(dataSchema =
StructType(newFields))(
+ hadoopFsRelation.sparkSession)
+ val newRelation = relation.copy(relation = newHadoopFsRelation, output =
newOutput.toIndexedSeq)
+
+ val attributeMap = relation.output.zip(newOutput).map {
+ case (oldAttr, newAttr) => oldAttr.exprId -> newAttr
+ }.toMap
+ val withFilter = if (filters.nonEmpty) {
+ Filter(filters.map(variants.rewriteExpr(_, attributeMap)).reduce(And),
newRelation)
+ } else {
+ newRelation
+ }
+ val newProjectList = projectList.map { e =>
+ val rewritten = variants.rewriteExpr(e, attributeMap)
+ rewritten match {
+ case n: NamedExpression => n
+ // This is when the variant column is directly selected. We replace
the attribute reference
+ // with a struct access, which is not a `NamedExpression` that
`Project` requires. We wrap
+ // it with an `Alias`.
+ case _ => Alias(rewritten, e.name)(e.exprId, e.qualifier)
+ }
+ }
+ Project(newProjectList, withFilter)
+ }
+}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScanSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScanSuite.scala
new file mode 100644
index 000000000000..2a866dcd66f0
--- /dev/null
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScanSuite.scala
@@ -0,0 +1,178 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources
+
+import org.apache.spark.SparkConf
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.variant._
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types._
+
+class PushVariantIntoScanSuite extends SharedSparkSession {
+ override def sparkConf: SparkConf =
+ super.sparkConf.set(SQLConf.PUSH_VARIANT_INTO_SCAN.key, "true")
+
+ private def localTimeZone = spark.sessionState.conf.sessionLocalTimeZone
+
+ // Return a `StructField` with the expected `VariantMetadata`.
+ private def field(ordinal: Int, dataType: DataType, path: String,
+ failOnError: Boolean = true, timeZone: String =
localTimeZone): StructField =
+ StructField(ordinal.toString, dataType,
+ metadata = VariantMetadata(path, failOnError, timeZone).toMetadata)
+
+ // Validate an `Alias` expression has the expected name and child.
+ private def checkAlias(expr: Expression, expectedName: String, expected:
Expression): Unit = {
+ expr match {
+ case Alias(child, name) =>
+ assert(name == expectedName)
+ assert(child == expected)
+ case _ => fail()
+ }
+ }
+
+ private def testOnFormats(fn: String => Unit): Unit = {
+ for (format <- Seq("PARQUET")) {
+ test("test - " + format) {
+ withTable("T") {
+ fn(format)
+ }
+ }
+ }
+ }
+
+ testOnFormats { format =>
+ sql("create table T (v variant, vs struct<v1 variant, v2 variant, i int>,
" +
+ "va array<variant>, vd variant default parse_json('1')) " +
+ s"using $format")
+
+ sql("select variant_get(v, '$.a', 'int') as a, v, cast(v as struct<b
float>) as v from T")
+ .queryExecution.optimizedPlan match {
+ case Project(projectList, l: LogicalRelation) =>
+ val output = l.output
+ val v = output(0)
+ checkAlias(projectList(0), "a", GetStructField(v, 0))
+ checkAlias(projectList(1), "v", GetStructField(v, 1))
+ checkAlias(projectList(2), "v", GetStructField(v, 2))
+ assert(v.dataType == StructType(Array(
+ field(0, IntegerType, "$.a"),
+ field(1, VariantType, "$", timeZone = "UTC"),
+ field(2, StructType(Array(StructField("b", FloatType))), "$"))))
+ case _ => fail()
+ }
+
+ sql("select 1 from T where isnotnull(v)")
+ .queryExecution.optimizedPlan match {
+ case Project(projectList, Filter(condition, l: LogicalRelation)) =>
+ val output = l.output
+ val v = output(0)
+ checkAlias(projectList(0), "1", Literal(1))
+ assert(condition == IsNotNull(v))
+ assert(v.dataType == StructType(Array(
+ field(0, BooleanType, "$.__placeholder_field__", failOnError =
false, timeZone = "UTC"))))
+ case _ => fail()
+ }
+
+ sql("select variant_get(v, '$.a', 'int') + 1 as a, try_variant_get(v,
'$.b', 'string') as b " +
+ "from T where variant_get(v, '$.a', 'int') =
1").queryExecution.optimizedPlan match {
+ case Project(projectList, Filter(condition, l: LogicalRelation)) =>
+ val output = l.output
+ val v = output(0)
+ checkAlias(projectList(0), "a", Add(GetStructField(v, 0), Literal(1)))
+ checkAlias(projectList(1), "b", GetStructField(v, 1))
+ assert(condition == And(IsNotNull(v), EqualTo(GetStructField(v, 0),
Literal(1))))
+ assert(v.dataType == StructType(Array(
+ field(0, IntegerType, "$.a"),
+ field(1, StringType, "$.b", failOnError = false))))
+ case _ => fail()
+ }
+
+ sql("select variant_get(vs.v1, '$.a', 'int') as a, variant_get(vs.v1,
'$.b', 'int') as b, " +
+ "variant_get(vs.v2, '$.a', 'int') as a, vs.i from
T").queryExecution.optimizedPlan match {
+ case Project(projectList, l: LogicalRelation) =>
+ val output = l.output
+ val vs = output(1)
+ val v1 = GetStructField(vs, 0, Some("v1"))
+ val v2 = GetStructField(vs, 1, Some("v2"))
+ checkAlias(projectList(0), "a", GetStructField(v1, 0))
+ checkAlias(projectList(1), "b", GetStructField(v1, 1))
+ checkAlias(projectList(2), "a", GetStructField(v2, 0))
+ checkAlias(projectList(3), "i", GetStructField(vs, 2, Some("i")))
+ assert(vs.dataType == StructType(Array(
+ StructField("v1", StructType(Array(
+ field(0, IntegerType, "$.a"), field(1, IntegerType, "$.b")))),
+ StructField("v2", StructType(Array(field(0, IntegerType, "$.a")))),
+ StructField("i", IntegerType))))
+ case _ => fail()
+ }
+
+ def variantGet(child: Expression): Expression = VariantGet(
+ child,
+ path = Literal("$.a"),
+ targetType = VariantType,
+ failOnError = true,
+ timeZoneId = Some(localTimeZone))
+
+ // No push down if the struct containing variant is used.
+ sql("select vs, variant_get(vs.v1, '$.a') as a from
T").queryExecution.optimizedPlan match {
+ case Project(projectList, l: LogicalRelation) =>
+ val output = l.output
+ val vs = output(1)
+ assert(projectList(0) == vs)
+ checkAlias(projectList(1), "a", variantGet(GetStructField(vs, 0,
Some("v1"))))
+ assert(vs.dataType == StructType(Array(
+ StructField("v1", VariantType),
+ StructField("v2", VariantType),
+ StructField("i", IntegerType))))
+ case _ => fail()
+ }
+
+ // No push down for variant in array.
+ sql("select variant_get(va[0], '$.a') as a from
T").queryExecution.optimizedPlan match {
+ case Project(projectList, l: LogicalRelation) =>
+ val output = l.output
+ val va = output(2)
+ checkAlias(projectList(0), "a", variantGet(GetArrayItem(va,
Literal(0))))
+ assert(va.dataType == ArrayType(VariantType))
+ case _ => fail()
+ }
+
+ // No push down if variant has default value.
+ sql("select variant_get(vd, '$.a') as a from
T").queryExecution.optimizedPlan match {
+ case Project(projectList, l: LogicalRelation) =>
+ val output = l.output
+ val vd = output(3)
+ checkAlias(projectList(0), "a", variantGet(vd))
+ assert(vd.dataType == VariantType)
+ case _ => fail()
+ }
+ }
+
+ test("No push down for JSON") {
+ withTable("T") {
+ sql("create table T (v variant) using JSON")
+ sql("select variant_get(v, '$.a') from T").queryExecution.optimizedPlan
match {
+ case Project(_, l: LogicalRelation) =>
+ val output = l.output
+ assert(output(0).dataType == VariantType)
+ case _ => fail()
+ }
+ }
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]