Repository: spark Updated Branches: refs/heads/branch-2.0 6eb8ec6f4 -> 655d88293
[SPARK-15471][SQL] ScalaReflection cleanup ## What changes were proposed in this pull request? 1. simplify the logic of deserializing option type. 2. simplify the logic of serializing array type, and remove silentSchemaFor 3. remove some unnecessary code. ## How was this patch tested? existing tests Author: Wenchen Fan <[email protected]> Closes #13250 from cloud-fan/encoder. (cherry picked from commit 07c36a2f07fcf5da6fb395f830ebbfc10eb27dcc) Signed-off-by: Michael Armbrust <[email protected]> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/655d8829 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/655d8829 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/655d8829 Branch: refs/heads/branch-2.0 Commit: 655d88293ee554f631aabb355b8c24cc38e23332 Parents: 6eb8ec6 Author: Wenchen Fan <[email protected]> Authored: Mon May 23 11:13:27 2016 -0700 Committer: Michael Armbrust <[email protected]> Committed: Mon May 23 11:13:37 2016 -0700 ---------------------------------------------------------------------- .../spark/sql/catalyst/ScalaReflection.scala | 105 ++++--------------- .../catalyst/expressions/objects/objects.scala | 4 +- 2 files changed, 21 insertions(+), 88 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/655d8829/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 36989a2..bdd40f3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} /** @@ -72,6 +72,7 @@ object ScalaReflection extends ScalaReflection { case t if t <:< definitions.ByteTpe => ByteType case t if t <:< definitions.BooleanTpe => BooleanType case t if t <:< localTypeOf[Array[Byte]] => BinaryType + case t if t <:< localTypeOf[CalendarInterval] => CalendarIntervalType case t if t <:< localTypeOf[Decimal] => DecimalType.SYSTEM_DEFAULT case _ => val className = getClassNameFromType(tpe) @@ -189,7 +190,6 @@ object ScalaReflection extends ScalaReflection { case _ => UpCast(expr, expected, walkedTypePath) } - val className = getClassNameFromType(tpe) tpe match { case t if !dataTypeFor(t).isInstanceOf[ObjectType] => getPath @@ -239,16 +239,14 @@ object ScalaReflection extends ScalaReflection { DateTimeUtils.getClass, ObjectType(classOf[java.sql.Date]), "toJavaDate", - getPath :: Nil, - propagateNull = true) + getPath :: Nil) case t if t <:< localTypeOf[java.sql.Timestamp] => StaticInvoke( DateTimeUtils.getClass, ObjectType(classOf[java.sql.Timestamp]), "toJavaTimestamp", - getPath :: Nil, - propagateNull = true) + getPath :: Nil) case t if t <:< localTypeOf[java.lang.String] => Invoke(getPath, "toString", ObjectType(classOf[String])) @@ -437,17 +435,17 @@ object ScalaReflection extends ScalaReflection { walkedTypePath: Seq[String]): Expression = ScalaReflectionLock.synchronized { def toCatalystArray(input: Expression, elementType: `Type`): Expression = { - val externalDataType = dataTypeFor(elementType) - val Schema(catalystType, nullable) = silentSchemaFor(elementType) - if (isNativeType(externalDataType)) { - NewInstance( - classOf[GenericArrayData], - input :: Nil, - dataType = ArrayType(catalystType, nullable)) - } else { - val clsName = getClassNameFromType(elementType) - val newPath = s"""- array element class: "$clsName"""" +: walkedTypePath - MapObjects(serializerFor(_, elementType, newPath), input, externalDataType) + dataTypeFor(elementType) match { + case dt: ObjectType => + val clsName = getClassNameFromType(elementType) + val newPath = s"""- array element class: "$clsName"""" +: walkedTypePath + MapObjects(serializerFor(_, elementType, newPath), input, dt) + + case dt => + NewInstance( + classOf[GenericArrayData], + input :: Nil, + dataType = ArrayType(dt, schemaFor(elementType).nullable)) } } @@ -457,63 +455,10 @@ object ScalaReflection extends ScalaReflection { tpe match { case t if t <:< localTypeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t - optType match { - // For primitive types we must manually unbox the value of the object. - case t if t <:< definitions.IntTpe => - Invoke( - UnwrapOption(ObjectType(classOf[java.lang.Integer]), inputObject), - "intValue", - IntegerType) - case t if t <:< definitions.LongTpe => - Invoke( - UnwrapOption(ObjectType(classOf[java.lang.Long]), inputObject), - "longValue", - LongType) - case t if t <:< definitions.DoubleTpe => - Invoke( - UnwrapOption(ObjectType(classOf[java.lang.Double]), inputObject), - "doubleValue", - DoubleType) - case t if t <:< definitions.FloatTpe => - Invoke( - UnwrapOption(ObjectType(classOf[java.lang.Float]), inputObject), - "floatValue", - FloatType) - case t if t <:< definitions.ShortTpe => - Invoke( - UnwrapOption(ObjectType(classOf[java.lang.Short]), inputObject), - "shortValue", - ShortType) - case t if t <:< definitions.ByteTpe => - Invoke( - UnwrapOption(ObjectType(classOf[java.lang.Byte]), inputObject), - "byteValue", - ByteType) - case t if t <:< definitions.BooleanTpe => - Invoke( - UnwrapOption(ObjectType(classOf[java.lang.Boolean]), inputObject), - "booleanValue", - BooleanType) - - // For non-primitives, we can just extract the object from the Option and then recurse. - case other => - val className = getClassNameFromType(optType) - val newPath = s"""- option value class: "$className"""" +: walkedTypePath - - val optionObjectType: DataType = other match { - // Special handling is required for arrays, as getClassFromType(<Array>) will fail - // since Scala Arrays map to native Java constructs. E.g. "Array[Int]" will map to - // the Java type "[I". - case arr if arr <:< localTypeOf[Array[_]] => arrayClassFor(t) - case cls => ObjectType(getClassFromType(cls)) - } - val unwrapped = UnwrapOption(optionObjectType, inputObject) - - expressions.If( - IsNull(unwrapped), - expressions.Literal.create(null, silentSchemaFor(optType).dataType), - serializerFor(unwrapped, optType, newPath)) - } + val className = getClassNameFromType(optType) + val newPath = s"""- option value class: "$className"""" +: walkedTypePath + val unwrapped = UnwrapOption(dataTypeFor(optType), inputObject) + serializerFor(unwrapped, optType, newPath) // Since List[_] also belongs to localTypeOf[Product], we put this case before // "case t if definedByConstructorParams(t)" to make sure it will match to the @@ -704,18 +649,6 @@ object ScalaReflection extends ScalaReflection { /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ def schemaFor[T: TypeTag]: Schema = schemaFor(localTypeOf[T]) - /** - * Returns a catalyst DataType and its nullability for the given Scala Type using reflection. - * - * Unlike `schemaFor`, this method won't throw exception for un-supported type, it will return - * `NullType` silently instead. - */ - def silentSchemaFor(tpe: `Type`): Schema = try { - schemaFor(tpe) - } catch { - case _: UnsupportedOperationException => Schema(NullType, nullable = true) - } - /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ def schemaFor(tpe: `Type`): Schema = ScalaReflectionLock.synchronized { tpe match { http://git-wip-us.apache.org/repos/asf/spark/blob/655d8829/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 5e17f89..2f2323f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -289,8 +289,8 @@ case class UnwrapOption( ${inputObject.code} final boolean ${ev.isNull} = ${inputObject.isNull} || ${inputObject.value}.isEmpty(); - $javaType ${ev.value} = - ${ev.isNull} ? ${ctx.defaultValue(javaType)} : ($javaType) ${inputObject.value}.get(); + $javaType ${ev.value} = ${ev.isNull} ? + ${ctx.defaultValue(javaType)} : (${ctx.boxedType(javaType)}) ${inputObject.value}.get(); """ ev.copy(code = code) } --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
