cloud-fan commented on code in PR #49235:
URL: https://github.com/apache/spark/pull/49235#discussion_r1891418567


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala:
##########
@@ -0,0 +1,334 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources
+
+import scala.collection.mutable.HashMap
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.variant.{VariantGet, 
VariantPathParser}
+import org.apache.spark.sql.catalyst.planning.PhysicalOperation
+import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, 
Project, Subquery}
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns
+import org.apache.spark.sql.errors.QueryExecutionErrors
+import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
+
+// A metadata class of a struct field. All struct fields in a struct must 
either all have this
+// metadata, or all don't have it.
+// We define a "variant struct" as: a special struct with its fields annotated 
with this metadata.
+// It indicates that the struct should produce all requested fields of a 
variant type, and should be
+// treated specially by the scan.
+case class VariantMetadata(
+    // The `path` parameter of VariantGet. It has the same format as a JSON 
path, except that
+    // `[*]` is not supported.
+    path: String,
+    failOnError: Boolean,
+    timeZoneId: String) {
+  // Produce a metadata contain one key-value pair. The key is the special 
`METADATA_KEY`.
+  // The value contains three key-value pairs for `path`, `failOnError`, and 
`timeZoneId`.
+  def toMetadata: Metadata =
+    new MetadataBuilder().putMetadata(
+      VariantMetadata.METADATA_KEY,
+      new MetadataBuilder()
+        .putString(VariantMetadata.PATH_KEY, path)
+        .putBoolean(VariantMetadata.FAIL_ON_ERROR_KEY, failOnError)
+        .putString(VariantMetadata.TIME_ZONE_ID_KEY, timeZoneId)
+        .build()
+    ).build()
+
+  def parsedPath(): Array[VariantPathParser.PathSegment] = {
+    VariantPathParser.parse(path).getOrElse {
+      val name = if (failOnError) "variant_get" else "try_variant_get"
+      throw QueryExecutionErrors.invalidVariantGetPath(path, name)
+    }
+  }
+}
+
+object VariantMetadata {
+  val METADATA_KEY = "__VARIANT_METADATA_KEY"
+  val PATH_KEY = "path"
+  val FAIL_ON_ERROR_KEY = "failOnError"
+  val TIME_ZONE_ID_KEY = "timeZoneId"
+
+  def isVariantStruct(s: StructType): Boolean =
+    s.fields.length > 0 && s.fields.forall(_.metadata.contains(METADATA_KEY))
+
+  def isVariantStruct(t: DataType): Boolean = t match {
+    case s: StructType => isVariantStruct(s)
+    case _ => false
+  }
+
+  // Parse the `VariantMetadata` from a metadata produced by `toMetadata`.
+  def fromMetadata(metadata: Metadata): VariantMetadata = {
+    val value = metadata.getMetadata(METADATA_KEY)
+    VariantMetadata(
+      value.getString(PATH_KEY),
+      value.getBoolean(FAIL_ON_ERROR_KEY),
+      value.getString(TIME_ZONE_ID_KEY)
+    )
+  }
+}
+
+// Represent a requested field of a variant that the scan should produce.
+// Each `RequestedVariantField` is corresponded to a variant path extraction 
in the plan.
+case class RequestedVariantField(path: VariantMetadata, targetType: DataType)
+
+object RequestedVariantField {
+  def fullVariant: RequestedVariantField =
+    RequestedVariantField(VariantMetadata("$", failOnError = true, "UTC"), 
VariantType)
+
+  def apply(v: VariantGet): RequestedVariantField =
+    RequestedVariantField(
+      VariantMetadata(v.path.eval().toString, v.failOnError, 
v.timeZoneId.get), v.dataType)
+
+  def apply(c: Cast): RequestedVariantField =
+    RequestedVariantField(
+      VariantMetadata("$", c.evalMode != EvalMode.TRY, c.timeZoneId.get), 
c.dataType)
+}
+
+// Extract a nested struct access path. Return the (root attribute id, a 
sequence of ordinals to
+// access the field). For non-nested attribute access, the sequence is empty.
+object StructPath {
+  def unapply(expr: Expression): Option[(ExprId, Seq[Int])] = expr match {
+    case GetStructField(StructPath(root, path), ordinal, _) => Some((root, 
path :+ ordinal))
+    case a: Attribute => Some(a.exprId, Nil)
+    case _ => None
+  }
+}
+
+// A collection of all eligible variants in a relation, which are in the root 
of the relation output
+// schema, or only nested in struct types.
+// The user should:
+// 1. Call `addVariantFields` to add all eligible variants in a relation.
+// 2. Call `collectRequestedFields` on all expressions depending on the 
relation. This process will
+// add the requested fields of each variant and potentially remove 
non-eligible variants. See
+// `collectRequestedFields` for details.
+// 3. Call `rewriteType` to produce a new output schema for the relation.
+// 4. Call `rewriteExpr` to rewrite the previously visited expressions by 
replacing variant
+// extractions with struct accessed.
+class VariantInRelation {
+  // First level key: root attribute id.
+  // Second level key: struct access paths to the variant type.
+  // Third level key: requested fields of a variant type.
+  // Final value: the ordinal of a requested field in the final struct of 
requested fields.
+  val mapping = new HashMap[ExprId, HashMap[Seq[Int], 
HashMap[RequestedVariantField, Int]]]
+
+  // Extract the SQL-struct path where the leaf is a variant.
+  object StructPathToVariant {
+    def unapply(expr: Expression): Option[HashMap[RequestedVariantField, Int]] 
= expr match {
+      case StructPath(attrId, path) =>
+        mapping.get(attrId).flatMap(_.get(path))
+      case _ => None
+    }
+  }
+
+  // Find eligible variants recursively. `attrId` is the root attribute id.
+  // `path` is the current struct access path. `dataType` is the child data 
type after extracting
+  // `path` from the root attribute struct.
+  def addVariantFields(attrId: ExprId, dataType: DataType, defaultValue: Any,
+                       path: Seq[Int]): Unit = {

Review Comment:
   ```suggestion
     def addVariantFields(
         attrId: ExprId,
         dataType: DataType,
         defaultValue: Any,
         path: Seq[Int]): Unit = {
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to