Repository: spark
Updated Branches:
  refs/heads/branch-1.6 af86c38db -> 927070d6d


[SPARK-11926][SQL] unify GetStructField and GetInternalRowField

Author: Wenchen Fan <[email protected]>

Closes #9909 from cloud-fan/get-struct.

(cherry picked from commit 19530da6903fa59b051eec69b9c17e231c68454b)
Signed-off-by: Michael Armbrust <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/927070d6
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/927070d6
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/927070d6

Branch: refs/heads/branch-1.6
Commit: 927070d6d75abcaac6d9676b1b9b556f21ec8536
Parents: af86c38
Author: Wenchen Fan <[email protected]>
Authored: Tue Nov 24 11:09:01 2015 -0800
Committer: Michael Armbrust <[email protected]>
Committed: Tue Nov 24 11:10:58 2015 -0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/ScalaReflection.scala    |  2 +-
 .../sql/catalyst/analysis/unresolved.scala      |  8 ++++----
 .../catalyst/encoders/ExpressionEncoder.scala   |  2 +-
 .../sql/catalyst/encoders/RowEncoder.scala      |  2 +-
 .../sql/catalyst/expressions/Expression.scala   |  2 +-
 .../expressions/complexTypeExtractors.scala     | 18 ++++++++---------
 .../catalyst/expressions/namedExpressions.scala |  4 ++--
 .../sql/catalyst/expressions/objects.scala      | 21 --------------------
 .../catalyst/expressions/ComplexTypeSuite.scala |  4 ++--
 9 files changed, 21 insertions(+), 42 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/927070d6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 476bece..d133ad3 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -130,7 +130,7 @@ object ScalaReflection extends ScalaReflection {
 
     /** Returns the current path with a field at ordinal extracted. */
     def addToPathOrdinal(ordinal: Int, dataType: DataType): Expression = path
-      .map(p => GetInternalRowField(p, ordinal, dataType))
+      .map(p => GetStructField(p, ordinal))
       .getOrElse(BoundReference(ordinal, dataType, false))
 
     /** Returns the current path or `BoundReference`. */

http://git-wip-us.apache.org/repos/asf/spark/blob/927070d6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index 6485bdf..1b2a8dc 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -201,12 +201,12 @@ case class UnresolvedStar(target: Option[Seq[String]]) 
extends Star with Unevalu
     if (attribute.isDefined) {
       // This target resolved to an attribute in child. It must be a struct. 
Expand it.
       attribute.get.dataType match {
-        case s: StructType => {
-          s.fields.map( f => {
-            val extract = GetStructField(attribute.get, f, 
s.getFieldIndex(f.name).get)
+        case s: StructType => s.zipWithIndex.map {
+          case (f, i) =>
+            val extract = GetStructField(attribute.get, i)
             Alias(extract, target.get + "." + f.name)()
-          })
         }
+
         case _ => {
           throw new AnalysisException("Can only star expand struct data types. 
Attribute: `" +
             target.get + "`")

http://git-wip-us.apache.org/repos/asf/spark/blob/927070d6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index 7bc9aed..0c10a56 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -111,7 +111,7 @@ object ExpressionEncoder {
           case UnresolvedAttribute(nameParts) =>
             assert(nameParts.length == 1)
             UnresolvedExtractValue(input, Literal(nameParts.head))
-          case BoundReference(ordinal, dt, _) => GetInternalRowField(input, 
ordinal, dt)
+          case BoundReference(ordinal, dt, _) => GetStructField(input, ordinal)
         }
       }
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/927070d6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index fa553e7..67518f5 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -220,7 +220,7 @@ object RowEncoder {
         If(
           Invoke(input, "isNullAt", BooleanType, Literal(i) :: Nil),
           Literal.create(null, externalDataTypeFor(f.dataType)),
-          constructorFor(GetInternalRowField(input, i, f.dataType)))
+          constructorFor(GetStructField(input, i)))
       }
       CreateExternalRow(convertedFields)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/927070d6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 540ed35..169435a 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -206,7 +206,7 @@ abstract class Expression extends TreeNode[Expression] {
    */
   def prettyString: String = {
     transform {
-      case a: AttributeReference => PrettyAttribute(a.name)
+      case a: AttributeReference => PrettyAttribute(a.name, a.dataType)
       case u: UnresolvedAttribute => PrettyAttribute(u.name)
     }.toString
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/927070d6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
index f871b73..10ce10a 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
@@ -51,7 +51,7 @@ object ExtractValue {
       case (StructType(fields), NonNullLiteral(v, StringType)) =>
         val fieldName = v.toString
         val ordinal = findField(fields, fieldName, resolver)
-        GetStructField(child, fields(ordinal).copy(name = fieldName), ordinal)
+        GetStructField(child, ordinal, Some(fieldName))
 
       case (ArrayType(StructType(fields), containsNull), NonNullLiteral(v, 
StringType)) =>
         val fieldName = v.toString
@@ -97,18 +97,18 @@ object ExtractValue {
  * Returns the value of fields in the Struct `child`.
  *
  * No need to do type checking since it is handled by [[ExtractValue]].
- * TODO: Unify with [[GetInternalRowField]], remove the need to specify a 
[[StructField]].
+ *
+ * Note that we can pass in the field name directly to keep case preserving in 
`toString`.
+ * For example, when get field `yEAr` from `<year: int, month: int>`, we 
should pass in `yEAr`.
  */
-case class GetStructField(child: Expression, field: StructField, ordinal: Int)
+case class GetStructField(child: Expression, ordinal: Int, name: 
Option[String] = None)
   extends UnaryExpression {
 
-  override def dataType: DataType = child.dataType match {
-    case s: StructType => s(ordinal).dataType
-    // This is a hack to avoid breaking existing code until we remove the need 
for the struct field
-    case _ => field.dataType
-  }
+  private lazy val field = child.dataType.asInstanceOf[StructType](ordinal)
+
+  override def dataType: DataType = field.dataType
   override def nullable: Boolean = child.nullable || field.nullable
-  override def toString: String = s"$child.${field.name}"
+  override def toString: String = s"$child.${name.getOrElse(field.name)}"
 
   protected override def nullSafeEval(input: Any): Any =
     input.asInstanceOf[InternalRow].get(ordinal, field.dataType)

http://git-wip-us.apache.org/repos/asf/spark/blob/927070d6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index 00b7970..26b6aca 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -273,7 +273,8 @@ case class AttributeReference(
  * A place holder used when printing expressions without debugging information 
such as the
  * expression id or the unresolved indicator.
  */
-case class PrettyAttribute(name: String) extends Attribute with Unevaluable {
+case class PrettyAttribute(name: String, dataType: DataType = NullType)
+  extends Attribute with Unevaluable {
 
   override def toString: String = name
 
@@ -286,7 +287,6 @@ case class PrettyAttribute(name: String) extends Attribute 
with Unevaluable {
   override def qualifiers: Seq[String] = throw new 
UnsupportedOperationException
   override def exprId: ExprId = throw new UnsupportedOperationException
   override def nullable: Boolean = throw new UnsupportedOperationException
-  override def dataType: DataType = NullType
 }
 
 object VirtualColumn {

http://git-wip-us.apache.org/repos/asf/spark/blob/927070d6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
index 4a1f419..62d09f0 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
@@ -517,27 +517,6 @@ case class CreateExternalRow(children: Seq[Expression]) 
extends Expression {
   }
 }
 
-case class GetInternalRowField(child: Expression, ordinal: Int, dataType: 
DataType)
-  extends UnaryExpression {
-
-  override def nullable: Boolean = true
-
-  override def eval(input: InternalRow): Any =
-    throw new UnsupportedOperationException("Only code-generated evaluation is 
supported")
-
-  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): 
String = {
-    nullSafeCodeGen(ctx, ev, eval => {
-      s"""
-        if ($eval.isNullAt($ordinal)) {
-          ${ev.isNull} = true;
-        } else {
-          ${ev.value} = ${ctx.getValue(eval, dataType, ordinal.toString)};
-        }
-      """
-    })
-  }
-}
-
 /**
  * Serializes an input object using a generic serializer (Kryo or Java).
  * @param kryo if true, use Kryo. Otherwise, use Java.

http://git-wip-us.apache.org/repos/asf/spark/blob/927070d6/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
index e60990a..62fd472 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
@@ -79,8 +79,8 @@ class ComplexTypeSuite extends SparkFunSuite with 
ExpressionEvalHelper {
     def getStructField(expr: Expression, fieldName: String): GetStructField = {
       expr.dataType match {
         case StructType(fields) =>
-          val field = fields.find(_.name == fieldName).get
-          GetStructField(expr, field, fields.indexOf(field))
+          val index = fields.indexWhere(_.name == fieldName)
+          GetStructField(expr, index)
       }
     }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to