This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 34f6066 [SPARK-27001][SQL] Refactor "serializerFor" method between ScalaReflection and JavaTypeInference 34f6066 is described below commit 34f606678a90e860711a5f9f9618cf00788c9eb0 Author: Jungtaek Lim (HeartSaVioR) <kabh...@gmail.com> AuthorDate: Mon Mar 4 10:45:48 2019 +0800 [SPARK-27001][SQL] Refactor "serializerFor" method between ScalaReflection and JavaTypeInference ## What changes were proposed in this pull request? This patch proposes refactoring `serializerFor` method between `ScalaReflection` and `JavaTypeInference`, being consistent with what we refactored for `deserializerFor` in #23854. This patch also extracts the logic on recording walk type path since the logic is duplicated across `serializerFor` and `deserializerFor` with `ScalaReflection` and `JavaTypeInference`. ## How was this patch tested? Existing tests. Closes #23908 from HeartSaVioR/SPARK-27001. Authored-by: Jungtaek Lim (HeartSaVioR) <kabh...@gmail.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../sql/catalyst/DeserializerBuildHelper.scala | 32 ++- .../spark/sql/catalyst/JavaTypeInference.scala | 143 +++++--------- .../spark/sql/catalyst/ScalaReflection.scala | 220 +++++++-------------- .../spark/sql/catalyst/SerializerBuildHelper.scala | 198 +++++++++++++++++++ .../apache/spark/sql/catalyst/WalkedTypePath.scala | 57 ++++++ .../spark/sql/catalyst/analysis/Analyzer.scala | 2 +- .../sql/catalyst/encoders/ExpressionEncoder.scala | 2 +- .../spark/sql/catalyst/encoders/RowEncoder.scala | 2 +- .../spark/sql/catalyst/expressions/Cast.scala | 4 +- .../catalyst/expressions/CodeGenerationSuite.scala | 2 +- .../expressions/NullExpressionsSuite.scala | 2 +- 11 files changed, 394 insertions(+), 270 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala index d75d3ca..e55c25c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala @@ -29,7 +29,7 @@ object DeserializerBuildHelper { path: Expression, part: String, dataType: DataType, - walkedTypePath: Seq[String]): Expression = { + walkedTypePath: WalkedTypePath): Expression = { val newPath = UnresolvedExtractValue(path, expressions.Literal(part)) upCastToExpectedType(newPath, dataType, walkedTypePath) } @@ -39,40 +39,30 @@ object DeserializerBuildHelper { path: Expression, ordinal: Int, dataType: DataType, - walkedTypePath: Seq[String]): Expression = { + walkedTypePath: WalkedTypePath): Expression = { val newPath = GetStructField(path, ordinal) upCastToExpectedType(newPath, dataType, walkedTypePath) } - def deserializerForWithNullSafety( - expr: Expression, - dataType: DataType, - nullable: Boolean, - walkedTypePath: Seq[String], - funcForCreatingNewExpr: (Expression, Seq[String]) => Expression): Expression = { - val newExpr = funcForCreatingNewExpr(expr, walkedTypePath) - expressionWithNullSafety(newExpr, nullable, walkedTypePath) - } - def deserializerForWithNullSafetyAndUpcast( expr: Expression, dataType: DataType, nullable: Boolean, - walkedTypePath: Seq[String], - funcForCreatingNewExpr: (Expression, Seq[String]) => Expression): Expression = { + walkedTypePath: WalkedTypePath, + funcForCreatingDeserializer: (Expression, WalkedTypePath) => Expression): Expression = { val casted = upCastToExpectedType(expr, dataType, walkedTypePath) - deserializerForWithNullSafety(casted, dataType, nullable, walkedTypePath, - funcForCreatingNewExpr) + expressionWithNullSafety(funcForCreatingDeserializer(casted, walkedTypePath), + nullable, walkedTypePath) } - private def expressionWithNullSafety( + def expressionWithNullSafety( expr: Expression, nullable: Boolean, - walkedTypePath: Seq[String]): Expression = { + walkedTypePath: WalkedTypePath): Expression = { if (nullable) { expr } else { - AssertNotNull(expr, walkedTypePath) + AssertNotNull(expr, walkedTypePath.getPaths) } } @@ -167,10 +157,10 @@ object DeserializerBuildHelper { private def upCastToExpectedType( expr: Expression, expected: DataType, - walkedTypePath: Seq[String]): Expression = expected match { + walkedTypePath: WalkedTypePath): Expression = expected match { case _: StructType => expr case _: ArrayType => expr case _: MapType => expr - case _ => UpCast(expr, expected, walkedTypePath) + case _ => UpCast(expr, expected, walkedTypePath.getPaths) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 87b2ae8..933a6db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -27,12 +27,12 @@ import scala.language.existentials import com.google.common.reflect.TypeToken import org.apache.spark.sql.catalyst.DeserializerBuildHelper._ +import org.apache.spark.sql.catalyst.SerializerBuildHelper._ import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.catalyst.util.ArrayBasedMapData import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String /** * Type-inference utilities for POJOs and Java collections. @@ -195,7 +195,7 @@ object JavaTypeInference { */ def deserializerFor(beanClass: Class[_]): Expression = { val typeToken = TypeToken.of(beanClass) - val walkedTypePath = s"""- root class: "${beanClass.getCanonicalName}"""" :: Nil + val walkedTypePath = new WalkedTypePath().recordRoot(beanClass.getCanonicalName) val (dataType, nullable) = inferDataType(typeToken) // Assumes we are deserializing the first column of a row. @@ -208,7 +208,7 @@ object JavaTypeInference { private def deserializerFor( typeToken: TypeToken[_], path: Expression, - walkedTypePath: Seq[String]): Expression = { + walkedTypePath: WalkedTypePath): Expression = { typeToken.getRawType match { case c if !inferExternalType(c).isInstanceOf[ObjectType] => path @@ -244,8 +244,7 @@ object JavaTypeInference { case c if c.isArray => val elementType = c.getComponentType - val newTypePath = s"""- array element class: "${elementType.getCanonicalName}"""" +: - walkedTypePath + val newTypePath = walkedTypePath.recordArray(elementType.getCanonicalName) val (dataType, elementNullable) = inferDataType(elementType) val mapFunction: Expression => Expression = element => { // upcast the array element to the data type the encoder expected. @@ -274,8 +273,7 @@ object JavaTypeInference { case c if listType.isAssignableFrom(typeToken) => val et = elementType(typeToken) - val newTypePath = s"""- array element class: "${et.getType.getTypeName}"""" +: - walkedTypePath + val newTypePath = walkedTypePath.recordArray(et.getType.getTypeName) val (dataType, elementNullable) = inferDataType(et) val mapFunction: Expression => Expression = element => { // upcast the array element to the data type the encoder expected. @@ -291,8 +289,8 @@ object JavaTypeInference { case _ if mapType.isAssignableFrom(typeToken) => val (keyType, valueType) = mapKeyValueType(typeToken) - val newTypePath = (s"""- map key class: "${keyType.getType.getTypeName}"""" + - s""", value class: "${valueType.getType.getTypeName}"""") +: walkedTypePath + val newTypePath = walkedTypePath.recordMap(keyType.getType.getTypeName, + valueType.getType.getTypeName) val keyData = Invoke( @@ -328,15 +326,12 @@ object JavaTypeInference { val fieldName = p.getName val fieldType = typeToken.method(p.getReadMethod).getReturnType val (dataType, nullable) = inferDataType(fieldType) - val newTypePath = (s"""- field (class: "${fieldType.getType.getTypeName}"""" + - s""", name: "$fieldName")""") +: walkedTypePath - val setter = deserializerForWithNullSafety( - path, - dataType, + val newTypePath = walkedTypePath.recordField(fieldType.getType.getTypeName, fieldName) + val setter = expressionWithNullSafety( + deserializerFor(fieldType, addToPath(path, fieldName, dataType, newTypePath), + newTypePath), nullable = nullable, - newTypePath, - (expr, typePath) => deserializerFor(fieldType, - addToPath(expr, fieldName, dataType, typePath), typePath)) + newTypePath) p.getWriteMethod.getName -> setter }.toMap @@ -367,12 +362,10 @@ object JavaTypeInference { def toCatalystArray(input: Expression, elementType: TypeToken[_]): Expression = { val (dataType, nullable) = inferDataType(elementType) if (ScalaReflection.isNativeType(dataType)) { - NewInstance( - classOf[GenericArrayData], - input :: Nil, - dataType = ArrayType(dataType, nullable)) + createSerializerForGenericArray(input, dataType, nullable = nullable) } else { - MapObjects(serializerFor(_, elementType), input, ObjectType(elementType.getRawType)) + createSerializerForMapObjects(input, ObjectType(elementType.getRawType), + serializerFor(_, elementType)) } } @@ -380,60 +373,26 @@ object JavaTypeInference { inputObject } else { typeToken.getRawType match { - case c if c == classOf[String] => - StaticInvoke( - classOf[UTF8String], - StringType, - "fromString", - inputObject :: Nil, - returnNullable = false) - - case c if c == classOf[java.sql.Timestamp] => - StaticInvoke( - DateTimeUtils.getClass, - TimestampType, - "fromJavaTimestamp", - inputObject :: Nil, - returnNullable = false) - - case c if c == classOf[java.time.LocalDate] => - StaticInvoke( - DateTimeUtils.getClass, - DateType, - "localDateToDays", - inputObject :: Nil, - returnNullable = false) - - case c if c == classOf[java.sql.Date] => - StaticInvoke( - DateTimeUtils.getClass, - DateType, - "fromJavaDate", - inputObject :: Nil, - returnNullable = false) + case c if c == classOf[String] => createSerializerForString(inputObject) + + case c if c == classOf[java.time.Instant] => createSerializerForJavaInstant(inputObject) + + case c if c == classOf[java.sql.Timestamp] => createSerializerForSqlTimestamp(inputObject) + + case c if c == classOf[java.time.LocalDate] => createSerializerForJavaLocalDate(inputObject) + + case c if c == classOf[java.sql.Date] => createSerializerForSqlDate(inputObject) case c if c == classOf[java.math.BigDecimal] => - StaticInvoke( - Decimal.getClass, - DecimalType.SYSTEM_DEFAULT, - "apply", - inputObject :: Nil, - returnNullable = false) - - case c if c == classOf[java.lang.Boolean] => - Invoke(inputObject, "booleanValue", BooleanType) - case c if c == classOf[java.lang.Byte] => - Invoke(inputObject, "byteValue", ByteType) - case c if c == classOf[java.lang.Short] => - Invoke(inputObject, "shortValue", ShortType) - case c if c == classOf[java.lang.Integer] => - Invoke(inputObject, "intValue", IntegerType) - case c if c == classOf[java.lang.Long] => - Invoke(inputObject, "longValue", LongType) - case c if c == classOf[java.lang.Float] => - Invoke(inputObject, "floatValue", FloatType) - case c if c == classOf[java.lang.Double] => - Invoke(inputObject, "doubleValue", DoubleType) + createSerializerForJavaBigDecimal(inputObject) + + case c if c == classOf[java.lang.Boolean] => createSerializerForBoolean(inputObject) + case c if c == classOf[java.lang.Byte] => createSerializerForByte(inputObject) + case c if c == classOf[java.lang.Short] => createSerializerForShort(inputObject) + case c if c == classOf[java.lang.Integer] => createSerializerForInteger(inputObject) + case c if c == classOf[java.lang.Long] => createSerializerForLong(inputObject) + case c if c == classOf[java.lang.Float] => createSerializerForFloat(inputObject) + case c if c == classOf[java.lang.Double] => createSerializerForDouble(inputObject) case _ if typeToken.isArray => toCatalystArray(inputObject, typeToken.getComponentType) @@ -444,38 +403,34 @@ object JavaTypeInference { case _ if mapType.isAssignableFrom(typeToken) => val (keyType, valueType) = mapKeyValueType(typeToken) - ExternalMapToCatalyst( + createSerializerForMap( inputObject, - ObjectType(keyType.getRawType), - serializerFor(_, keyType), - keyNullable = true, - ObjectType(valueType.getRawType), - serializerFor(_, valueType), - valueNullable = true + MapElementInformation( + ObjectType(keyType.getRawType), + nullable = true, + serializerFor(_, keyType)), + MapElementInformation( + ObjectType(valueType.getRawType), + nullable = true, + serializerFor(_, valueType)) ) case other if other.isEnum => - StaticInvoke( - classOf[UTF8String], - StringType, - "fromString", - Invoke(inputObject, "name", ObjectType(classOf[String]), returnNullable = false) :: Nil, - returnNullable = false) + createSerializerForString( + Invoke(inputObject, "name", ObjectType(classOf[String]), returnNullable = false)) case other => val properties = getJavaBeanReadableAndWritableProperties(other) - val nonNullOutput = CreateNamedStruct(properties.flatMap { p => + val fields = properties.map { p => val fieldName = p.getName val fieldType = typeToken.method(p.getReadMethod).getReturnType val fieldValue = Invoke( inputObject, p.getReadMethod.getName, inferExternalType(fieldType.getRawType)) - expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType) :: Nil - }) - - val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType) - expressions.If(IsNull(inputObject), nullOutput, nonNullOutput) + (fieldName, serializerFor(fieldValue, fieldType)) + } + createSerializerForObject(inputObject, fields) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index bbddd33..5b3109a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -21,10 +21,11 @@ import org.apache.commons.lang3.reflect.ConstructorUtils import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.DeserializerBuildHelper._ +import org.apache.spark.sql.catalyst.SerializerBuildHelper._ import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal import org.apache.spark.sql.catalyst.expressions.{Expression, _} import org.apache.spark.sql.catalyst.expressions.objects._ -import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils, GenericArrayData, MapData} +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -136,7 +137,7 @@ object ScalaReflection extends ScalaReflection { */ def deserializerForType(tpe: `Type`): Expression = { val clsName = getClassNameFromType(tpe) - val walkedTypePath = s"""- root class: "$clsName"""" :: Nil + val walkedTypePath = new WalkedTypePath().recordRoot(clsName) val Schema(dataType, nullable) = schemaFor(tpe) // Assumes we are deserializing the first column of a row. @@ -156,14 +157,14 @@ object ScalaReflection extends ScalaReflection { private def deserializerFor( tpe: `Type`, path: Expression, - walkedTypePath: Seq[String]): Expression = cleanUpReflectionObjects { + walkedTypePath: WalkedTypePath): Expression = cleanUpReflectionObjects { tpe.dealias match { case t if !dataTypeFor(t).isInstanceOf[ObjectType] => path case t if t <:< localTypeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t val className = getClassNameFromType(optType) - val newTypePath = s"""- option value class: "$className"""" +: walkedTypePath + val newTypePath = walkedTypePath.recordOption(className) WrapOption(deserializerFor(optType, path, newTypePath), dataTypeFor(optType)) case t if t <:< localTypeOf[java.lang.Integer] => @@ -225,7 +226,7 @@ object ScalaReflection extends ScalaReflection { val TypeRef(_, _, Seq(elementType)) = t val Schema(dataType, elementNullable) = schemaFor(elementType) val className = getClassNameFromType(elementType) - val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath + val newTypePath = walkedTypePath.recordArray(className) val mapFunction: Expression => Expression = element => { // upcast the array element to the data type the encoder expected. @@ -260,7 +261,7 @@ object ScalaReflection extends ScalaReflection { val TypeRef(_, _, Seq(elementType)) = t val Schema(dataType, elementNullable) = schemaFor(elementType) val className = getClassNameFromType(elementType) - val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath + val newTypePath = walkedTypePath.recordArray(className) val mapFunction: Expression => Expression = element => { deserializerForWithNullSafetyAndUpcast( @@ -286,8 +287,7 @@ object ScalaReflection extends ScalaReflection { val classNameForKey = getClassNameFromType(keyType) val classNameForValue = getClassNameFromType(valueType) - val newTypePath = (s"""- map key class: "${classNameForKey}"""" + - s""", value class: "${classNameForValue}"""") +: walkedTypePath + val newTypePath = walkedTypePath.recordMap(classNameForKey, classNameForValue) UnresolvedCatalystToExternalMap( path, @@ -322,28 +322,24 @@ object ScalaReflection extends ScalaReflection { val arguments = params.zipWithIndex.map { case ((fieldName, fieldType), i) => val Schema(dataType, nullable) = schemaFor(fieldType) val clsName = getClassNameFromType(fieldType) - val newTypePath = (s"""- field (class: "$clsName", """ + - s"""name: "$fieldName")""") +: walkedTypePath + val newTypePath = walkedTypePath.recordField(clsName, fieldName) // For tuples, we based grab the inner fields by ordinal instead of name. - deserializerForWithNullSafety( - path, - dataType, + val newPath = if (cls.getName startsWith "scala.Tuple") { + deserializerFor( + fieldType, + addToPathOrdinal(path, i, dataType, newTypePath), + newTypePath) + } else { + deserializerFor( + fieldType, + addToPath(path, fieldName, dataType, newTypePath), + newTypePath) + } + expressionWithNullSafety( + newPath, nullable = nullable, - newTypePath, - (expr, typePath) => { - if (cls.getName startsWith "scala.Tuple") { - deserializerFor( - fieldType, - addToPathOrdinal(expr, i, dataType, typePath), - newTypePath) - } else { - deserializerFor( - fieldType, - addToPath(expr, fieldName, dataType, typePath), - newTypePath) - } - }) + newTypePath) } val newInstance = NewInstance(cls, arguments, ObjectType(cls), propagateNull = false) @@ -371,7 +367,7 @@ object ScalaReflection extends ScalaReflection { */ def serializerForType(tpe: `Type`): Expression = ScalaReflection.cleanUpReflectionObjects { val clsName = getClassNameFromType(tpe) - val walkedTypePath = s"""- root class: "$clsName"""" :: Nil + val walkedTypePath = new WalkedTypePath().recordRoot(clsName) // The input object to `ExpressionEncoder` is located at first column of an row. val isPrimitive = tpe.typeSymbol.asClass.isPrimitive @@ -387,38 +383,28 @@ object ScalaReflection extends ScalaReflection { private def serializerFor( inputObject: Expression, tpe: `Type`, - walkedTypePath: Seq[String], + walkedTypePath: WalkedTypePath, seenTypeSet: Set[`Type`] = Set.empty): Expression = cleanUpReflectionObjects { def toCatalystArray(input: Expression, elementType: `Type`): Expression = { dataTypeFor(elementType) match { case dt: ObjectType => val clsName = getClassNameFromType(elementType) - val newPath = s"""- array element class: "$clsName"""" +: walkedTypePath - MapObjects(serializerFor(_, elementType, newPath, seenTypeSet), input, dt) + val newPath = walkedTypePath.recordArray(clsName) + createSerializerForMapObjects(input, dt, + serializerFor(_, elementType, newPath, seenTypeSet)) case dt @ (BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType) => val cls = input.dataType.asInstanceOf[ObjectType].cls if (cls.isArray && cls.getComponentType.isPrimitive) { - StaticInvoke( - classOf[UnsafeArrayData], - ArrayType(dt, false), - "fromPrimitiveArray", - input :: Nil, - returnNullable = false) + createSerializerForPrimitiveArray(input, dt) } else { - NewInstance( - classOf[GenericArrayData], - input :: Nil, - dataType = ArrayType(dt, schemaFor(elementType).nullable)) + createSerializerForGenericArray(input, dt, nullable = schemaFor(elementType).nullable) } case dt => - NewInstance( - classOf[GenericArrayData], - input :: Nil, - dataType = ArrayType(dt, schemaFor(elementType).nullable)) + createSerializerForGenericArray(input, dt, nullable = schemaFor(elementType).nullable) } } @@ -428,7 +414,7 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t val className = getClassNameFromType(optType) - val newPath = s"""- option value class: "$className"""" +: walkedTypePath + val newPath = walkedTypePath.recordOption(className) val unwrapped = UnwrapOption(dataTypeFor(optType), inputObject) serializerFor(unwrapped, optType, newPath, seenTypeSet) @@ -447,17 +433,20 @@ object ScalaReflection extends ScalaReflection { val TypeRef(_, _, Seq(keyType, valueType)) = t val keyClsName = getClassNameFromType(keyType) val valueClsName = getClassNameFromType(valueType) - val keyPath = s"""- map key class: "$keyClsName"""" +: walkedTypePath - val valuePath = s"""- map value class: "$valueClsName"""" +: walkedTypePath + val keyPath = walkedTypePath.recordKeyForMap(keyClsName) + val valuePath = walkedTypePath.recordValueForMap(valueClsName) - ExternalMapToCatalyst( + createSerializerForMap( inputObject, - dataTypeFor(keyType), - serializerFor(_, keyType, keyPath, seenTypeSet), - keyNullable = !keyType.typeSymbol.asClass.isPrimitive, - dataTypeFor(valueType), - serializerFor(_, valueType, valuePath, seenTypeSet), - valueNullable = !valueType.typeSymbol.asClass.isPrimitive) + MapElementInformation( + dataTypeFor(keyType), + nullable = !keyType.typeSymbol.asClass.isPrimitive, + serializerFor(_, keyType, keyPath, seenTypeSet)), + MapElementInformation( + dataTypeFor(valueType), + nullable = !valueType.typeSymbol.asClass.isPrimitive, + serializerFor(_, valueType, valuePath, seenTypeSet)) + ) case t if t <:< localTypeOf[scala.collection.Set[_]] => val TypeRef(_, _, Seq(elementType)) = t @@ -472,110 +461,47 @@ object ScalaReflection extends ScalaReflection { toCatalystArray(newInput, elementType) - case t if t <:< localTypeOf[String] => - StaticInvoke( - classOf[UTF8String], - StringType, - "fromString", - inputObject :: Nil, - returnNullable = false) + case t if t <:< localTypeOf[String] => createSerializerForString(inputObject) - case t if t <:< localTypeOf[java.time.Instant] => - StaticInvoke( - DateTimeUtils.getClass, - TimestampType, - "instantToMicros", - inputObject :: Nil, - returnNullable = false) + case t if t <:< localTypeOf[java.time.Instant] => createSerializerForJavaInstant(inputObject) case t if t <:< localTypeOf[java.sql.Timestamp] => - StaticInvoke( - DateTimeUtils.getClass, - TimestampType, - "fromJavaTimestamp", - inputObject :: Nil, - returnNullable = false) + createSerializerForSqlTimestamp(inputObject) case t if t <:< localTypeOf[java.time.LocalDate] => - StaticInvoke( - DateTimeUtils.getClass, - DateType, - "localDateToDays", - inputObject :: Nil, - returnNullable = false) + createSerializerForJavaLocalDate(inputObject) - case t if t <:< localTypeOf[java.sql.Date] => - StaticInvoke( - DateTimeUtils.getClass, - DateType, - "fromJavaDate", - inputObject :: Nil, - returnNullable = false) + case t if t <:< localTypeOf[java.sql.Date] => createSerializerForSqlDate(inputObject) - case t if t <:< localTypeOf[BigDecimal] => - StaticInvoke( - Decimal.getClass, - DecimalType.SYSTEM_DEFAULT, - "apply", - inputObject :: Nil, - returnNullable = false) + case t if t <:< localTypeOf[BigDecimal] => createSerializerForScalaBigDecimal(inputObject) case t if t <:< localTypeOf[java.math.BigDecimal] => - StaticInvoke( - Decimal.getClass, - DecimalType.SYSTEM_DEFAULT, - "apply", - inputObject :: Nil, - returnNullable = false) + createSerializerForJavaBigDecimal(inputObject) case t if t <:< localTypeOf[java.math.BigInteger] => - StaticInvoke( - Decimal.getClass, - DecimalType.BigIntDecimal, - "apply", - inputObject :: Nil, - returnNullable = false) + createSerializerForJavaBigInteger(inputObject) - case t if t <:< localTypeOf[scala.math.BigInt] => - StaticInvoke( - Decimal.getClass, - DecimalType.BigIntDecimal, - "apply", - inputObject :: Nil, - returnNullable = false) + case t if t <:< localTypeOf[scala.math.BigInt] => createSerializerForScalaBigInt(inputObject) - case t if t <:< localTypeOf[java.lang.Integer] => - Invoke(inputObject, "intValue", IntegerType) - case t if t <:< localTypeOf[java.lang.Long] => - Invoke(inputObject, "longValue", LongType) - case t if t <:< localTypeOf[java.lang.Double] => - Invoke(inputObject, "doubleValue", DoubleType) - case t if t <:< localTypeOf[java.lang.Float] => - Invoke(inputObject, "floatValue", FloatType) - case t if t <:< localTypeOf[java.lang.Short] => - Invoke(inputObject, "shortValue", ShortType) - case t if t <:< localTypeOf[java.lang.Byte] => - Invoke(inputObject, "byteValue", ByteType) - case t if t <:< localTypeOf[java.lang.Boolean] => - Invoke(inputObject, "booleanValue", BooleanType) + case t if t <:< localTypeOf[java.lang.Integer] => createSerializerForInteger(inputObject) + case t if t <:< localTypeOf[java.lang.Long] => createSerializerForLong(inputObject) + case t if t <:< localTypeOf[java.lang.Double] => createSerializerForDouble(inputObject) + case t if t <:< localTypeOf[java.lang.Float] => createSerializerForFloat(inputObject) + case t if t <:< localTypeOf[java.lang.Short] => createSerializerForShort(inputObject) + case t if t <:< localTypeOf[java.lang.Byte] => createSerializerForByte(inputObject) + case t if t <:< localTypeOf[java.lang.Boolean] => createSerializerForBoolean(inputObject) case t if t.typeSymbol.annotations.exists(_.tree.tpe =:= typeOf[SQLUserDefinedType]) => val udt = getClassFromType(t) .getAnnotation(classOf[SQLUserDefinedType]).udt().getConstructor().newInstance() - val obj = NewInstance( - udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), - Nil, - dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt())) - Invoke(obj, "serialize", udt, inputObject :: Nil) + val udtClass = udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt() + createSerializerForUserDefinedType(inputObject, udt, udtClass) case t if UDTRegistration.exists(getClassNameFromType(t)) => val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.getConstructor(). newInstance().asInstanceOf[UserDefinedType[_]] - val obj = NewInstance( - udt.getClass, - Nil, - dataType = ObjectType(udt.getClass)) - Invoke(obj, "serialize", udt, inputObject :: Nil) + val udtClass = udt.getClass + createSerializerForUserDefinedType(inputObject, udt, udtClass) case t if definedByConstructorParams(t) => if (seenTypeSet.contains(t)) { @@ -584,10 +510,10 @@ object ScalaReflection extends ScalaReflection { } val params = getConstructorParameters(t) - val nonNullOutput = CreateNamedStruct(params.flatMap { case (fieldName, fieldType) => + val fields = params.map { case (fieldName, fieldType) => if (javaKeywords.contains(fieldName)) { throw new UnsupportedOperationException(s"`$fieldName` is a reserved keyword and " + - "cannot be used as field name\n" + walkedTypePath.mkString("\n")) + "cannot be used as field name\n" + walkedTypePath) } // SPARK-26730 inputObject won't be null with If's guard below. And KnownNotNul @@ -597,16 +523,14 @@ object ScalaReflection extends ScalaReflection { val fieldValue = Invoke(KnownNotNull(inputObject), fieldName, dataTypeFor(fieldType), returnNullable = !fieldType.typeSymbol.asClass.isPrimitive) val clsName = getClassNameFromType(fieldType) - val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath - expressions.Literal(fieldName) :: - serializerFor(fieldValue, fieldType, newPath, seenTypeSet + t) :: Nil - }) - val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType) - expressions.If(IsNull(inputObject), nullOutput, nonNullOutput) + val newPath = walkedTypePath.recordField(clsName, fieldName) + (fieldName, serializerFor(fieldValue, fieldType, newPath, seenTypeSet + t)) + } + createSerializerForObject(inputObject, fields) - case other => + case _ => throw new UnsupportedOperationException( - s"No Encoder found for $tpe\n" + walkedTypePath.mkString("\n")) + s"No Encoder found for $tpe\n" + walkedTypePath) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala new file mode 100644 index 0000000..e035c4b --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala @@ -0,0 +1,198 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +import org.apache.spark.sql.catalyst.expressions.{CreateNamedStruct, Expression, IsNull, UnsafeArrayData} +import org.apache.spark.sql.catalyst.expressions.objects._ +import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +object SerializerBuildHelper { + + def createSerializerForBoolean(inputObject: Expression): Expression = { + Invoke(inputObject, "booleanValue", BooleanType) + } + + def createSerializerForByte(inputObject: Expression): Expression = { + Invoke(inputObject, "byteValue", ByteType) + } + + def createSerializerForShort(inputObject: Expression): Expression = { + Invoke(inputObject, "shortValue", ShortType) + } + + def createSerializerForInteger(inputObject: Expression): Expression = { + Invoke(inputObject, "intValue", IntegerType) + } + + def createSerializerForLong(inputObject: Expression): Expression = { + Invoke(inputObject, "longValue", LongType) + } + + def createSerializerForFloat(inputObject: Expression): Expression = { + Invoke(inputObject, "floatValue", FloatType) + } + + def createSerializerForDouble(inputObject: Expression): Expression = { + Invoke(inputObject, "doubleValue", DoubleType) + } + + def createSerializerForString(inputObject: Expression): Expression = { + StaticInvoke( + classOf[UTF8String], + StringType, + "fromString", + inputObject :: Nil, + returnNullable = false) + } + + def createSerializerForJavaInstant(inputObject: Expression): Expression = { + StaticInvoke( + DateTimeUtils.getClass, + TimestampType, + "instantToMicros", + inputObject :: Nil, + returnNullable = false) + } + + def createSerializerForSqlTimestamp(inputObject: Expression): Expression = { + StaticInvoke( + DateTimeUtils.getClass, + TimestampType, + "fromJavaTimestamp", + inputObject :: Nil, + returnNullable = false) + } + + def createSerializerForJavaLocalDate(inputObject: Expression): Expression = { + StaticInvoke( + DateTimeUtils.getClass, + DateType, + "localDateToDays", + inputObject :: Nil, + returnNullable = false) + } + + def createSerializerForSqlDate(inputObject: Expression): Expression = { + StaticInvoke( + DateTimeUtils.getClass, + DateType, + "fromJavaDate", + inputObject :: Nil, + returnNullable = false) + } + + def createSerializerForJavaBigDecimal(inputObject: Expression): Expression = { + StaticInvoke( + Decimal.getClass, + DecimalType.SYSTEM_DEFAULT, + "apply", + inputObject :: Nil, + returnNullable = false) + } + + def createSerializerForScalaBigDecimal(inputObject: Expression): Expression = { + createSerializerForJavaBigDecimal(inputObject) + } + + def createSerializerForJavaBigInteger(inputObject: Expression): Expression = { + StaticInvoke( + Decimal.getClass, + DecimalType.BigIntDecimal, + "apply", + inputObject :: Nil, + returnNullable = false) + } + + def createSerializerForScalaBigInt(inputObject: Expression): Expression = { + createSerializerForJavaBigInteger(inputObject) + } + + def createSerializerForPrimitiveArray( + inputObject: Expression, + dataType: DataType): Expression = { + StaticInvoke( + classOf[UnsafeArrayData], + ArrayType(dataType, false), + "fromPrimitiveArray", + inputObject :: Nil, + returnNullable = false) + } + + def createSerializerForGenericArray( + inputObject: Expression, + dataType: DataType, + nullable: Boolean): Expression = { + NewInstance( + classOf[GenericArrayData], + inputObject :: Nil, + dataType = ArrayType(dataType, nullable)) + } + + def createSerializerForMapObjects( + inputObject: Expression, + dataType: ObjectType, + funcForNewExpr: Expression => Expression): Expression = { + MapObjects(funcForNewExpr, inputObject, dataType) + } + + case class MapElementInformation( + dataType: DataType, + nullable: Boolean, + funcForNewExpr: Expression => Expression) + + def createSerializerForMap( + inputObject: Expression, + keyInformation: MapElementInformation, + valueInformation: MapElementInformation): Expression = { + ExternalMapToCatalyst( + inputObject, + keyInformation.dataType, + keyInformation.funcForNewExpr, + keyNullable = keyInformation.nullable, + valueInformation.dataType, + valueInformation.funcForNewExpr, + valueNullable = valueInformation.nullable + ) + } + + private def argumentsForFieldSerializer( + fieldName: String, + serializerForFieldValue: Expression): Seq[Expression] = { + expressions.Literal(fieldName) :: serializerForFieldValue :: Nil + } + + def createSerializerForObject( + inputObject: Expression, + fields: Seq[(String, Expression)]): Expression = { + val nonNullOutput = CreateNamedStruct(fields.flatMap { case(fieldName, fieldExpr) => + argumentsForFieldSerializer(fieldName, fieldExpr) + }) + val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType) + expressions.If(IsNull(inputObject), nullOutput, nonNullOutput) + } + + def createSerializerForUserDefinedType( + inputObject: Expression, + udt: UserDefinedType[_], + udtClass: Class[_]): Expression = { + val obj = NewInstance(udtClass, Nil, dataType = ObjectType(udtClass)) + Invoke(obj, "serialize", udt, inputObject :: Nil) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/WalkedTypePath.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/WalkedTypePath.scala new file mode 100644 index 0000000..cdb55b8 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/WalkedTypePath.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +/** + * This class records the paths the serializer and deserializer walk through to reach current path. + * Note that this class adds new path in prior to recorded paths so it maintains + * the paths as reverse order. + */ +case class WalkedTypePath(private val walkedPaths: Seq[String] = Nil) extends Serializable { + def recordRoot(className: String): WalkedTypePath = + newInstance(s"""- root class: "$className"""") + + def recordOption(className: String): WalkedTypePath = + newInstance(s"""- option value class: "$className"""") + + def recordArray(elementClassName: String): WalkedTypePath = + newInstance(s"""- array element class: "$elementClassName"""") + + def recordMap(keyClassName: String, valueClassName: String): WalkedTypePath = { + newInstance(s"""- map key class: "$keyClassName"""" + + s""", value class: "$valueClassName"""") + } + + def recordKeyForMap(keyClassName: String): WalkedTypePath = + newInstance(s"""- map key class: "$keyClassName"""") + + def recordValueForMap(valueClassName: String): WalkedTypePath = + newInstance(s"""- map value class: "$valueClassName"""") + + def recordField(className: String, fieldName: String): WalkedTypePath = + newInstance(s"""- field (class: "$className", name: "$fieldName")""") + + override def toString: String = { + walkedPaths.mkString("\n") + } + + def getPaths: Seq[String] = walkedPaths + + private def newInstance(newRecord: String): WalkedTypePath = + WalkedTypePath(newRecord +: walkedPaths) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 42904c5..ab9cedc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2348,7 +2348,7 @@ class Analyzer( } else { // always add an UpCast. it will be removed in the optimizer if it is unnecessary. Some(Alias( - UpCast(queryExpr, tableAttr.dataType, Seq()), tableAttr.name + UpCast(queryExpr, tableAttr.dataType), tableAttr.name )( explicitMetadata = Option(tableAttr.metadata) )) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index da5c1fd..abffda7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -21,7 +21,7 @@ import scala.reflect.ClassTag import scala.reflect.runtime.universe.{typeTag, TypeTag} import org.apache.spark.sql.Encoder -import org.apache.spark.sql.catalyst.{InternalRow, JavaTypeInference, ScalaReflection} +import org.apache.spark.sql.catalyst.{InternalRow, JavaTypeInference, ScalaReflection, WalkedTypePath} import org.apache.spark.sql.catalyst.analysis.{Analyzer, GetColumnByOrdinal, SimpleAnalyzer, UnresolvedAttribute, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 68a603b..97709bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -155,7 +155,7 @@ object RowEncoder { element => { val value = serializerFor(ValidateExternalType(element, et), et) if (!containsNull) { - AssertNotNull(value, Seq.empty) + AssertNotNull(value) } else { value } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index d591c58..84087ae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -21,7 +21,7 @@ import java.math.{BigDecimal => JavaBigDecimal} import java.util.concurrent.TimeUnit._ import org.apache.spark.SparkException -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.{InternalRow, WalkedTypePath} import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ @@ -1378,7 +1378,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String * Cast the child expression to the target data type, but will throw error if the cast might * truncate, e.g. long -> int, timestamp -> data. */ -case class UpCast(child: Expression, dataType: DataType, walkedTypePath: Seq[String]) +case class UpCast(child: Expression, dataType: DataType, walkedTypePath: Seq[String] = Nil) extends UnaryExpression with Unevaluable { override lazy val resolved = false } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index baa1b3b..7d49866 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -338,7 +338,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { test("should not apply common subexpression elimination on conditional expressions") { val row = InternalRow(null) val bound = BoundReference(0, IntegerType, true) - val assertNotNull = AssertNotNull(bound, Nil) + val assertNotNull = AssertNotNull(bound) val expr = If(IsNull(bound), Literal(1), Add(assertNotNull, assertNotNull)) val projection = GenerateUnsafeProjection.generate( Seq(expr), subexpressionEliminationEnabled = true) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala index b7ce367..49fd59c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala @@ -53,7 +53,7 @@ class NullExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("AssertNotNUll") { val ex = intercept[RuntimeException] { - evaluateWithoutCodegen(AssertNotNull(Literal(null), Seq.empty[String])) + evaluateWithoutCodegen(AssertNotNull(Literal(null))) }.getMessage assert(ex.contains("Null value appeared in non-nullable field")) } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org