Repository: spark
Updated Branches:
  refs/heads/master 04614820e -> fd990a908


[SPARK-23873][SQL] Use accessors in interpreted LambdaVariable

## What changes were proposed in this pull request?

Currently, interpreted execution of `LambdaVariable` just uses 
`InternalRow.get` to access element. We should use specified accessors if 
possible.

## How was this patch tested?

Added test.

Author: Liang-Chi Hsieh <vii...@gmail.com>

Closes #20981 from viirya/SPARK-23873.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/fd990a90
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/fd990a90
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/fd990a90

Branch: refs/heads/master
Commit: fd990a908b94d1c90c4ca604604f35a13b453d44
Parents: 0461482
Author: Liang-Chi Hsieh <vii...@gmail.com>
Authored: Mon Apr 16 22:45:57 2018 +0200
Committer: Herman van Hovell <hvanhov...@databricks.com>
Committed: Mon Apr 16 22:45:57 2018 +0200

----------------------------------------------------------------------
 .../apache/spark/sql/catalyst/InternalRow.scala | 26 +++++++++++++-
 .../catalyst/expressions/BoundAttribute.scala   | 22 +++---------
 .../catalyst/expressions/objects/objects.scala  |  8 ++++-
 .../expressions/ExpressionEvalHelper.scala      |  4 ++-
 .../expressions/ObjectExpressionsSuite.scala    | 38 ++++++++++++++++++--
 5 files changed, 75 insertions(+), 23 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/fd990a90/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
index 2911064..274d75e 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst
 
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
-import org.apache.spark.sql.types.{DataType, Decimal, StructType}
+import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
 
 /**
@@ -119,4 +119,28 @@ object InternalRow {
     case v: MapData => v.copy()
     case _ => value
   }
+
+  /**
+   * Returns an accessor for an `InternalRow` with given data type. The 
returned accessor
+   * actually takes a `SpecializedGetters` input because it can be generalized 
to other classes
+   * that implements `SpecializedGetters` (e.g., `ArrayData`) too.
+   */
+  def getAccessor(dataType: DataType): (SpecializedGetters, Int) => Any = 
dataType match {
+    case BooleanType => (input, ordinal) => input.getBoolean(ordinal)
+    case ByteType => (input, ordinal) => input.getByte(ordinal)
+    case ShortType => (input, ordinal) => input.getShort(ordinal)
+    case IntegerType | DateType => (input, ordinal) => input.getInt(ordinal)
+    case LongType | TimestampType => (input, ordinal) => input.getLong(ordinal)
+    case FloatType => (input, ordinal) => input.getFloat(ordinal)
+    case DoubleType => (input, ordinal) => input.getDouble(ordinal)
+    case StringType => (input, ordinal) => input.getUTF8String(ordinal)
+    case BinaryType => (input, ordinal) => input.getBinary(ordinal)
+    case CalendarIntervalType => (input, ordinal) => input.getInterval(ordinal)
+    case t: DecimalType => (input, ordinal) => input.getDecimal(ordinal, 
t.precision, t.scale)
+    case t: StructType => (input, ordinal) => input.getStruct(ordinal, t.size)
+    case _: ArrayType => (input, ordinal) => input.getArray(ordinal)
+    case _: MapType => (input, ordinal) => input.getMap(ordinal)
+    case u: UserDefinedType[_] => getAccessor(u.sqlType)
+    case _ => (input, ordinal) => input.get(ordinal, dataType)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/fd990a90/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
index 5021a56..4cc84b2 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
@@ -33,28 +33,14 @@ case class BoundReference(ordinal: Int, dataType: DataType, 
nullable: Boolean)
 
   override def toString: String = s"input[$ordinal, ${dataType.simpleString}, 
$nullable]"
 
+  private val accessor: (InternalRow, Int) => Any = 
InternalRow.getAccessor(dataType)
+
   // Use special getter for primitive types (for UnsafeRow)
   override def eval(input: InternalRow): Any = {
-    if (input.isNullAt(ordinal)) {
+    if (nullable && input.isNullAt(ordinal)) {
       null
     } else {
-      dataType match {
-        case BooleanType => input.getBoolean(ordinal)
-        case ByteType => input.getByte(ordinal)
-        case ShortType => input.getShort(ordinal)
-        case IntegerType | DateType => input.getInt(ordinal)
-        case LongType | TimestampType => input.getLong(ordinal)
-        case FloatType => input.getFloat(ordinal)
-        case DoubleType => input.getDouble(ordinal)
-        case StringType => input.getUTF8String(ordinal)
-        case BinaryType => input.getBinary(ordinal)
-        case CalendarIntervalType => input.getInterval(ordinal)
-        case t: DecimalType => input.getDecimal(ordinal, t.precision, t.scale)
-        case t: StructType => input.getStruct(ordinal, t.size)
-        case _: ArrayType => input.getArray(ordinal)
-        case _: MapType => input.getMap(ordinal)
-        case _ => input.get(ordinal, dataType)
-      }
+      accessor(input, ordinal)
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/fd990a90/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 50e90ca..77802e8 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -560,11 +560,17 @@ case class LambdaVariable(
     dataType: DataType,
     nullable: Boolean = true) extends LeafExpression with NonSQLExpression {
 
+  private val accessor: (InternalRow, Int) => Any = 
InternalRow.getAccessor(dataType)
+
   // Interpreted execution of `LambdaVariable` always get the 0-index element 
from input row.
   override def eval(input: InternalRow): Any = {
     assert(input.numFields == 1,
       "The input row of interpreted LambdaVariable should have only 1 field.")
-    input.get(0, dataType)
+    if (nullable && input.isNullAt(0)) {
+      null
+    } else {
+      accessor(input, 0)
+    }
   }
 
   override def genCode(ctx: CodegenContext): ExprCode = {

http://git-wip-us.apache.org/repos/asf/spark/blob/fd990a90/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
index a5ecd1b..b4bf6d7 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
@@ -70,7 +70,9 @@ trait ExpressionEvalHelper extends 
GeneratorDrivenPropertyChecks {
    * Check the equality between result of expression and expected value, it 
will handle
    * Array[Byte], Spread[Double], MapData and Row.
    */
-  protected def checkResult(result: Any, expected: Any, dataType: DataType): 
Boolean = {
+  protected def checkResult(result: Any, expected: Any, exprDataType: 
DataType): Boolean = {
+    val dataType = UserDefinedType.sqlType(exprDataType)
+
     (result, expected) match {
       case (result: Array[Byte], expected: Array[Byte]) =>
         java.util.Arrays.equals(result, expected)

http://git-wip-us.apache.org/repos/asf/spark/blob/fd990a90/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
index b1bc67d..b0188b0 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
@@ -21,13 +21,14 @@ import java.sql.{Date, Timestamp}
 
 import scala.collection.JavaConverters._
 import scala.reflect.ClassTag
+import scala.util.Random
 
 import org.apache.spark.{SparkConf, SparkFunSuite}
 import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
-import org.apache.spark.sql.Row
+import org.apache.spark.sql.{RandomDataGenerator, Row}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.ResolveTimeZone
-import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, 
ExpressionEncoder, RowEncoder}
 import 
org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
 import org.apache.spark.sql.catalyst.expressions.objects._
 import org.apache.spark.sql.catalyst.util._
@@ -381,6 +382,39 @@ class ObjectExpressionsSuite extends SparkFunSuite with 
ExpressionEvalHelper {
       checkEvaluation(decodeUsingSerializer, null, 
InternalRow.fromSeq(Seq(null)))
     }
   }
+
+  test("LambdaVariable should support interpreted execution") {
+    def genSchema(dt: DataType): Seq[StructType] = {
+      Seq(StructType(StructField("col_1", dt, nullable = false) :: Nil),
+        StructType(StructField("col_1", dt, nullable = true) :: Nil))
+    }
+
+    val elementTypes = Seq(BooleanType, ByteType, ShortType, IntegerType, 
LongType, FloatType,
+      DoubleType, DecimalType.USER_DEFAULT, StringType, BinaryType, DateType, 
TimestampType,
+      CalendarIntervalType, new ExamplePointUDT())
+    val arrayTypes = elementTypes.flatMap { elementType =>
+      Seq(ArrayType(elementType, containsNull = false), ArrayType(elementType, 
containsNull = true))
+    }
+    val mapTypes = elementTypes.flatMap { elementType =>
+      Seq(MapType(elementType, elementType, false), MapType(elementType, 
elementType, true))
+    }
+    val structTypes = elementTypes.flatMap { elementType =>
+      Seq(StructType(StructField("col1", elementType, false) :: Nil),
+        StructType(StructField("col1", elementType, true) :: Nil))
+    }
+
+    val testTypes = elementTypes ++ arrayTypes ++ mapTypes ++ structTypes
+    val random = new Random(100)
+    testTypes.foreach { dt =>
+      genSchema(dt).map { schema =>
+        val row = RandomDataGenerator.randomRow(random, schema)
+        val rowConverter = RowEncoder(schema)
+        val internalRow = rowConverter.toRow(row)
+        val lambda = LambdaVariable("dummy", "dummuIsNull", 
schema(0).dataType, schema(0).nullable)
+        checkEvaluationWithoutCodegen(lambda, internalRow.get(0, 
schema(0).dataType), internalRow)
+      }
+    }
+  }
 }
 
 class TestBean extends Serializable {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to