Repository: spark Updated Branches: refs/heads/master 04614820e -> fd990a908
[SPARK-23873][SQL] Use accessors in interpreted LambdaVariable ## What changes were proposed in this pull request? Currently, interpreted execution of `LambdaVariable` just uses `InternalRow.get` to access element. We should use specified accessors if possible. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh <vii...@gmail.com> Closes #20981 from viirya/SPARK-23873. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/fd990a90 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/fd990a90 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/fd990a90 Branch: refs/heads/master Commit: fd990a908b94d1c90c4ca604604f35a13b453d44 Parents: 0461482 Author: Liang-Chi Hsieh <vii...@gmail.com> Authored: Mon Apr 16 22:45:57 2018 +0200 Committer: Herman van Hovell <hvanhov...@databricks.com> Committed: Mon Apr 16 22:45:57 2018 +0200 ---------------------------------------------------------------------- .../apache/spark/sql/catalyst/InternalRow.scala | 26 +++++++++++++- .../catalyst/expressions/BoundAttribute.scala | 22 +++--------- .../catalyst/expressions/objects/objects.scala | 8 ++++- .../expressions/ExpressionEvalHelper.scala | 4 ++- .../expressions/ObjectExpressionsSuite.scala | 38 ++++++++++++++++++-- 5 files changed, 75 insertions(+), 23 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/fd990a90/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index 2911064..274d75e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} -import org.apache.spark.sql.types.{DataType, Decimal, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String /** @@ -119,4 +119,28 @@ object InternalRow { case v: MapData => v.copy() case _ => value } + + /** + * Returns an accessor for an `InternalRow` with given data type. The returned accessor + * actually takes a `SpecializedGetters` input because it can be generalized to other classes + * that implements `SpecializedGetters` (e.g., `ArrayData`) too. + */ + def getAccessor(dataType: DataType): (SpecializedGetters, Int) => Any = dataType match { + case BooleanType => (input, ordinal) => input.getBoolean(ordinal) + case ByteType => (input, ordinal) => input.getByte(ordinal) + case ShortType => (input, ordinal) => input.getShort(ordinal) + case IntegerType | DateType => (input, ordinal) => input.getInt(ordinal) + case LongType | TimestampType => (input, ordinal) => input.getLong(ordinal) + case FloatType => (input, ordinal) => input.getFloat(ordinal) + case DoubleType => (input, ordinal) => input.getDouble(ordinal) + case StringType => (input, ordinal) => input.getUTF8String(ordinal) + case BinaryType => (input, ordinal) => input.getBinary(ordinal) + case CalendarIntervalType => (input, ordinal) => input.getInterval(ordinal) + case t: DecimalType => (input, ordinal) => input.getDecimal(ordinal, t.precision, t.scale) + case t: StructType => (input, ordinal) => input.getStruct(ordinal, t.size) + case _: ArrayType => (input, ordinal) => input.getArray(ordinal) + case _: MapType => (input, ordinal) => input.getMap(ordinal) + case u: UserDefinedType[_] => getAccessor(u.sqlType) + case _ => (input, ordinal) => input.get(ordinal, dataType) + } } http://git-wip-us.apache.org/repos/asf/spark/blob/fd990a90/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 5021a56..4cc84b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -33,28 +33,14 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) override def toString: String = s"input[$ordinal, ${dataType.simpleString}, $nullable]" + private val accessor: (InternalRow, Int) => Any = InternalRow.getAccessor(dataType) + // Use special getter for primitive types (for UnsafeRow) override def eval(input: InternalRow): Any = { - if (input.isNullAt(ordinal)) { + if (nullable && input.isNullAt(ordinal)) { null } else { - dataType match { - case BooleanType => input.getBoolean(ordinal) - case ByteType => input.getByte(ordinal) - case ShortType => input.getShort(ordinal) - case IntegerType | DateType => input.getInt(ordinal) - case LongType | TimestampType => input.getLong(ordinal) - case FloatType => input.getFloat(ordinal) - case DoubleType => input.getDouble(ordinal) - case StringType => input.getUTF8String(ordinal) - case BinaryType => input.getBinary(ordinal) - case CalendarIntervalType => input.getInterval(ordinal) - case t: DecimalType => input.getDecimal(ordinal, t.precision, t.scale) - case t: StructType => input.getStruct(ordinal, t.size) - case _: ArrayType => input.getArray(ordinal) - case _: MapType => input.getMap(ordinal) - case _ => input.get(ordinal, dataType) - } + accessor(input, ordinal) } } http://git-wip-us.apache.org/repos/asf/spark/blob/fd990a90/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 50e90ca..77802e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -560,11 +560,17 @@ case class LambdaVariable( dataType: DataType, nullable: Boolean = true) extends LeafExpression with NonSQLExpression { + private val accessor: (InternalRow, Int) => Any = InternalRow.getAccessor(dataType) + // Interpreted execution of `LambdaVariable` always get the 0-index element from input row. override def eval(input: InternalRow): Any = { assert(input.numFields == 1, "The input row of interpreted LambdaVariable should have only 1 field.") - input.get(0, dataType) + if (nullable && input.isNullAt(0)) { + null + } else { + accessor(input, 0) + } } override def genCode(ctx: CodegenContext): ExprCode = { http://git-wip-us.apache.org/repos/asf/spark/blob/fd990a90/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index a5ecd1b..b4bf6d7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -70,7 +70,9 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { * Check the equality between result of expression and expected value, it will handle * Array[Byte], Spread[Double], MapData and Row. */ - protected def checkResult(result: Any, expected: Any, dataType: DataType): Boolean = { + protected def checkResult(result: Any, expected: Any, exprDataType: DataType): Boolean = { + val dataType = UserDefinedType.sqlType(exprDataType) + (result, expected) match { case (result: Array[Byte], expected: Array[Byte]) => java.util.Arrays.equals(result, expected) http://git-wip-us.apache.org/repos/asf/spark/blob/fd990a90/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index b1bc67d..b0188b0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -21,13 +21,14 @@ import java.sql.{Date, Timestamp} import scala.collection.JavaConverters._ import scala.reflect.ClassTag +import scala.util.Random import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} -import org.apache.spark.sql.Row +import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.ResolveTimeZone -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.util._ @@ -381,6 +382,39 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(decodeUsingSerializer, null, InternalRow.fromSeq(Seq(null))) } } + + test("LambdaVariable should support interpreted execution") { + def genSchema(dt: DataType): Seq[StructType] = { + Seq(StructType(StructField("col_1", dt, nullable = false) :: Nil), + StructType(StructField("col_1", dt, nullable = true) :: Nil)) + } + + val elementTypes = Seq(BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType, + DoubleType, DecimalType.USER_DEFAULT, StringType, BinaryType, DateType, TimestampType, + CalendarIntervalType, new ExamplePointUDT()) + val arrayTypes = elementTypes.flatMap { elementType => + Seq(ArrayType(elementType, containsNull = false), ArrayType(elementType, containsNull = true)) + } + val mapTypes = elementTypes.flatMap { elementType => + Seq(MapType(elementType, elementType, false), MapType(elementType, elementType, true)) + } + val structTypes = elementTypes.flatMap { elementType => + Seq(StructType(StructField("col1", elementType, false) :: Nil), + StructType(StructField("col1", elementType, true) :: Nil)) + } + + val testTypes = elementTypes ++ arrayTypes ++ mapTypes ++ structTypes + val random = new Random(100) + testTypes.foreach { dt => + genSchema(dt).map { schema => + val row = RandomDataGenerator.randomRow(random, schema) + val rowConverter = RowEncoder(schema) + val internalRow = rowConverter.toRow(row) + val lambda = LambdaVariable("dummy", "dummuIsNull", schema(0).dataType, schema(0).nullable) + checkEvaluationWithoutCodegen(lambda, internalRow.get(0, schema(0).dataType), internalRow) + } + } + } } class TestBean extends Serializable { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org