hvanhovell commented on a change in pull request #23397: [SPARK-26495][SQL] 
Simplify the SelectedField extractor.
URL: https://github.com/apache/spark/pull/23397#discussion_r244544141
 
 

 ##########
 File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/SelectedField.scala
 ##########
 @@ -64,71 +63,51 @@ private[execution] object SelectedField {
     selectField(unaliased, None)
   }
 
-  private def selectField(expr: Expression, fieldOpt: Option[StructField]): 
Option[StructField] = {
+  /**
+   * Convert an expression into the parts of the schema (the field) it 
accesses.
+   */
+  private def selectField(expr: Expression, dataTypeOpt: Option[DataType]): 
Option[StructField] = {
     expr match {
-      // No children. Returns a StructField with the attribute name or None if 
fieldOpt is None.
-      case AttributeReference(name, dataType, nullable, metadata) =>
-        fieldOpt.map(field =>
-          StructField(name, wrapStructType(dataType, field), nullable, 
metadata))
-      // Handles case "expr0.field[n]", where "expr0" is of struct type and 
"expr0.field" is of
-      // array type.
-      case GetArrayItem(x @ GetStructFieldObject(child, field @ 
StructField(name,
-          dataType, nullable, metadata)), _) =>
-        val childField = fieldOpt.map(field => StructField(name,
-          wrapStructType(dataType, field), nullable, 
metadata)).getOrElse(field)
-        selectField(child, Some(childField))
-      // Handles case "expr0.field[n]", where "expr0.field" is of array type.
-      case GetArrayItem(child, _) =>
-        selectField(child, fieldOpt)
-      // Handles case "expr0.field.subfield", where "expr0" and "expr0.field" 
are of array type.
-      case GetArrayStructFields(child: GetArrayStructFields,
-          field @ StructField(name, dataType, nullable, metadata), _, _, _) =>
-        val childField = fieldOpt.map(field => StructField(name,
-            wrapStructType(dataType, field),
-            nullable, metadata)).orElse(Some(field))
-        selectField(child, childField)
-      // Handles case "expr0.field", where "expr0" is of array type.
-      case GetArrayStructFields(child,
-          field @ StructField(name, dataType, nullable, metadata), _, _, _) =>
-        val childField =
-          fieldOpt.map(field => StructField(name,
-            wrapStructType(dataType, field),
-            nullable, metadata)).orElse(Some(field))
-        selectField(child, childField)
-      // Handles case "expr0.field[key]", where "expr0" is of struct type and 
"expr0.field" is of
-      // map type.
-      case GetMapValue(x @ GetStructFieldObject(child, field @ 
StructField(name,
-          dataType,
-          nullable, metadata)), _) =>
-        val childField = fieldOpt.map(field => StructField(name,
-          wrapStructType(dataType, field),
-          nullable, metadata)).orElse(Some(field))
-        selectField(child, childField)
-      // Handles case "expr0.field[key]", where "expr0.field" is of map type.
+      case a: Attribute =>
+        dataTypeOpt.map { dt =>
+          StructField(a.name, dt, a.nullable)
+        }
+      case c: GetStructField =>
+        val field = c.childSchema(c.ordinal)
+        val newField = field.copy(dataType = 
dataTypeOpt.getOrElse(field.dataType))
+        selectField(c.child, Option(struct(newField)))
+      case GetArrayStructFields(child, field, _, _, containsNull) =>
+        val newFieldDataType = dataTypeOpt match {
+          case None =>
+            // GetArrayStructFields is the top level extractor. This means its 
result is
+            // not pruned and we need to use the element type of the array its 
producing.
+            field.dataType
+          case Some(ArrayType(dataType, _)) =>
+            // GetArrayStructFields is part of a chain of extractors and its 
result is pruned
+            // by a parent expression. In this case need to use the parent 
element type.
+            dataType
+          case Some(x) =>
+            // This should not happen.
+            throw new AnalysisException(s"DataType '$x' is not supported by 
GetArrayStructFields.")
+        }
+        val newField = StructField(field.name, newFieldDataType, 
field.nullable)
+        selectField(child, Option(ArrayType(struct(newField), containsNull)))
       case GetMapValue(child, _) =>
-        selectField(child, fieldOpt)
-      // Handles case "expr0.field", where expr0 is of struct type.
-      case GetStructFieldObject(child,
-        field @ StructField(name, dataType, nullable, metadata)) =>
-        val childField = fieldOpt.map(field => StructField(name,
-          wrapStructType(dataType, field),
-          nullable, metadata)).orElse(Some(field))
-        selectField(child, childField)
+        // GetMapValue does not select a field from a struct (i.e. prune the 
struct) so it can't be
+        // the top-level extractor. However it can be part of an extractor 
chain.
+        val MapType(keyType, _, valueContainsNull) = child.dataType
+        val opt = dataTypeOpt.map(dt => MapType(keyType, dt, 
valueContainsNull))
+        selectField(child, opt)
+      case GetArrayItem(child, _) =>
+        // GetArrayItem does not select a field from a struct (i.e. prune the 
struct) so it can't be
+        // the top-level extractor. However it can be part of an extractor 
chain.
+        val ArrayType(_, containsNull) = child.dataType
+        val opt = dataTypeOpt.map(dt => ArrayType(dt, containsNull))
+        selectField(child, opt)
       case _ =>
         None
     }
   }
 
-  // Constructs a composition of complex types with a StructType(Array(field)) 
at its core. Returns
-  // a StructType for a StructType, an ArrayType for an ArrayType and a 
MapType for a MapType.
-  private def wrapStructType(dataType: DataType, field: StructField): DataType 
= {
-    dataType match {
-      case _: StructType =>
-        StructType(Array(field))
-      case ArrayType(elementType, containsNull) =>
-        ArrayType(wrapStructType(elementType, field), containsNull)
-      case MapType(keyType, valueType, valueContainsNull) =>
-        MapType(keyType, wrapStructType(valueType, field), valueContainsNull)
-    }
-  }
+  private def struct(field: StructField): StructType = StructType(Array(field))
 
 Review comment:
   No, I just don't like repeating myself. I can move into StructType if you 
have issues with it.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to