cloud-fan commented on code in PR #49235: URL: https://github.com/apache/spark/pull/49235#discussion_r1891419666
########## sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala: ########## @@ -0,0 +1,334 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import scala.collection.mutable.HashMap + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.variant.{VariantGet, VariantPathParser} +import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project, Subquery} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns +import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ + +// A metadata class of a struct field. All struct fields in a struct must either all have this +// metadata, or all don't have it. +// We define a "variant struct" as: a special struct with its fields annotated with this metadata. +// It indicates that the struct should produce all requested fields of a variant type, and should be +// treated specially by the scan. +case class VariantMetadata( + // The `path` parameter of VariantGet. It has the same format as a JSON path, except that + // `[*]` is not supported. + path: String, + failOnError: Boolean, + timeZoneId: String) { + // Produce a metadata contain one key-value pair. The key is the special `METADATA_KEY`. + // The value contains three key-value pairs for `path`, `failOnError`, and `timeZoneId`. + def toMetadata: Metadata = + new MetadataBuilder().putMetadata( + VariantMetadata.METADATA_KEY, + new MetadataBuilder() + .putString(VariantMetadata.PATH_KEY, path) + .putBoolean(VariantMetadata.FAIL_ON_ERROR_KEY, failOnError) + .putString(VariantMetadata.TIME_ZONE_ID_KEY, timeZoneId) + .build() + ).build() + + def parsedPath(): Array[VariantPathParser.PathSegment] = { + VariantPathParser.parse(path).getOrElse { + val name = if (failOnError) "variant_get" else "try_variant_get" + throw QueryExecutionErrors.invalidVariantGetPath(path, name) + } + } +} + +object VariantMetadata { + val METADATA_KEY = "__VARIANT_METADATA_KEY" + val PATH_KEY = "path" + val FAIL_ON_ERROR_KEY = "failOnError" + val TIME_ZONE_ID_KEY = "timeZoneId" + + def isVariantStruct(s: StructType): Boolean = + s.fields.length > 0 && s.fields.forall(_.metadata.contains(METADATA_KEY)) + + def isVariantStruct(t: DataType): Boolean = t match { + case s: StructType => isVariantStruct(s) + case _ => false + } + + // Parse the `VariantMetadata` from a metadata produced by `toMetadata`. + def fromMetadata(metadata: Metadata): VariantMetadata = { + val value = metadata.getMetadata(METADATA_KEY) + VariantMetadata( + value.getString(PATH_KEY), + value.getBoolean(FAIL_ON_ERROR_KEY), + value.getString(TIME_ZONE_ID_KEY) + ) + } +} + +// Represent a requested field of a variant that the scan should produce. +// Each `RequestedVariantField` is corresponded to a variant path extraction in the plan. +case class RequestedVariantField(path: VariantMetadata, targetType: DataType) + +object RequestedVariantField { + def fullVariant: RequestedVariantField = + RequestedVariantField(VariantMetadata("$", failOnError = true, "UTC"), VariantType) + + def apply(v: VariantGet): RequestedVariantField = + RequestedVariantField( + VariantMetadata(v.path.eval().toString, v.failOnError, v.timeZoneId.get), v.dataType) + + def apply(c: Cast): RequestedVariantField = + RequestedVariantField( + VariantMetadata("$", c.evalMode != EvalMode.TRY, c.timeZoneId.get), c.dataType) +} + +// Extract a nested struct access path. Return the (root attribute id, a sequence of ordinals to +// access the field). For non-nested attribute access, the sequence is empty. +object StructPath { + def unapply(expr: Expression): Option[(ExprId, Seq[Int])] = expr match { + case GetStructField(StructPath(root, path), ordinal, _) => Some((root, path :+ ordinal)) + case a: Attribute => Some(a.exprId, Nil) + case _ => None + } +} + +// A collection of all eligible variants in a relation, which are in the root of the relation output +// schema, or only nested in struct types. +// The user should: +// 1. Call `addVariantFields` to add all eligible variants in a relation. +// 2. Call `collectRequestedFields` on all expressions depending on the relation. This process will +// add the requested fields of each variant and potentially remove non-eligible variants. See +// `collectRequestedFields` for details. +// 3. Call `rewriteType` to produce a new output schema for the relation. +// 4. Call `rewriteExpr` to rewrite the previously visited expressions by replacing variant +// extractions with struct accessed. +class VariantInRelation { + // First level key: root attribute id. + // Second level key: struct access paths to the variant type. + // Third level key: requested fields of a variant type. + // Final value: the ordinal of a requested field in the final struct of requested fields. + val mapping = new HashMap[ExprId, HashMap[Seq[Int], HashMap[RequestedVariantField, Int]]] + + // Extract the SQL-struct path where the leaf is a variant. + object StructPathToVariant { + def unapply(expr: Expression): Option[HashMap[RequestedVariantField, Int]] = expr match { + case StructPath(attrId, path) => + mapping.get(attrId).flatMap(_.get(path)) + case _ => None + } + } + + // Find eligible variants recursively. `attrId` is the root attribute id. + // `path` is the current struct access path. `dataType` is the child data type after extracting + // `path` from the root attribute struct. + def addVariantFields(attrId: ExprId, dataType: DataType, defaultValue: Any, + path: Seq[Int]): Unit = { + dataType match { + // TODO(SHREDDING): non-null default value is not yet supported. + case _: VariantType if defaultValue == null => + mapping.getOrElseUpdate(attrId, new HashMap).put(path, new HashMap) + case s: StructType if !VariantMetadata.isVariantStruct(s) => + val row = defaultValue.asInstanceOf[InternalRow] + for ((field, idx) <- s.fields.zipWithIndex) { + val fieldDefault = if (row == null || row.isNullAt(idx)) { + null + } else { + row.get(idx, field.dataType) + } + addVariantFields(attrId, field.dataType, fieldDefault, path :+ idx) + } + case _ => + } + } + + def rewriteType(attrId: ExprId, dataType: DataType, path: Seq[Int]): DataType = { + dataType match { + case _: VariantType => + mapping.get(attrId).flatMap(_.get(path)) match { + case Some(fields) => + var requestedFields = fields.toArray.sortBy(_._2).map { case (field, ordinal) => + StructField(ordinal.toString, field.targetType, metadata = field.path.toMetadata) + } + // Avoid producing an empty struct of requested fields. This is intended to simplify the + // scan implementation, which may not be able to handle empty struct type. This happens + // if the variant is not used, or only used in `IsNotNull/IsNull` expressions. The value + // of the placeholder field doesn't matter, even if the scan source accidentally + // contains such a field. + if (requestedFields.isEmpty) { + val placeholder = VariantMetadata("$.__placeholder_field__", + failOnError = false, timeZoneId = "UTC") + requestedFields = Array(StructField("0", BooleanType, + metadata = placeholder.toMetadata)) + } + StructType(requestedFields) + case _ => dataType + } + case s: StructType if !VariantMetadata.isVariantStruct(s) => + val newFields = s.fields.zipWithIndex.map { case (field, idx) => + field.copy(dataType = rewriteType(attrId, field.dataType, path :+ idx)) + } + StructType(newFields) + case _ => dataType + } + } + + // Add a requested field to a variant column. + private def addField(map: HashMap[RequestedVariantField, Int], + field: RequestedVariantField): Unit = { Review Comment: ```suggestion private def addField( map: HashMap[RequestedVariantField, Int], field: RequestedVariantField): Unit = { ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
