This is an automated email from the ASF dual-hosted git repository. joshrosen pushed a commit to branch branch-2.4 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-2.4 by this push: new ba7f61e [SPARK-26555][SQL][BRANCH-2.4] make ScalaReflection subtype checking thread safe ba7f61e is described below commit ba7f61e25d58aa379f94a23b03503a25574529bc Author: mwlon <mlonca...@hmc.edu> AuthorDate: Wed Jun 19 19:03:35 2019 -0700 [SPARK-26555][SQL][BRANCH-2.4] make ScalaReflection subtype checking thread safe This is a Spark 2.4.x backport of #24085. Original description follows below: ## What changes were proposed in this pull request? Make ScalaReflection subtype checking thread safe by adding a lock. There is a thread safety bug in the <:< operator in all versions of scala (https://github.com/scala/bug/issues/10766). ## How was this patch tested? Existing tests and a new one for the new subtype checking function. Closes #24913 from JoshRosen/joshrosen/SPARK-26555-branch-2.4-backport. Authored-by: mwlon <mlonca...@hmc.edu> Signed-off-by: Josh Rosen <rosenvi...@gmail.com> --- .../spark/sql/catalyst/ScalaReflection.scala | 216 +++++++++++---------- .../spark/sql/catalyst/ScalaReflectionSuite.scala | 6 + 2 files changed, 124 insertions(+), 98 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index c27180e..1b186bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -40,6 +40,9 @@ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} trait DefinedByConstructorParams +private[catalyst] object ScalaSubtypeLock + + /** * A default version of ScalaReflection that uses the runtime universe. */ @@ -68,19 +71,32 @@ object ScalaReflection extends ScalaReflection { */ def dataTypeFor[T : TypeTag]: DataType = dataTypeFor(localTypeOf[T]) + /** + * Synchronize to prevent concurrent usage of `<:<` operator. + * This operator is not thread safe in any current version of scala; i.e. + * (2.11.12, 2.12.8, 2.13.0-M5). + * + * See https://github.com/scala/bug/issues/10766 + */ + private[catalyst] def isSubtype(tpe1: `Type`, tpe2: `Type`): Boolean = { + ScalaSubtypeLock.synchronized { + tpe1 <:< tpe2 + } + } + private def dataTypeFor(tpe: `Type`): DataType = cleanUpReflectionObjects { tpe.dealias match { - case t if t <:< definitions.NullTpe => NullType - case t if t <:< definitions.IntTpe => IntegerType - case t if t <:< definitions.LongTpe => LongType - case t if t <:< definitions.DoubleTpe => DoubleType - case t if t <:< definitions.FloatTpe => FloatType - case t if t <:< definitions.ShortTpe => ShortType - case t if t <:< definitions.ByteTpe => ByteType - case t if t <:< definitions.BooleanTpe => BooleanType - case t if t <:< localTypeOf[Array[Byte]] => BinaryType - case t if t <:< localTypeOf[CalendarInterval] => CalendarIntervalType - case t if t <:< localTypeOf[Decimal] => DecimalType.SYSTEM_DEFAULT + case t if isSubtype(t, definitions.NullTpe) => NullType + case t if isSubtype(t, definitions.IntTpe) => IntegerType + case t if isSubtype(t, definitions.LongTpe) => LongType + case t if isSubtype(t, definitions.DoubleTpe) => DoubleType + case t if isSubtype(t, definitions.FloatTpe) => FloatType + case t if isSubtype(t, definitions.ShortTpe) => ShortType + case t if isSubtype(t, definitions.ByteTpe) => ByteType + case t if isSubtype(t, definitions.BooleanTpe) => BooleanType + case t if isSubtype(t, localTypeOf[Array[Byte]]) => BinaryType + case t if isSubtype(t, localTypeOf[CalendarInterval]) => CalendarIntervalType + case t if isSubtype(t, localTypeOf[Decimal]) => DecimalType.SYSTEM_DEFAULT case _ => val className = getClassNameFromType(tpe) className match { @@ -103,13 +119,13 @@ object ScalaReflection extends ScalaReflection { */ private def arrayClassFor(tpe: `Type`): ObjectType = cleanUpReflectionObjects { val cls = tpe.dealias match { - case t if t <:< definitions.IntTpe => classOf[Array[Int]] - case t if t <:< definitions.LongTpe => classOf[Array[Long]] - case t if t <:< definitions.DoubleTpe => classOf[Array[Double]] - case t if t <:< definitions.FloatTpe => classOf[Array[Float]] - case t if t <:< definitions.ShortTpe => classOf[Array[Short]] - case t if t <:< definitions.ByteTpe => classOf[Array[Byte]] - case t if t <:< definitions.BooleanTpe => classOf[Array[Boolean]] + case t if isSubtype(t, definitions.IntTpe) => classOf[Array[Int]] + case t if isSubtype(t, definitions.LongTpe) => classOf[Array[Long]] + case t if isSubtype(t, definitions.DoubleTpe) => classOf[Array[Double]] + case t if isSubtype(t, definitions.FloatTpe) => classOf[Array[Float]] + case t if isSubtype(t, definitions.ShortTpe) => classOf[Array[Short]] + case t if isSubtype(t, definitions.ByteTpe) => classOf[Array[Byte]] + case t if isSubtype(t, definitions.BooleanTpe) => classOf[Array[Boolean]] case other => // There is probably a better way to do this, but I couldn't find it... val elementType = dataTypeFor(other).asInstanceOf[ObjectType].cls @@ -210,48 +226,48 @@ object ScalaReflection extends ScalaReflection { tpe.dealias match { case t if !dataTypeFor(t).isInstanceOf[ObjectType] => getPath - case t if t <:< localTypeOf[Option[_]] => + case t if isSubtype(t, localTypeOf[Option[_]]) => val TypeRef(_, _, Seq(optType)) = t val className = getClassNameFromType(optType) val newTypePath = s"""- option value class: "$className"""" +: walkedTypePath WrapOption(deserializerFor(optType, path, newTypePath), dataTypeFor(optType)) - case t if t <:< localTypeOf[java.lang.Integer] => + case t if isSubtype(t, localTypeOf[java.lang.Integer]) => val boxedType = classOf[java.lang.Integer] val objectType = ObjectType(boxedType) StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false) - case t if t <:< localTypeOf[java.lang.Long] => + case t if isSubtype(t, localTypeOf[java.lang.Long]) => val boxedType = classOf[java.lang.Long] val objectType = ObjectType(boxedType) StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false) - case t if t <:< localTypeOf[java.lang.Double] => + case t if isSubtype(t, localTypeOf[java.lang.Double]) => val boxedType = classOf[java.lang.Double] val objectType = ObjectType(boxedType) StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false) - case t if t <:< localTypeOf[java.lang.Float] => + case t if isSubtype(t, localTypeOf[java.lang.Float]) => val boxedType = classOf[java.lang.Float] val objectType = ObjectType(boxedType) StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false) - case t if t <:< localTypeOf[java.lang.Short] => + case t if isSubtype(t, localTypeOf[java.lang.Short]) => val boxedType = classOf[java.lang.Short] val objectType = ObjectType(boxedType) StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false) - case t if t <:< localTypeOf[java.lang.Byte] => + case t if isSubtype(t, localTypeOf[java.lang.Byte]) => val boxedType = classOf[java.lang.Byte] val objectType = ObjectType(boxedType) StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false) - case t if t <:< localTypeOf[java.lang.Boolean] => + case t if isSubtype(t, localTypeOf[java.lang.Boolean]) => val boxedType = classOf[java.lang.Boolean] val objectType = ObjectType(boxedType) StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, returnNullable = false) - case t if t <:< localTypeOf[java.sql.Date] => + case t if isSubtype(t, localTypeOf[java.sql.Date]) => StaticInvoke( DateTimeUtils.getClass, ObjectType(classOf[java.sql.Date]), @@ -259,7 +275,7 @@ object ScalaReflection extends ScalaReflection { getPath :: Nil, returnNullable = false) - case t if t <:< localTypeOf[java.sql.Timestamp] => + case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) => StaticInvoke( DateTimeUtils.getClass, ObjectType(classOf[java.sql.Timestamp]), @@ -267,25 +283,25 @@ object ScalaReflection extends ScalaReflection { getPath :: Nil, returnNullable = false) - case t if t <:< localTypeOf[java.lang.String] => + case t if isSubtype(t, localTypeOf[java.lang.String]) => Invoke(getPath, "toString", ObjectType(classOf[String]), returnNullable = false) - case t if t <:< localTypeOf[java.math.BigDecimal] => + case t if isSubtype(t, localTypeOf[java.math.BigDecimal]) => Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]), returnNullable = false) - case t if t <:< localTypeOf[BigDecimal] => + case t if isSubtype(t, localTypeOf[BigDecimal]) => Invoke(getPath, "toBigDecimal", ObjectType(classOf[BigDecimal]), returnNullable = false) - case t if t <:< localTypeOf[java.math.BigInteger] => + case t if isSubtype(t, localTypeOf[java.math.BigInteger]) => Invoke(getPath, "toJavaBigInteger", ObjectType(classOf[java.math.BigInteger]), returnNullable = false) - case t if t <:< localTypeOf[scala.math.BigInt] => + case t if isSubtype(t, localTypeOf[scala.math.BigInt]) => Invoke(getPath, "toScalaBigInt", ObjectType(classOf[scala.math.BigInt]), returnNullable = false) - case t if t <:< localTypeOf[Array[_]] => + case t if isSubtype(t, localTypeOf[Array[_]]) => val TypeRef(_, _, Seq(elementType)) = t val Schema(dataType, elementNullable) = schemaFor(elementType) val className = getClassNameFromType(elementType) @@ -309,13 +325,13 @@ object ScalaReflection extends ScalaReflection { Invoke(arrayData, "array", arrayCls, returnNullable = false) } else { val primitiveMethod = elementType match { - case t if t <:< definitions.IntTpe => "toIntArray" - case t if t <:< definitions.LongTpe => "toLongArray" - case t if t <:< definitions.DoubleTpe => "toDoubleArray" - case t if t <:< definitions.FloatTpe => "toFloatArray" - case t if t <:< definitions.ShortTpe => "toShortArray" - case t if t <:< definitions.ByteTpe => "toByteArray" - case t if t <:< definitions.BooleanTpe => "toBooleanArray" + case t if isSubtype(t, definitions.IntTpe) => "toIntArray" + case t if isSubtype(t, definitions.LongTpe) => "toLongArray" + case t if isSubtype(t, definitions.DoubleTpe) => "toDoubleArray" + case t if isSubtype(t, definitions.FloatTpe) => "toFloatArray" + case t if isSubtype(t, definitions.ShortTpe) => "toShortArray" + case t if isSubtype(t, definitions.ByteTpe) => "toByteArray" + case t if isSubtype(t, definitions.BooleanTpe) => "toBooleanArray" case other => throw new IllegalStateException("expect primitive array element type " + "but got " + other) } @@ -324,8 +340,8 @@ object ScalaReflection extends ScalaReflection { // We serialize a `Set` to Catalyst array. When we deserialize a Catalyst array // to a `Set`, if there are duplicated elements, the elements will be de-duplicated. - case t if t <:< localTypeOf[Seq[_]] || - t <:< localTypeOf[scala.collection.Set[_]] => + case t if isSubtype(t, localTypeOf[Seq[_]]) || + isSubtype(t, localTypeOf[scala.collection.Set[_]]) => val TypeRef(_, _, Seq(elementType)) = t val Schema(dataType, elementNullable) = schemaFor(elementType) val className = getClassNameFromType(elementType) @@ -344,14 +360,14 @@ object ScalaReflection extends ScalaReflection { val companion = t.dealias.typeSymbol.companion.typeSignature val cls = companion.member(TermName("newBuilder")) match { - case NoSymbol if t <:< localTypeOf[Seq[_]] => classOf[Seq[_]] - case NoSymbol if t <:< localTypeOf[scala.collection.Set[_]] => + case NoSymbol if isSubtype(t, localTypeOf[Seq[_]]) => classOf[Seq[_]] + case NoSymbol if isSubtype(t, localTypeOf[scala.collection.Set[_]]) => classOf[scala.collection.Set[_]] case _ => mirror.runtimeClass(t.typeSymbol.asClass) } UnresolvedMapObjects(mapFunction, getPath, Some(cls)) - case t if t <:< localTypeOf[Map[_, _]] => + case t if isSubtype(t, localTypeOf[Map[_, _]]) => // TODO: add walked type path for map val TypeRef(_, _, Seq(keyType, valueType)) = t @@ -486,7 +502,7 @@ object ScalaReflection extends ScalaReflection { tpe.dealias match { case _ if !inputObject.dataType.isInstanceOf[ObjectType] => inputObject - case t if t <:< localTypeOf[Option[_]] => + case t if isSubtype(t, localTypeOf[Option[_]]) => val TypeRef(_, _, Seq(optType)) = t val className = getClassNameFromType(optType) val newPath = s"""- option value class: "$className"""" +: walkedTypePath @@ -496,15 +512,15 @@ object ScalaReflection extends ScalaReflection { // Since List[_] also belongs to localTypeOf[Product], we put this case before // "case t if definedByConstructorParams(t)" to make sure it will match to the // case "localTypeOf[Seq[_]]" - case t if t <:< localTypeOf[Seq[_]] => + case t if isSubtype(t, localTypeOf[Seq[_]]) => val TypeRef(_, _, Seq(elementType)) = t toCatalystArray(inputObject, elementType) - case t if t <:< localTypeOf[Array[_]] => + case t if isSubtype(t, localTypeOf[Array[_]]) => val TypeRef(_, _, Seq(elementType)) = t toCatalystArray(inputObject, elementType) - case t if t <:< localTypeOf[Map[_, _]] => + case t if isSubtype(t, localTypeOf[Map[_, _]]) => val TypeRef(_, _, Seq(keyType, valueType)) = t val keyClsName = getClassNameFromType(keyType) val valueClsName = getClassNameFromType(valueType) @@ -520,7 +536,7 @@ object ScalaReflection extends ScalaReflection { serializerFor(_, valueType, valuePath, seenTypeSet), valueNullable = !valueType.typeSymbol.asClass.isPrimitive) - case t if t <:< localTypeOf[scala.collection.Set[_]] => + case t if isSubtype(t, localTypeOf[scala.collection.Set[_]]) => val TypeRef(_, _, Seq(elementType)) = t // There's no corresponding Catalyst type for `Set`, we serialize a `Set` to Catalyst array. @@ -533,7 +549,7 @@ object ScalaReflection extends ScalaReflection { toCatalystArray(newInput, elementType) - case t if t <:< localTypeOf[String] => + case t if isSubtype(t, localTypeOf[String]) => StaticInvoke( classOf[UTF8String], StringType, @@ -541,7 +557,7 @@ object ScalaReflection extends ScalaReflection { inputObject :: Nil, returnNullable = false) - case t if t <:< localTypeOf[java.sql.Timestamp] => + case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) => StaticInvoke( DateTimeUtils.getClass, TimestampType, @@ -549,7 +565,7 @@ object ScalaReflection extends ScalaReflection { inputObject :: Nil, returnNullable = false) - case t if t <:< localTypeOf[java.sql.Date] => + case t if isSubtype(t, localTypeOf[java.sql.Date]) => StaticInvoke( DateTimeUtils.getClass, DateType, @@ -557,7 +573,7 @@ object ScalaReflection extends ScalaReflection { inputObject :: Nil, returnNullable = false) - case t if t <:< localTypeOf[BigDecimal] => + case t if isSubtype(t, localTypeOf[BigDecimal]) => StaticInvoke( Decimal.getClass, DecimalType.SYSTEM_DEFAULT, @@ -565,7 +581,7 @@ object ScalaReflection extends ScalaReflection { inputObject :: Nil, returnNullable = false) - case t if t <:< localTypeOf[java.math.BigDecimal] => + case t if isSubtype(t, localTypeOf[java.math.BigDecimal]) => StaticInvoke( Decimal.getClass, DecimalType.SYSTEM_DEFAULT, @@ -573,7 +589,7 @@ object ScalaReflection extends ScalaReflection { inputObject :: Nil, returnNullable = false) - case t if t <:< localTypeOf[java.math.BigInteger] => + case t if isSubtype(t, localTypeOf[java.math.BigInteger]) => StaticInvoke( Decimal.getClass, DecimalType.BigIntDecimal, @@ -581,7 +597,7 @@ object ScalaReflection extends ScalaReflection { inputObject :: Nil, returnNullable = false) - case t if t <:< localTypeOf[scala.math.BigInt] => + case t if isSubtype(t, localTypeOf[scala.math.BigInt]) => StaticInvoke( Decimal.getClass, DecimalType.BigIntDecimal, @@ -589,19 +605,19 @@ object ScalaReflection extends ScalaReflection { inputObject :: Nil, returnNullable = false) - case t if t <:< localTypeOf[java.lang.Integer] => + case t if isSubtype(t, localTypeOf[java.lang.Integer]) => Invoke(inputObject, "intValue", IntegerType) - case t if t <:< localTypeOf[java.lang.Long] => + case t if isSubtype(t, localTypeOf[java.lang.Long]) => Invoke(inputObject, "longValue", LongType) - case t if t <:< localTypeOf[java.lang.Double] => + case t if isSubtype(t, localTypeOf[java.lang.Double]) => Invoke(inputObject, "doubleValue", DoubleType) - case t if t <:< localTypeOf[java.lang.Float] => + case t if isSubtype(t, localTypeOf[java.lang.Float]) => Invoke(inputObject, "floatValue", FloatType) - case t if t <:< localTypeOf[java.lang.Short] => + case t if isSubtype(t, localTypeOf[java.lang.Short]) => Invoke(inputObject, "shortValue", ShortType) - case t if t <:< localTypeOf[java.lang.Byte] => + case t if isSubtype(t, localTypeOf[java.lang.Byte]) => Invoke(inputObject, "byteValue", ByteType) - case t if t <:< localTypeOf[java.lang.Boolean] => + case t if isSubtype(t, localTypeOf[java.lang.Boolean]) => Invoke(inputObject, "booleanValue", BooleanType) case t if t.typeSymbol.annotations.exists(_.tree.tpe =:= typeOf[SQLUserDefinedType]) => @@ -658,7 +674,7 @@ object ScalaReflection extends ScalaReflection { */ def optionOfProductType(tpe: `Type`): Boolean = cleanUpReflectionObjects { tpe.dealias match { - case t if t <:< localTypeOf[Option[_]] => + case t if isSubtype(t, localTypeOf[Option[_]]) => val TypeRef(_, _, Seq(optType)) = t definedByConstructorParams(optType) case _ => false @@ -724,7 +740,7 @@ object ScalaReflection extends ScalaReflection { tpe.dealias match { // this must be the first case, since all objects in scala are instances of Null, therefore // Null type would wrongly match the first of them, which is Option as of now - case t if t <:< definitions.NullTpe => Schema(NullType, nullable = true) + case t if isSubtype(t, definitions.NullTpe) => Schema(NullType, nullable = true) case t if t.typeSymbol.annotations.exists(_.tree.tpe =:= typeOf[SQLUserDefinedType]) => val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() Schema(udt, nullable = true) @@ -732,52 +748,56 @@ object ScalaReflection extends ScalaReflection { val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance() .asInstanceOf[UserDefinedType[_]] Schema(udt, nullable = true) - case t if t <:< localTypeOf[Option[_]] => + case t if isSubtype(t, localTypeOf[Option[_]]) => val TypeRef(_, _, Seq(optType)) = t Schema(schemaFor(optType).dataType, nullable = true) - case t if t <:< localTypeOf[Array[Byte]] => Schema(BinaryType, nullable = true) - case t if t <:< localTypeOf[Array[_]] => + case t if isSubtype(t, localTypeOf[Array[Byte]]) => Schema(BinaryType, nullable = true) + case t if isSubtype(t, localTypeOf[Array[_]]) => val TypeRef(_, _, Seq(elementType)) = t val Schema(dataType, nullable) = schemaFor(elementType) Schema(ArrayType(dataType, containsNull = nullable), nullable = true) - case t if t <:< localTypeOf[Seq[_]] => + case t if isSubtype(t, localTypeOf[Seq[_]]) => val TypeRef(_, _, Seq(elementType)) = t val Schema(dataType, nullable) = schemaFor(elementType) Schema(ArrayType(dataType, containsNull = nullable), nullable = true) - case t if t <:< localTypeOf[Map[_, _]] => + case t if isSubtype(t, localTypeOf[Map[_, _]]) => val TypeRef(_, _, Seq(keyType, valueType)) = t val Schema(valueDataType, valueNullable) = schemaFor(valueType) Schema(MapType(schemaFor(keyType).dataType, valueDataType, valueContainsNull = valueNullable), nullable = true) - case t if t <:< localTypeOf[Set[_]] => + case t if isSubtype(t, localTypeOf[Set[_]]) => val TypeRef(_, _, Seq(elementType)) = t val Schema(dataType, nullable) = schemaFor(elementType) Schema(ArrayType(dataType, containsNull = nullable), nullable = true) - case t if t <:< localTypeOf[String] => Schema(StringType, nullable = true) - case t if t <:< localTypeOf[java.sql.Timestamp] => Schema(TimestampType, nullable = true) - case t if t <:< localTypeOf[java.sql.Date] => Schema(DateType, nullable = true) - case t if t <:< localTypeOf[BigDecimal] => Schema(DecimalType.SYSTEM_DEFAULT, nullable = true) - case t if t <:< localTypeOf[java.math.BigDecimal] => + case t if isSubtype(t, localTypeOf[String]) => Schema(StringType, nullable = true) + case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) => + Schema(TimestampType, nullable = true) + case t if isSubtype(t, localTypeOf[java.sql.Date]) => + Schema(DateType, nullable = true) + case t if isSubtype(t, localTypeOf[BigDecimal]) => Schema(DecimalType.SYSTEM_DEFAULT, nullable = true) - case t if t <:< localTypeOf[java.math.BigInteger] => + case t if isSubtype(t, localTypeOf[java.math.BigDecimal]) => + Schema(DecimalType.SYSTEM_DEFAULT, nullable = true) + case t if isSubtype(t, localTypeOf[java.math.BigInteger]) => Schema(DecimalType.BigIntDecimal, nullable = true) - case t if t <:< localTypeOf[scala.math.BigInt] => + case t if isSubtype(t, localTypeOf[scala.math.BigInt]) => Schema(DecimalType.BigIntDecimal, nullable = true) - case t if t <:< localTypeOf[Decimal] => Schema(DecimalType.SYSTEM_DEFAULT, nullable = true) - case t if t <:< localTypeOf[java.lang.Integer] => Schema(IntegerType, nullable = true) - case t if t <:< localTypeOf[java.lang.Long] => Schema(LongType, nullable = true) - case t if t <:< localTypeOf[java.lang.Double] => Schema(DoubleType, nullable = true) - case t if t <:< localTypeOf[java.lang.Float] => Schema(FloatType, nullable = true) - case t if t <:< localTypeOf[java.lang.Short] => Schema(ShortType, nullable = true) - case t if t <:< localTypeOf[java.lang.Byte] => Schema(ByteType, nullable = true) - case t if t <:< localTypeOf[java.lang.Boolean] => Schema(BooleanType, nullable = true) - case t if t <:< definitions.IntTpe => Schema(IntegerType, nullable = false) - case t if t <:< definitions.LongTpe => Schema(LongType, nullable = false) - case t if t <:< definitions.DoubleTpe => Schema(DoubleType, nullable = false) - case t if t <:< definitions.FloatTpe => Schema(FloatType, nullable = false) - case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false) - case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false) - case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false) + case t if isSubtype(t, localTypeOf[Decimal]) => + Schema(DecimalType.SYSTEM_DEFAULT, nullable = true) + case t if isSubtype(t, localTypeOf[java.lang.Integer]) => Schema(IntegerType, nullable = true) + case t if isSubtype(t, localTypeOf[java.lang.Long]) => Schema(LongType, nullable = true) + case t if isSubtype(t, localTypeOf[java.lang.Double]) => Schema(DoubleType, nullable = true) + case t if isSubtype(t, localTypeOf[java.lang.Float]) => Schema(FloatType, nullable = true) + case t if isSubtype(t, localTypeOf[java.lang.Short]) => Schema(ShortType, nullable = true) + case t if isSubtype(t, localTypeOf[java.lang.Byte]) => Schema(ByteType, nullable = true) + case t if isSubtype(t, localTypeOf[java.lang.Boolean]) => Schema(BooleanType, nullable = true) + case t if isSubtype(t, definitions.IntTpe) => Schema(IntegerType, nullable = false) + case t if isSubtype(t, definitions.LongTpe) => Schema(LongType, nullable = false) + case t if isSubtype(t, definitions.DoubleTpe) => Schema(DoubleType, nullable = false) + case t if isSubtype(t, definitions.FloatTpe) => Schema(FloatType, nullable = false) + case t if isSubtype(t, definitions.ShortTpe) => Schema(ShortType, nullable = false) + case t if isSubtype(t, definitions.ByteTpe) => Schema(ByteType, nullable = false) + case t if isSubtype(t, definitions.BooleanTpe) => Schema(BooleanType, nullable = false) case t if definedByConstructorParams(t) => val params = getConstructorParameters(t) Schema(StructType( @@ -805,9 +825,9 @@ object ScalaReflection extends ScalaReflection { def definedByConstructorParams(tpe: Type): Boolean = cleanUpReflectionObjects { tpe.dealias match { // `Option` is a `Product`, but we don't wanna treat `Option[Int]` as a struct type. - case t if t <:< localTypeOf[Option[_]] => definedByConstructorParams(t.typeArgs.head) - case _ => tpe.dealias <:< localTypeOf[Product] || - tpe.dealias <:< localTypeOf[DefinedByConstructorParams] + case t if isSubtype(t, localTypeOf[Option[_]]) => definedByConstructorParams(t.typeArgs.head) + case _ => isSubtype(tpe.dealias, localTypeOf[Product]) || + isSubtype(tpe.dealias, localTypeOf[DefinedByConstructorParams]) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index f9ee948..38b8e7b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -112,6 +112,12 @@ object TestingUDT { class ScalaReflectionSuite extends SparkFunSuite { import org.apache.spark.sql.catalyst.ScalaReflection._ + test("isSubtype") { + assert(isSubtype(localTypeOf[Option[Int]], localTypeOf[Option[_]])) + assert(isSubtype(localTypeOf[Option[Int]], localTypeOf[Option[Int]])) + assert(!isSubtype(localTypeOf[Option[_]], localTypeOf[Option[Int]])) + } + test("SQLUserDefinedType annotation on Scala structure") { val schema = schemaFor[TestingUDT.NestedStruct] assert(schema === Schema( --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org