This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new d5c08fc  [SPARK-26555][SQL] make ScalaReflection subtype checking 
thread safe
d5c08fc is described below

commit d5c08fcaab141e13f95bd8d7a2c900aeccdf718d
Author: mwlon <[email protected]>
AuthorDate: Tue Mar 19 18:22:01 2019 +0800

    [SPARK-26555][SQL] make ScalaReflection subtype checking thread safe
    
    ## What changes were proposed in this pull request?
    
    Make ScalaReflection subtype checking thread safe by adding a lock. There 
is a thread safety bug in the <:< operator in all versions of scala 
(https://github.com/scala/bug/issues/10766).
    
    ## How was this patch tested?
    
    Existing tests and a new one for the new subtype checking function.
    
    Closes #24085 from mwlon/SPARK-26555.
    
    Authored-by: mwlon <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../spark/sql/catalyst/ScalaReflection.scala       | 234 ++++++++++++---------
 .../spark/sql/catalyst/ScalaReflectionSuite.scala  |   6 +
 2 files changed, 136 insertions(+), 104 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index d8d268a..fa8993e 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -38,6 +38,9 @@ import org.apache.spark.unsafe.types.{CalendarInterval, 
UTF8String}
 trait DefinedByConstructorParams
 
 
+private[catalyst] object ScalaSubtypeLock
+
+
 /**
  * A default version of ScalaReflection that uses the runtime universe.
  */
@@ -66,19 +69,32 @@ object ScalaReflection extends ScalaReflection {
    */
   def dataTypeFor[T : TypeTag]: DataType = dataTypeFor(localTypeOf[T])
 
+  /**
+   * Synchronize to prevent concurrent usage of `<:<` operator.
+   * This operator is not thread safe in any current version of scala; i.e.
+   * (2.11.12, 2.12.8, 2.13.0-M5).
+   *
+   * See https://github.com/scala/bug/issues/10766
+   */
+  private[catalyst] def isSubtype(tpe1: `Type`, tpe2: `Type`): Boolean = {
+    ScalaSubtypeLock.synchronized {
+      tpe1 <:< tpe2
+    }
+  }
+
   private def dataTypeFor(tpe: `Type`): DataType = cleanUpReflectionObjects {
     tpe.dealias match {
-      case t if t <:< definitions.NullTpe => NullType
-      case t if t <:< definitions.IntTpe => IntegerType
-      case t if t <:< definitions.LongTpe => LongType
-      case t if t <:< definitions.DoubleTpe => DoubleType
-      case t if t <:< definitions.FloatTpe => FloatType
-      case t if t <:< definitions.ShortTpe => ShortType
-      case t if t <:< definitions.ByteTpe => ByteType
-      case t if t <:< definitions.BooleanTpe => BooleanType
-      case t if t <:< localTypeOf[Array[Byte]] => BinaryType
-      case t if t <:< localTypeOf[CalendarInterval] => CalendarIntervalType
-      case t if t <:< localTypeOf[Decimal] => DecimalType.SYSTEM_DEFAULT
+      case t if isSubtype(t, definitions.NullTpe) => NullType
+      case t if isSubtype(t, definitions.IntTpe) => IntegerType
+      case t if isSubtype(t, definitions.LongTpe) => LongType
+      case t if isSubtype(t, definitions.DoubleTpe) => DoubleType
+      case t if isSubtype(t, definitions.FloatTpe) => FloatType
+      case t if isSubtype(t, definitions.ShortTpe) => ShortType
+      case t if isSubtype(t, definitions.ByteTpe) => ByteType
+      case t if isSubtype(t, definitions.BooleanTpe) => BooleanType
+      case t if isSubtype(t, localTypeOf[Array[Byte]]) => BinaryType
+      case t if isSubtype(t, localTypeOf[CalendarInterval]) => 
CalendarIntervalType
+      case t if isSubtype(t, localTypeOf[Decimal]) => 
DecimalType.SYSTEM_DEFAULT
       case _ =>
         val className = getClassNameFromType(tpe)
         className match {
@@ -101,13 +117,13 @@ object ScalaReflection extends ScalaReflection {
    */
   private def arrayClassFor(tpe: `Type`): ObjectType = 
cleanUpReflectionObjects {
     val cls = tpe.dealias match {
-      case t if t <:< definitions.IntTpe => classOf[Array[Int]]
-      case t if t <:< definitions.LongTpe => classOf[Array[Long]]
-      case t if t <:< definitions.DoubleTpe => classOf[Array[Double]]
-      case t if t <:< definitions.FloatTpe => classOf[Array[Float]]
-      case t if t <:< definitions.ShortTpe => classOf[Array[Short]]
-      case t if t <:< definitions.ByteTpe => classOf[Array[Byte]]
-      case t if t <:< definitions.BooleanTpe => classOf[Array[Boolean]]
+      case t if isSubtype(t, definitions.IntTpe) => classOf[Array[Int]]
+      case t if isSubtype(t, definitions.LongTpe) => classOf[Array[Long]]
+      case t if isSubtype(t, definitions.DoubleTpe) => classOf[Array[Double]]
+      case t if isSubtype(t, definitions.FloatTpe) => classOf[Array[Float]]
+      case t if isSubtype(t, definitions.ShortTpe) => classOf[Array[Short]]
+      case t if isSubtype(t, definitions.ByteTpe) => classOf[Array[Byte]]
+      case t if isSubtype(t, definitions.BooleanTpe) => classOf[Array[Boolean]]
       case other =>
         // There is probably a better way to do this, but I couldn't find it...
         val elementType = dataTypeFor(other).asInstanceOf[ObjectType].cls
@@ -161,68 +177,68 @@ object ScalaReflection extends ScalaReflection {
     tpe.dealias match {
       case t if !dataTypeFor(t).isInstanceOf[ObjectType] => path
 
-      case t if t <:< localTypeOf[Option[_]] =>
+      case t if isSubtype(t, localTypeOf[Option[_]]) =>
         val TypeRef(_, _, Seq(optType)) = t
         val className = getClassNameFromType(optType)
         val newTypePath = walkedTypePath.recordOption(className)
         WrapOption(deserializerFor(optType, path, newTypePath), 
dataTypeFor(optType))
 
-      case t if t <:< localTypeOf[java.lang.Integer] =>
+      case t if isSubtype(t, localTypeOf[java.lang.Integer]) =>
         createDeserializerForTypesSupportValueOf(path,
           classOf[java.lang.Integer])
 
-      case t if t <:< localTypeOf[java.lang.Long] =>
+      case t if isSubtype(t, localTypeOf[java.lang.Long]) =>
         createDeserializerForTypesSupportValueOf(path,
           classOf[java.lang.Long])
 
-      case t if t <:< localTypeOf[java.lang.Double] =>
+      case t if isSubtype(t, localTypeOf[java.lang.Double]) =>
         createDeserializerForTypesSupportValueOf(path,
           classOf[java.lang.Double])
 
-      case t if t <:< localTypeOf[java.lang.Float] =>
+      case t if isSubtype(t, localTypeOf[java.lang.Float]) =>
         createDeserializerForTypesSupportValueOf(path,
           classOf[java.lang.Float])
 
-      case t if t <:< localTypeOf[java.lang.Short] =>
+      case t if isSubtype(t, localTypeOf[java.lang.Short]) =>
         createDeserializerForTypesSupportValueOf(path,
           classOf[java.lang.Short])
 
-      case t if t <:< localTypeOf[java.lang.Byte] =>
+      case t if isSubtype(t, localTypeOf[java.lang.Byte]) =>
         createDeserializerForTypesSupportValueOf(path,
           classOf[java.lang.Byte])
 
-      case t if t <:< localTypeOf[java.lang.Boolean] =>
+      case t if isSubtype(t, localTypeOf[java.lang.Boolean]) =>
         createDeserializerForTypesSupportValueOf(path,
           classOf[java.lang.Boolean])
 
-      case t if t <:< localTypeOf[java.time.LocalDate] =>
+      case t if isSubtype(t, localTypeOf[java.time.LocalDate]) =>
         createDeserializerForLocalDate(path)
 
-      case t if t <:< localTypeOf[java.sql.Date] =>
+      case t if isSubtype(t, localTypeOf[java.sql.Date]) =>
         createDeserializerForSqlDate(path)
 
-      case t if t <:< localTypeOf[java.time.Instant] =>
+      case t if isSubtype(t, localTypeOf[java.time.Instant]) =>
         createDeserializerForInstant(path)
 
-      case t if t <:< localTypeOf[java.sql.Timestamp] =>
+      case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) =>
         createDeserializerForSqlTimestamp(path)
 
-      case t if t <:< localTypeOf[java.lang.String] =>
+      case t if isSubtype(t, localTypeOf[java.lang.String]) =>
         createDeserializerForString(path, returnNullable = false)
 
-      case t if t <:< localTypeOf[java.math.BigDecimal] =>
+      case t if isSubtype(t, localTypeOf[java.math.BigDecimal]) =>
         createDeserializerForJavaBigDecimal(path, returnNullable = false)
 
-      case t if t <:< localTypeOf[BigDecimal] =>
+      case t if isSubtype(t, localTypeOf[BigDecimal]) =>
         createDeserializerForScalaBigDecimal(path, returnNullable = false)
 
-      case t if t <:< localTypeOf[java.math.BigInteger] =>
+      case t if isSubtype(t, localTypeOf[java.math.BigInteger]) =>
         createDeserializerForJavaBigInteger(path, returnNullable = false)
 
-      case t if t <:< localTypeOf[scala.math.BigInt] =>
+      case t if isSubtype(t, localTypeOf[scala.math.BigInt]) =>
         createDeserializerForScalaBigInt(path)
 
-      case t if t <:< localTypeOf[Array[_]] =>
+      case t if isSubtype(t, localTypeOf[Array[_]]) =>
         val TypeRef(_, _, Seq(elementType)) = t
         val Schema(dataType, elementNullable) = schemaFor(elementType)
         val className = getClassNameFromType(elementType)
@@ -242,13 +258,13 @@ object ScalaReflection extends ScalaReflection {
         val arrayCls = arrayClassFor(elementType)
 
         val methodName = elementType match {
-          case t if t <:< definitions.IntTpe => "toIntArray"
-          case t if t <:< definitions.LongTpe => "toLongArray"
-          case t if t <:< definitions.DoubleTpe => "toDoubleArray"
-          case t if t <:< definitions.FloatTpe => "toFloatArray"
-          case t if t <:< definitions.ShortTpe => "toShortArray"
-          case t if t <:< definitions.ByteTpe => "toByteArray"
-          case t if t <:< definitions.BooleanTpe => "toBooleanArray"
+          case t if isSubtype(t, definitions.IntTpe) => "toIntArray"
+          case t if isSubtype(t, definitions.LongTpe) => "toLongArray"
+          case t if isSubtype(t, definitions.DoubleTpe) => "toDoubleArray"
+          case t if isSubtype(t, definitions.FloatTpe) => "toFloatArray"
+          case t if isSubtype(t, definitions.ShortTpe) => "toShortArray"
+          case t if isSubtype(t, definitions.ByteTpe) => "toByteArray"
+          case t if isSubtype(t, definitions.BooleanTpe) => "toBooleanArray"
           // non-primitive
           case _ => "array"
         }
@@ -256,8 +272,8 @@ object ScalaReflection extends ScalaReflection {
 
       // We serialize a `Set` to Catalyst array. When we deserialize a 
Catalyst array
       // to a `Set`, if there are duplicated elements, the elements will be 
de-duplicated.
-      case t if t <:< localTypeOf[Seq[_]] ||
-          t <:< localTypeOf[scala.collection.Set[_]] =>
+      case t if isSubtype(t, localTypeOf[Seq[_]]) ||
+          isSubtype(t, localTypeOf[scala.collection.Set[_]]) =>
         val TypeRef(_, _, Seq(elementType)) = t
         val Schema(dataType, elementNullable) = schemaFor(elementType)
         val className = getClassNameFromType(elementType)
@@ -274,14 +290,14 @@ object ScalaReflection extends ScalaReflection {
 
         val companion = t.dealias.typeSymbol.companion.typeSignature
         val cls = companion.member(TermName("newBuilder")) match {
-          case NoSymbol if t <:< localTypeOf[Seq[_]] => classOf[Seq[_]]
-          case NoSymbol if t <:< localTypeOf[scala.collection.Set[_]] =>
+          case NoSymbol if isSubtype(t, localTypeOf[Seq[_]]) => classOf[Seq[_]]
+          case NoSymbol if isSubtype(t, localTypeOf[scala.collection.Set[_]]) 
=>
             classOf[scala.collection.Set[_]]
           case _ => mirror.runtimeClass(t.typeSymbol.asClass)
         }
         UnresolvedMapObjects(mapFunction, path, Some(cls))
 
-      case t if t <:< localTypeOf[Map[_, _]] =>
+      case t if isSubtype(t, localTypeOf[Map[_, _]]) =>
         val TypeRef(_, _, Seq(keyType, valueType)) = t
 
         val classNameForKey = getClassNameFromType(keyType)
@@ -411,7 +427,7 @@ object ScalaReflection extends ScalaReflection {
     tpe.dealias match {
       case _ if !inputObject.dataType.isInstanceOf[ObjectType] => inputObject
 
-      case t if t <:< localTypeOf[Option[_]] =>
+      case t if isSubtype(t, localTypeOf[Option[_]]) =>
         val TypeRef(_, _, Seq(optType)) = t
         val className = getClassNameFromType(optType)
         val newPath = walkedTypePath.recordOption(className)
@@ -421,15 +437,15 @@ object ScalaReflection extends ScalaReflection {
       // Since List[_] also belongs to localTypeOf[Product], we put this case 
before
       // "case t if definedByConstructorParams(t)" to make sure it will match 
to the
       // case "localTypeOf[Seq[_]]"
-      case t if t <:< localTypeOf[Seq[_]] =>
+      case t if isSubtype(t, localTypeOf[Seq[_]]) =>
         val TypeRef(_, _, Seq(elementType)) = t
         toCatalystArray(inputObject, elementType)
 
-      case t if t <:< localTypeOf[Array[_]] =>
+      case t if isSubtype(t, localTypeOf[Array[_]]) =>
         val TypeRef(_, _, Seq(elementType)) = t
         toCatalystArray(inputObject, elementType)
 
-      case t if t <:< localTypeOf[Map[_, _]] =>
+      case t if isSubtype(t, localTypeOf[Map[_, _]]) =>
         val TypeRef(_, _, Seq(keyType, valueType)) = t
         val keyClsName = getClassNameFromType(keyType)
         val valueClsName = getClassNameFromType(valueType)
@@ -448,7 +464,7 @@ object ScalaReflection extends ScalaReflection {
             serializerFor(_, valueType, valuePath, seenTypeSet))
         )
 
-      case t if t <:< localTypeOf[scala.collection.Set[_]] =>
+      case t if isSubtype(t, localTypeOf[scala.collection.Set[_]]) =>
         val TypeRef(_, _, Seq(elementType)) = t
 
         // There's no corresponding Catalyst type for `Set`, we serialize a 
`Set` to Catalyst array.
@@ -461,35 +477,41 @@ object ScalaReflection extends ScalaReflection {
 
         toCatalystArray(newInput, elementType)
 
-      case t if t <:< localTypeOf[String] => 
createSerializerForString(inputObject)
+      case t if isSubtype(t, localTypeOf[String]) => 
createSerializerForString(inputObject)
 
-      case t if t <:< localTypeOf[java.time.Instant] => 
createSerializerForJavaInstant(inputObject)
+      case t if isSubtype(t, localTypeOf[java.time.Instant]) =>
+        createSerializerForJavaInstant(inputObject)
 
-      case t if t <:< localTypeOf[java.sql.Timestamp] =>
+      case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) =>
         createSerializerForSqlTimestamp(inputObject)
 
-      case t if t <:< localTypeOf[java.time.LocalDate] =>
+      case t if isSubtype(t, localTypeOf[java.time.LocalDate]) =>
         createSerializerForJavaLocalDate(inputObject)
 
-      case t if t <:< localTypeOf[java.sql.Date] => 
createSerializerForSqlDate(inputObject)
+      case t if isSubtype(t, localTypeOf[java.sql.Date]) => 
createSerializerForSqlDate(inputObject)
 
-      case t if t <:< localTypeOf[BigDecimal] => 
createSerializerForScalaBigDecimal(inputObject)
+      case t if isSubtype(t, localTypeOf[BigDecimal]) =>
+        createSerializerForScalaBigDecimal(inputObject)
 
-      case t if t <:< localTypeOf[java.math.BigDecimal] =>
+      case t if isSubtype(t, localTypeOf[java.math.BigDecimal]) =>
         createSerializerForJavaBigDecimal(inputObject)
 
-      case t if t <:< localTypeOf[java.math.BigInteger] =>
+      case t if isSubtype(t, localTypeOf[java.math.BigInteger]) =>
         createSerializerForJavaBigInteger(inputObject)
 
-      case t if t <:< localTypeOf[scala.math.BigInt] => 
createSerializerForScalaBigInt(inputObject)
+      case t if isSubtype(t, localTypeOf[scala.math.BigInt]) =>
+        createSerializerForScalaBigInt(inputObject)
 
-      case t if t <:< localTypeOf[java.lang.Integer] => 
createSerializerForInteger(inputObject)
-      case t if t <:< localTypeOf[java.lang.Long] => 
createSerializerForLong(inputObject)
-      case t if t <:< localTypeOf[java.lang.Double] => 
createSerializerForDouble(inputObject)
-      case t if t <:< localTypeOf[java.lang.Float] => 
createSerializerForFloat(inputObject)
-      case t if t <:< localTypeOf[java.lang.Short] => 
createSerializerForShort(inputObject)
-      case t if t <:< localTypeOf[java.lang.Byte] => 
createSerializerForByte(inputObject)
-      case t if t <:< localTypeOf[java.lang.Boolean] => 
createSerializerForBoolean(inputObject)
+      case t if isSubtype(t, localTypeOf[java.lang.Integer]) =>
+        createSerializerForInteger(inputObject)
+      case t if isSubtype(t, localTypeOf[java.lang.Long]) => 
createSerializerForLong(inputObject)
+      case t if isSubtype(t, localTypeOf[java.lang.Double]) =>
+        createSerializerForDouble(inputObject)
+      case t if isSubtype(t, localTypeOf[java.lang.Float]) => 
createSerializerForFloat(inputObject)
+      case t if isSubtype(t, localTypeOf[java.lang.Short]) => 
createSerializerForShort(inputObject)
+      case t if isSubtype(t, localTypeOf[java.lang.Byte]) => 
createSerializerForByte(inputObject)
+      case t if isSubtype(t, localTypeOf[java.lang.Boolean]) =>
+        createSerializerForBoolean(inputObject)
 
       case t if t.typeSymbol.annotations.exists(_.tree.tpe =:= 
typeOf[SQLUserDefinedType]) =>
         val udt = getClassFromType(t)
@@ -540,7 +562,7 @@ object ScalaReflection extends ScalaReflection {
    */
   def optionOfProductType(tpe: `Type`): Boolean = cleanUpReflectionObjects {
     tpe.dealias match {
-      case t if t <:< localTypeOf[Option[_]] =>
+      case t if isSubtype(t, localTypeOf[Option[_]]) =>
         val TypeRef(_, _, Seq(optType)) = t
         definedByConstructorParams(optType)
       case _ => false
@@ -606,7 +628,7 @@ object ScalaReflection extends ScalaReflection {
     tpe.dealias match {
       // this must be the first case, since all objects in scala are instances 
of Null, therefore
       // Null type would wrongly match the first of them, which is Option as 
of now
-      case t if t <:< definitions.NullTpe => Schema(NullType, nullable = true)
+      case t if isSubtype(t, definitions.NullTpe) => Schema(NullType, nullable 
= true)
       case t if t.typeSymbol.annotations.exists(_.tree.tpe =:= 
typeOf[SQLUserDefinedType]) =>
         val udt = 
getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().
           getConstructor().newInstance()
@@ -615,54 +637,58 @@ object ScalaReflection extends ScalaReflection {
         val udt = 
UDTRegistration.getUDTFor(getClassNameFromType(t)).get.getConstructor().
           newInstance().asInstanceOf[UserDefinedType[_]]
         Schema(udt, nullable = true)
-      case t if t <:< localTypeOf[Option[_]] =>
+      case t if isSubtype(t, localTypeOf[Option[_]]) =>
         val TypeRef(_, _, Seq(optType)) = t
         Schema(schemaFor(optType).dataType, nullable = true)
-      case t if t <:< localTypeOf[Array[Byte]] => Schema(BinaryType, nullable 
= true)
-      case t if t <:< localTypeOf[Array[_]] =>
+      case t if isSubtype(t, localTypeOf[Array[Byte]]) => Schema(BinaryType, 
nullable = true)
+      case t if isSubtype(t, localTypeOf[Array[_]]) =>
         val TypeRef(_, _, Seq(elementType)) = t
         val Schema(dataType, nullable) = schemaFor(elementType)
         Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
-      case t if t <:< localTypeOf[Seq[_]] =>
+      case t if isSubtype(t, localTypeOf[Seq[_]]) =>
         val TypeRef(_, _, Seq(elementType)) = t
         val Schema(dataType, nullable) = schemaFor(elementType)
         Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
-      case t if t <:< localTypeOf[Map[_, _]] =>
+      case t if isSubtype(t, localTypeOf[Map[_, _]]) =>
         val TypeRef(_, _, Seq(keyType, valueType)) = t
         val Schema(valueDataType, valueNullable) = schemaFor(valueType)
         Schema(MapType(schemaFor(keyType).dataType,
           valueDataType, valueContainsNull = valueNullable), nullable = true)
-      case t if t <:< localTypeOf[Set[_]] =>
+      case t if isSubtype(t, localTypeOf[Set[_]]) =>
         val TypeRef(_, _, Seq(elementType)) = t
         val Schema(dataType, nullable) = schemaFor(elementType)
         Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
-      case t if t <:< localTypeOf[String] => Schema(StringType, nullable = 
true)
-      case t if t <:< localTypeOf[java.time.Instant] => Schema(TimestampType, 
nullable = true)
-      case t if t <:< localTypeOf[java.sql.Timestamp] => Schema(TimestampType, 
nullable = true)
-      case t if t <:< localTypeOf[java.time.LocalDate] => Schema(DateType, 
nullable = true)
-      case t if t <:< localTypeOf[java.sql.Date] => Schema(DateType, nullable 
= true)
-      case t if t <:< localTypeOf[BigDecimal] => 
Schema(DecimalType.SYSTEM_DEFAULT, nullable = true)
-      case t if t <:< localTypeOf[java.math.BigDecimal] =>
+      case t if isSubtype(t, localTypeOf[String]) => Schema(StringType, 
nullable = true)
+      case t if isSubtype(t, localTypeOf[java.time.Instant]) =>
+        Schema(TimestampType, nullable = true)
+      case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) =>
+        Schema(TimestampType, nullable = true)
+      case t if isSubtype(t, localTypeOf[java.time.LocalDate]) => 
Schema(DateType, nullable = true)
+      case t if isSubtype(t, localTypeOf[java.sql.Date]) => Schema(DateType, 
nullable = true)
+      case t if isSubtype(t, localTypeOf[BigDecimal]) =>
         Schema(DecimalType.SYSTEM_DEFAULT, nullable = true)
-      case t if t <:< localTypeOf[java.math.BigInteger] =>
+      case t if isSubtype(t, localTypeOf[java.math.BigDecimal]) =>
+        Schema(DecimalType.SYSTEM_DEFAULT, nullable = true)
+      case t if isSubtype(t, localTypeOf[java.math.BigInteger]) =>
         Schema(DecimalType.BigIntDecimal, nullable = true)
-      case t if t <:< localTypeOf[scala.math.BigInt] =>
+      case t if isSubtype(t, localTypeOf[scala.math.BigInt]) =>
         Schema(DecimalType.BigIntDecimal, nullable = true)
-      case t if t <:< localTypeOf[Decimal] => 
Schema(DecimalType.SYSTEM_DEFAULT, nullable = true)
-      case t if t <:< localTypeOf[java.lang.Integer] => Schema(IntegerType, 
nullable = true)
-      case t if t <:< localTypeOf[java.lang.Long] => Schema(LongType, nullable 
= true)
-      case t if t <:< localTypeOf[java.lang.Double] => Schema(DoubleType, 
nullable = true)
-      case t if t <:< localTypeOf[java.lang.Float] => Schema(FloatType, 
nullable = true)
-      case t if t <:< localTypeOf[java.lang.Short] => Schema(ShortType, 
nullable = true)
-      case t if t <:< localTypeOf[java.lang.Byte] => Schema(ByteType, nullable 
= true)
-      case t if t <:< localTypeOf[java.lang.Boolean] => Schema(BooleanType, 
nullable = true)
-      case t if t <:< definitions.IntTpe => Schema(IntegerType, nullable = 
false)
-      case t if t <:< definitions.LongTpe => Schema(LongType, nullable = false)
-      case t if t <:< definitions.DoubleTpe => Schema(DoubleType, nullable = 
false)
-      case t if t <:< definitions.FloatTpe => Schema(FloatType, nullable = 
false)
-      case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = 
false)
-      case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false)
-      case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = 
false)
+      case t if isSubtype(t, localTypeOf[Decimal]) =>
+        Schema(DecimalType.SYSTEM_DEFAULT, nullable = true)
+      case t if isSubtype(t, localTypeOf[java.lang.Integer]) => 
Schema(IntegerType, nullable = true)
+      case t if isSubtype(t, localTypeOf[java.lang.Long]) => Schema(LongType, 
nullable = true)
+      case t if isSubtype(t, localTypeOf[java.lang.Double]) => 
Schema(DoubleType, nullable = true)
+      case t if isSubtype(t, localTypeOf[java.lang.Float]) => 
Schema(FloatType, nullable = true)
+      case t if isSubtype(t, localTypeOf[java.lang.Short]) => 
Schema(ShortType, nullable = true)
+      case t if isSubtype(t, localTypeOf[java.lang.Byte]) => Schema(ByteType, 
nullable = true)
+      case t if isSubtype(t, localTypeOf[java.lang.Boolean]) => 
Schema(BooleanType, nullable = true)
+      case t if isSubtype(t, definitions.IntTpe) => Schema(IntegerType, 
nullable = false)
+      case t if isSubtype(t, definitions.LongTpe) => Schema(LongType, nullable 
= false)
+      case t if isSubtype(t, definitions.DoubleTpe) => Schema(DoubleType, 
nullable = false)
+      case t if isSubtype(t, definitions.FloatTpe) => Schema(FloatType, 
nullable = false)
+      case t if isSubtype(t, definitions.ShortTpe) => Schema(ShortType, 
nullable = false)
+      case t if isSubtype(t, definitions.ByteTpe) => Schema(ByteType, nullable 
= false)
+      case t if isSubtype(t, definitions.BooleanTpe) => Schema(BooleanType, 
nullable = false)
       case t if definedByConstructorParams(t) =>
         val params = getConstructorParameters(t)
         Schema(StructType(
@@ -715,9 +741,9 @@ object ScalaReflection extends ScalaReflection {
   def definedByConstructorParams(tpe: Type): Boolean = 
cleanUpReflectionObjects {
     tpe.dealias match {
       // `Option` is a `Product`, but we don't wanna treat `Option[Int]` as a 
struct type.
-      case t if t <:< localTypeOf[Option[_]] => 
definedByConstructorParams(t.typeArgs.head)
-      case _ => tpe.dealias <:< localTypeOf[Product] ||
-        tpe.dealias <:< localTypeOf[DefinedByConstructorParams]
+      case t if isSubtype(t, localTypeOf[Option[_]]) => 
definedByConstructorParams(t.typeArgs.head)
+      case _ => isSubtype(tpe.dealias, localTypeOf[Product]) ||
+        isSubtype(tpe.dealias, localTypeOf[DefinedByConstructorParams])
     }
   }
 
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
index 80824cc..e8df031 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
@@ -145,6 +145,12 @@ class ScalaReflectionSuite extends SparkFunSuite {
   private def deserializerFor[T: TypeTag]: Expression =
     deserializerForType(ScalaReflection.localTypeOf[T])
 
+  test("isSubtype") {
+    assert(isSubtype(localTypeOf[Option[Int]], localTypeOf[Option[_]]))
+    assert(isSubtype(localTypeOf[Option[Int]], localTypeOf[Option[Int]]))
+    assert(!isSubtype(localTypeOf[Option[_]], localTypeOf[Option[Int]]))
+  }
+
   test("SQLUserDefinedType annotation on Scala structure") {
     val schema = schemaFor[TestingUDT.NestedStruct]
     assert(schema === Schema(


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to