This is an automated email from the ASF dual-hosted git repository.

joshrosen pushed a commit to branch branch-2.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-2.4 by this push:
     new ba7f61e  [SPARK-26555][SQL][BRANCH-2.4] make ScalaReflection subtype 
checking thread safe
ba7f61e is described below

commit ba7f61e25d58aa379f94a23b03503a25574529bc
Author: mwlon <mlonca...@hmc.edu>
AuthorDate: Wed Jun 19 19:03:35 2019 -0700

    [SPARK-26555][SQL][BRANCH-2.4] make ScalaReflection subtype checking thread 
safe
    
    This is a Spark 2.4.x backport of #24085. Original description follows 
below:
    
    ## What changes were proposed in this pull request?
    
    Make ScalaReflection subtype checking thread safe by adding a lock. There 
is a thread safety bug in the <:< operator in all versions of scala 
(https://github.com/scala/bug/issues/10766).
    
    ## How was this patch tested?
    
    Existing tests and a new one for the new subtype checking function.
    
    Closes #24913 from JoshRosen/joshrosen/SPARK-26555-branch-2.4-backport.
    
    Authored-by: mwlon <mlonca...@hmc.edu>
    Signed-off-by: Josh Rosen <rosenvi...@gmail.com>
---
 .../spark/sql/catalyst/ScalaReflection.scala       | 216 +++++++++++----------
 .../spark/sql/catalyst/ScalaReflectionSuite.scala  |   6 +
 2 files changed, 124 insertions(+), 98 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index c27180e..1b186bf 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -40,6 +40,9 @@ import org.apache.spark.unsafe.types.{CalendarInterval, 
UTF8String}
 trait DefinedByConstructorParams
 
 
+private[catalyst] object ScalaSubtypeLock
+
+
 /**
  * A default version of ScalaReflection that uses the runtime universe.
  */
@@ -68,19 +71,32 @@ object ScalaReflection extends ScalaReflection {
    */
   def dataTypeFor[T : TypeTag]: DataType = dataTypeFor(localTypeOf[T])
 
+  /**
+   * Synchronize to prevent concurrent usage of `<:<` operator.
+   * This operator is not thread safe in any current version of scala; i.e.
+   * (2.11.12, 2.12.8, 2.13.0-M5).
+   *
+   * See https://github.com/scala/bug/issues/10766
+   */
+  private[catalyst] def isSubtype(tpe1: `Type`, tpe2: `Type`): Boolean = {
+    ScalaSubtypeLock.synchronized {
+      tpe1 <:< tpe2
+    }
+  }
+
   private def dataTypeFor(tpe: `Type`): DataType = cleanUpReflectionObjects {
     tpe.dealias match {
-      case t if t <:< definitions.NullTpe => NullType
-      case t if t <:< definitions.IntTpe => IntegerType
-      case t if t <:< definitions.LongTpe => LongType
-      case t if t <:< definitions.DoubleTpe => DoubleType
-      case t if t <:< definitions.FloatTpe => FloatType
-      case t if t <:< definitions.ShortTpe => ShortType
-      case t if t <:< definitions.ByteTpe => ByteType
-      case t if t <:< definitions.BooleanTpe => BooleanType
-      case t if t <:< localTypeOf[Array[Byte]] => BinaryType
-      case t if t <:< localTypeOf[CalendarInterval] => CalendarIntervalType
-      case t if t <:< localTypeOf[Decimal] => DecimalType.SYSTEM_DEFAULT
+      case t if isSubtype(t, definitions.NullTpe) => NullType
+      case t if isSubtype(t, definitions.IntTpe) => IntegerType
+      case t if isSubtype(t, definitions.LongTpe) => LongType
+      case t if isSubtype(t, definitions.DoubleTpe) => DoubleType
+      case t if isSubtype(t, definitions.FloatTpe) => FloatType
+      case t if isSubtype(t, definitions.ShortTpe) => ShortType
+      case t if isSubtype(t, definitions.ByteTpe) => ByteType
+      case t if isSubtype(t, definitions.BooleanTpe) => BooleanType
+      case t if isSubtype(t, localTypeOf[Array[Byte]]) => BinaryType
+      case t if isSubtype(t, localTypeOf[CalendarInterval]) => 
CalendarIntervalType
+      case t if isSubtype(t, localTypeOf[Decimal]) => 
DecimalType.SYSTEM_DEFAULT
       case _ =>
         val className = getClassNameFromType(tpe)
         className match {
@@ -103,13 +119,13 @@ object ScalaReflection extends ScalaReflection {
    */
   private def arrayClassFor(tpe: `Type`): ObjectType = 
cleanUpReflectionObjects {
     val cls = tpe.dealias match {
-      case t if t <:< definitions.IntTpe => classOf[Array[Int]]
-      case t if t <:< definitions.LongTpe => classOf[Array[Long]]
-      case t if t <:< definitions.DoubleTpe => classOf[Array[Double]]
-      case t if t <:< definitions.FloatTpe => classOf[Array[Float]]
-      case t if t <:< definitions.ShortTpe => classOf[Array[Short]]
-      case t if t <:< definitions.ByteTpe => classOf[Array[Byte]]
-      case t if t <:< definitions.BooleanTpe => classOf[Array[Boolean]]
+      case t if isSubtype(t, definitions.IntTpe) => classOf[Array[Int]]
+      case t if isSubtype(t, definitions.LongTpe) => classOf[Array[Long]]
+      case t if isSubtype(t, definitions.DoubleTpe) => classOf[Array[Double]]
+      case t if isSubtype(t, definitions.FloatTpe) => classOf[Array[Float]]
+      case t if isSubtype(t, definitions.ShortTpe) => classOf[Array[Short]]
+      case t if isSubtype(t, definitions.ByteTpe) => classOf[Array[Byte]]
+      case t if isSubtype(t, definitions.BooleanTpe) => classOf[Array[Boolean]]
       case other =>
         // There is probably a better way to do this, but I couldn't find it...
         val elementType = dataTypeFor(other).asInstanceOf[ObjectType].cls
@@ -210,48 +226,48 @@ object ScalaReflection extends ScalaReflection {
     tpe.dealias match {
       case t if !dataTypeFor(t).isInstanceOf[ObjectType] => getPath
 
-      case t if t <:< localTypeOf[Option[_]] =>
+      case t if isSubtype(t, localTypeOf[Option[_]]) =>
         val TypeRef(_, _, Seq(optType)) = t
         val className = getClassNameFromType(optType)
         val newTypePath = s"""- option value class: "$className"""" +: 
walkedTypePath
         WrapOption(deserializerFor(optType, path, newTypePath), 
dataTypeFor(optType))
 
-      case t if t <:< localTypeOf[java.lang.Integer] =>
+      case t if isSubtype(t, localTypeOf[java.lang.Integer]) =>
         val boxedType = classOf[java.lang.Integer]
         val objectType = ObjectType(boxedType)
         StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, 
returnNullable = false)
 
-      case t if t <:< localTypeOf[java.lang.Long] =>
+      case t if isSubtype(t, localTypeOf[java.lang.Long]) =>
         val boxedType = classOf[java.lang.Long]
         val objectType = ObjectType(boxedType)
         StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, 
returnNullable = false)
 
-      case t if t <:< localTypeOf[java.lang.Double] =>
+      case t if isSubtype(t, localTypeOf[java.lang.Double]) =>
         val boxedType = classOf[java.lang.Double]
         val objectType = ObjectType(boxedType)
         StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, 
returnNullable = false)
 
-      case t if t <:< localTypeOf[java.lang.Float] =>
+      case t if isSubtype(t, localTypeOf[java.lang.Float]) =>
         val boxedType = classOf[java.lang.Float]
         val objectType = ObjectType(boxedType)
         StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, 
returnNullable = false)
 
-      case t if t <:< localTypeOf[java.lang.Short] =>
+      case t if isSubtype(t, localTypeOf[java.lang.Short]) =>
         val boxedType = classOf[java.lang.Short]
         val objectType = ObjectType(boxedType)
         StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, 
returnNullable = false)
 
-      case t if t <:< localTypeOf[java.lang.Byte] =>
+      case t if isSubtype(t, localTypeOf[java.lang.Byte]) =>
         val boxedType = classOf[java.lang.Byte]
         val objectType = ObjectType(boxedType)
         StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, 
returnNullable = false)
 
-      case t if t <:< localTypeOf[java.lang.Boolean] =>
+      case t if isSubtype(t, localTypeOf[java.lang.Boolean]) =>
         val boxedType = classOf[java.lang.Boolean]
         val objectType = ObjectType(boxedType)
         StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, 
returnNullable = false)
 
-      case t if t <:< localTypeOf[java.sql.Date] =>
+      case t if isSubtype(t, localTypeOf[java.sql.Date]) =>
         StaticInvoke(
           DateTimeUtils.getClass,
           ObjectType(classOf[java.sql.Date]),
@@ -259,7 +275,7 @@ object ScalaReflection extends ScalaReflection {
           getPath :: Nil,
           returnNullable = false)
 
-      case t if t <:< localTypeOf[java.sql.Timestamp] =>
+      case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) =>
         StaticInvoke(
           DateTimeUtils.getClass,
           ObjectType(classOf[java.sql.Timestamp]),
@@ -267,25 +283,25 @@ object ScalaReflection extends ScalaReflection {
           getPath :: Nil,
           returnNullable = false)
 
-      case t if t <:< localTypeOf[java.lang.String] =>
+      case t if isSubtype(t, localTypeOf[java.lang.String]) =>
         Invoke(getPath, "toString", ObjectType(classOf[String]), 
returnNullable = false)
 
-      case t if t <:< localTypeOf[java.math.BigDecimal] =>
+      case t if isSubtype(t, localTypeOf[java.math.BigDecimal]) =>
         Invoke(getPath, "toJavaBigDecimal", 
ObjectType(classOf[java.math.BigDecimal]),
           returnNullable = false)
 
-      case t if t <:< localTypeOf[BigDecimal] =>
+      case t if isSubtype(t, localTypeOf[BigDecimal]) =>
         Invoke(getPath, "toBigDecimal", ObjectType(classOf[BigDecimal]), 
returnNullable = false)
 
-      case t if t <:< localTypeOf[java.math.BigInteger] =>
+      case t if isSubtype(t, localTypeOf[java.math.BigInteger]) =>
         Invoke(getPath, "toJavaBigInteger", 
ObjectType(classOf[java.math.BigInteger]),
           returnNullable = false)
 
-      case t if t <:< localTypeOf[scala.math.BigInt] =>
+      case t if isSubtype(t, localTypeOf[scala.math.BigInt]) =>
         Invoke(getPath, "toScalaBigInt", 
ObjectType(classOf[scala.math.BigInt]),
           returnNullable = false)
 
-      case t if t <:< localTypeOf[Array[_]] =>
+      case t if isSubtype(t, localTypeOf[Array[_]]) =>
         val TypeRef(_, _, Seq(elementType)) = t
         val Schema(dataType, elementNullable) = schemaFor(elementType)
         val className = getClassNameFromType(elementType)
@@ -309,13 +325,13 @@ object ScalaReflection extends ScalaReflection {
           Invoke(arrayData, "array", arrayCls, returnNullable = false)
         } else {
           val primitiveMethod = elementType match {
-            case t if t <:< definitions.IntTpe => "toIntArray"
-            case t if t <:< definitions.LongTpe => "toLongArray"
-            case t if t <:< definitions.DoubleTpe => "toDoubleArray"
-            case t if t <:< definitions.FloatTpe => "toFloatArray"
-            case t if t <:< definitions.ShortTpe => "toShortArray"
-            case t if t <:< definitions.ByteTpe => "toByteArray"
-            case t if t <:< definitions.BooleanTpe => "toBooleanArray"
+            case t if isSubtype(t, definitions.IntTpe) => "toIntArray"
+            case t if isSubtype(t, definitions.LongTpe) => "toLongArray"
+            case t if isSubtype(t, definitions.DoubleTpe) => "toDoubleArray"
+            case t if isSubtype(t, definitions.FloatTpe) => "toFloatArray"
+            case t if isSubtype(t, definitions.ShortTpe) => "toShortArray"
+            case t if isSubtype(t, definitions.ByteTpe) => "toByteArray"
+            case t if isSubtype(t, definitions.BooleanTpe) => "toBooleanArray"
             case other => throw new IllegalStateException("expect primitive 
array element type " +
               "but got " + other)
           }
@@ -324,8 +340,8 @@ object ScalaReflection extends ScalaReflection {
 
       // We serialize a `Set` to Catalyst array. When we deserialize a 
Catalyst array
       // to a `Set`, if there are duplicated elements, the elements will be 
de-duplicated.
-      case t if t <:< localTypeOf[Seq[_]] ||
-          t <:< localTypeOf[scala.collection.Set[_]] =>
+      case t if isSubtype(t, localTypeOf[Seq[_]]) ||
+          isSubtype(t, localTypeOf[scala.collection.Set[_]]) =>
         val TypeRef(_, _, Seq(elementType)) = t
         val Schema(dataType, elementNullable) = schemaFor(elementType)
         val className = getClassNameFromType(elementType)
@@ -344,14 +360,14 @@ object ScalaReflection extends ScalaReflection {
 
         val companion = t.dealias.typeSymbol.companion.typeSignature
         val cls = companion.member(TermName("newBuilder")) match {
-          case NoSymbol if t <:< localTypeOf[Seq[_]] => classOf[Seq[_]]
-          case NoSymbol if t <:< localTypeOf[scala.collection.Set[_]] =>
+          case NoSymbol if isSubtype(t, localTypeOf[Seq[_]]) => classOf[Seq[_]]
+          case NoSymbol if isSubtype(t, localTypeOf[scala.collection.Set[_]]) 
=>
             classOf[scala.collection.Set[_]]
           case _ => mirror.runtimeClass(t.typeSymbol.asClass)
         }
         UnresolvedMapObjects(mapFunction, getPath, Some(cls))
 
-      case t if t <:< localTypeOf[Map[_, _]] =>
+      case t if isSubtype(t, localTypeOf[Map[_, _]]) =>
         // TODO: add walked type path for map
         val TypeRef(_, _, Seq(keyType, valueType)) = t
 
@@ -486,7 +502,7 @@ object ScalaReflection extends ScalaReflection {
     tpe.dealias match {
       case _ if !inputObject.dataType.isInstanceOf[ObjectType] => inputObject
 
-      case t if t <:< localTypeOf[Option[_]] =>
+      case t if isSubtype(t, localTypeOf[Option[_]]) =>
         val TypeRef(_, _, Seq(optType)) = t
         val className = getClassNameFromType(optType)
         val newPath = s"""- option value class: "$className"""" +: 
walkedTypePath
@@ -496,15 +512,15 @@ object ScalaReflection extends ScalaReflection {
       // Since List[_] also belongs to localTypeOf[Product], we put this case 
before
       // "case t if definedByConstructorParams(t)" to make sure it will match 
to the
       // case "localTypeOf[Seq[_]]"
-      case t if t <:< localTypeOf[Seq[_]] =>
+      case t if isSubtype(t, localTypeOf[Seq[_]]) =>
         val TypeRef(_, _, Seq(elementType)) = t
         toCatalystArray(inputObject, elementType)
 
-      case t if t <:< localTypeOf[Array[_]] =>
+      case t if isSubtype(t, localTypeOf[Array[_]]) =>
         val TypeRef(_, _, Seq(elementType)) = t
         toCatalystArray(inputObject, elementType)
 
-      case t if t <:< localTypeOf[Map[_, _]] =>
+      case t if isSubtype(t, localTypeOf[Map[_, _]]) =>
         val TypeRef(_, _, Seq(keyType, valueType)) = t
         val keyClsName = getClassNameFromType(keyType)
         val valueClsName = getClassNameFromType(valueType)
@@ -520,7 +536,7 @@ object ScalaReflection extends ScalaReflection {
           serializerFor(_, valueType, valuePath, seenTypeSet),
           valueNullable = !valueType.typeSymbol.asClass.isPrimitive)
 
-      case t if t <:< localTypeOf[scala.collection.Set[_]] =>
+      case t if isSubtype(t, localTypeOf[scala.collection.Set[_]]) =>
         val TypeRef(_, _, Seq(elementType)) = t
 
         // There's no corresponding Catalyst type for `Set`, we serialize a 
`Set` to Catalyst array.
@@ -533,7 +549,7 @@ object ScalaReflection extends ScalaReflection {
 
         toCatalystArray(newInput, elementType)
 
-      case t if t <:< localTypeOf[String] =>
+      case t if isSubtype(t, localTypeOf[String]) =>
         StaticInvoke(
           classOf[UTF8String],
           StringType,
@@ -541,7 +557,7 @@ object ScalaReflection extends ScalaReflection {
           inputObject :: Nil,
           returnNullable = false)
 
-      case t if t <:< localTypeOf[java.sql.Timestamp] =>
+      case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) =>
         StaticInvoke(
           DateTimeUtils.getClass,
           TimestampType,
@@ -549,7 +565,7 @@ object ScalaReflection extends ScalaReflection {
           inputObject :: Nil,
           returnNullable = false)
 
-      case t if t <:< localTypeOf[java.sql.Date] =>
+      case t if isSubtype(t, localTypeOf[java.sql.Date]) =>
         StaticInvoke(
           DateTimeUtils.getClass,
           DateType,
@@ -557,7 +573,7 @@ object ScalaReflection extends ScalaReflection {
           inputObject :: Nil,
           returnNullable = false)
 
-      case t if t <:< localTypeOf[BigDecimal] =>
+      case t if isSubtype(t, localTypeOf[BigDecimal]) =>
         StaticInvoke(
           Decimal.getClass,
           DecimalType.SYSTEM_DEFAULT,
@@ -565,7 +581,7 @@ object ScalaReflection extends ScalaReflection {
           inputObject :: Nil,
           returnNullable = false)
 
-      case t if t <:< localTypeOf[java.math.BigDecimal] =>
+      case t if isSubtype(t, localTypeOf[java.math.BigDecimal]) =>
         StaticInvoke(
           Decimal.getClass,
           DecimalType.SYSTEM_DEFAULT,
@@ -573,7 +589,7 @@ object ScalaReflection extends ScalaReflection {
           inputObject :: Nil,
           returnNullable = false)
 
-      case t if t <:< localTypeOf[java.math.BigInteger] =>
+      case t if isSubtype(t, localTypeOf[java.math.BigInteger]) =>
         StaticInvoke(
           Decimal.getClass,
           DecimalType.BigIntDecimal,
@@ -581,7 +597,7 @@ object ScalaReflection extends ScalaReflection {
           inputObject :: Nil,
           returnNullable = false)
 
-      case t if t <:< localTypeOf[scala.math.BigInt] =>
+      case t if isSubtype(t, localTypeOf[scala.math.BigInt]) =>
         StaticInvoke(
           Decimal.getClass,
           DecimalType.BigIntDecimal,
@@ -589,19 +605,19 @@ object ScalaReflection extends ScalaReflection {
           inputObject :: Nil,
           returnNullable = false)
 
-      case t if t <:< localTypeOf[java.lang.Integer] =>
+      case t if isSubtype(t, localTypeOf[java.lang.Integer]) =>
         Invoke(inputObject, "intValue", IntegerType)
-      case t if t <:< localTypeOf[java.lang.Long] =>
+      case t if isSubtype(t, localTypeOf[java.lang.Long]) =>
         Invoke(inputObject, "longValue", LongType)
-      case t if t <:< localTypeOf[java.lang.Double] =>
+      case t if isSubtype(t, localTypeOf[java.lang.Double]) =>
         Invoke(inputObject, "doubleValue", DoubleType)
-      case t if t <:< localTypeOf[java.lang.Float] =>
+      case t if isSubtype(t, localTypeOf[java.lang.Float]) =>
         Invoke(inputObject, "floatValue", FloatType)
-      case t if t <:< localTypeOf[java.lang.Short] =>
+      case t if isSubtype(t, localTypeOf[java.lang.Short]) =>
         Invoke(inputObject, "shortValue", ShortType)
-      case t if t <:< localTypeOf[java.lang.Byte] =>
+      case t if isSubtype(t, localTypeOf[java.lang.Byte]) =>
         Invoke(inputObject, "byteValue", ByteType)
-      case t if t <:< localTypeOf[java.lang.Boolean] =>
+      case t if isSubtype(t, localTypeOf[java.lang.Boolean]) =>
         Invoke(inputObject, "booleanValue", BooleanType)
 
       case t if t.typeSymbol.annotations.exists(_.tree.tpe =:= 
typeOf[SQLUserDefinedType]) =>
@@ -658,7 +674,7 @@ object ScalaReflection extends ScalaReflection {
    */
   def optionOfProductType(tpe: `Type`): Boolean = cleanUpReflectionObjects {
     tpe.dealias match {
-      case t if t <:< localTypeOf[Option[_]] =>
+      case t if isSubtype(t, localTypeOf[Option[_]]) =>
         val TypeRef(_, _, Seq(optType)) = t
         definedByConstructorParams(optType)
       case _ => false
@@ -724,7 +740,7 @@ object ScalaReflection extends ScalaReflection {
     tpe.dealias match {
       // this must be the first case, since all objects in scala are instances 
of Null, therefore
       // Null type would wrongly match the first of them, which is Option as 
of now
-      case t if t <:< definitions.NullTpe => Schema(NullType, nullable = true)
+      case t if isSubtype(t, definitions.NullTpe) => Schema(NullType, nullable 
= true)
       case t if t.typeSymbol.annotations.exists(_.tree.tpe =:= 
typeOf[SQLUserDefinedType]) =>
         val udt = 
getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance()
         Schema(udt, nullable = true)
@@ -732,52 +748,56 @@ object ScalaReflection extends ScalaReflection {
         val udt = 
UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance()
           .asInstanceOf[UserDefinedType[_]]
         Schema(udt, nullable = true)
-      case t if t <:< localTypeOf[Option[_]] =>
+      case t if isSubtype(t, localTypeOf[Option[_]]) =>
         val TypeRef(_, _, Seq(optType)) = t
         Schema(schemaFor(optType).dataType, nullable = true)
-      case t if t <:< localTypeOf[Array[Byte]] => Schema(BinaryType, nullable 
= true)
-      case t if t <:< localTypeOf[Array[_]] =>
+      case t if isSubtype(t, localTypeOf[Array[Byte]]) => Schema(BinaryType, 
nullable = true)
+      case t if isSubtype(t, localTypeOf[Array[_]]) =>
         val TypeRef(_, _, Seq(elementType)) = t
         val Schema(dataType, nullable) = schemaFor(elementType)
         Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
-      case t if t <:< localTypeOf[Seq[_]] =>
+      case t if isSubtype(t, localTypeOf[Seq[_]]) =>
         val TypeRef(_, _, Seq(elementType)) = t
         val Schema(dataType, nullable) = schemaFor(elementType)
         Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
-      case t if t <:< localTypeOf[Map[_, _]] =>
+      case t if isSubtype(t, localTypeOf[Map[_, _]]) =>
         val TypeRef(_, _, Seq(keyType, valueType)) = t
         val Schema(valueDataType, valueNullable) = schemaFor(valueType)
         Schema(MapType(schemaFor(keyType).dataType,
           valueDataType, valueContainsNull = valueNullable), nullable = true)
-      case t if t <:< localTypeOf[Set[_]] =>
+      case t if isSubtype(t, localTypeOf[Set[_]]) =>
         val TypeRef(_, _, Seq(elementType)) = t
         val Schema(dataType, nullable) = schemaFor(elementType)
         Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
-      case t if t <:< localTypeOf[String] => Schema(StringType, nullable = 
true)
-      case t if t <:< localTypeOf[java.sql.Timestamp] => Schema(TimestampType, 
nullable = true)
-      case t if t <:< localTypeOf[java.sql.Date] => Schema(DateType, nullable 
= true)
-      case t if t <:< localTypeOf[BigDecimal] => 
Schema(DecimalType.SYSTEM_DEFAULT, nullable = true)
-      case t if t <:< localTypeOf[java.math.BigDecimal] =>
+      case t if isSubtype(t, localTypeOf[String]) => Schema(StringType, 
nullable = true)
+      case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) =>
+        Schema(TimestampType, nullable = true)
+      case t if isSubtype(t, localTypeOf[java.sql.Date]) =>
+        Schema(DateType, nullable = true)
+      case t if isSubtype(t, localTypeOf[BigDecimal]) =>
         Schema(DecimalType.SYSTEM_DEFAULT, nullable = true)
-      case t if t <:< localTypeOf[java.math.BigInteger] =>
+      case t if isSubtype(t, localTypeOf[java.math.BigDecimal]) =>
+        Schema(DecimalType.SYSTEM_DEFAULT, nullable = true)
+      case t if isSubtype(t, localTypeOf[java.math.BigInteger]) =>
         Schema(DecimalType.BigIntDecimal, nullable = true)
-      case t if t <:< localTypeOf[scala.math.BigInt] =>
+      case t if isSubtype(t, localTypeOf[scala.math.BigInt]) =>
         Schema(DecimalType.BigIntDecimal, nullable = true)
-      case t if t <:< localTypeOf[Decimal] => 
Schema(DecimalType.SYSTEM_DEFAULT, nullable = true)
-      case t if t <:< localTypeOf[java.lang.Integer] => Schema(IntegerType, 
nullable = true)
-      case t if t <:< localTypeOf[java.lang.Long] => Schema(LongType, nullable 
= true)
-      case t if t <:< localTypeOf[java.lang.Double] => Schema(DoubleType, 
nullable = true)
-      case t if t <:< localTypeOf[java.lang.Float] => Schema(FloatType, 
nullable = true)
-      case t if t <:< localTypeOf[java.lang.Short] => Schema(ShortType, 
nullable = true)
-      case t if t <:< localTypeOf[java.lang.Byte] => Schema(ByteType, nullable 
= true)
-      case t if t <:< localTypeOf[java.lang.Boolean] => Schema(BooleanType, 
nullable = true)
-      case t if t <:< definitions.IntTpe => Schema(IntegerType, nullable = 
false)
-      case t if t <:< definitions.LongTpe => Schema(LongType, nullable = false)
-      case t if t <:< definitions.DoubleTpe => Schema(DoubleType, nullable = 
false)
-      case t if t <:< definitions.FloatTpe => Schema(FloatType, nullable = 
false)
-      case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = 
false)
-      case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false)
-      case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = 
false)
+      case t if isSubtype(t, localTypeOf[Decimal]) =>
+        Schema(DecimalType.SYSTEM_DEFAULT, nullable = true)
+      case t if isSubtype(t, localTypeOf[java.lang.Integer]) => 
Schema(IntegerType, nullable = true)
+      case t if isSubtype(t, localTypeOf[java.lang.Long]) => Schema(LongType, 
nullable = true)
+      case t if isSubtype(t, localTypeOf[java.lang.Double]) => 
Schema(DoubleType, nullable = true)
+      case t if isSubtype(t, localTypeOf[java.lang.Float]) => 
Schema(FloatType, nullable = true)
+      case t if isSubtype(t, localTypeOf[java.lang.Short]) => 
Schema(ShortType, nullable = true)
+      case t if isSubtype(t, localTypeOf[java.lang.Byte]) => Schema(ByteType, 
nullable = true)
+      case t if isSubtype(t, localTypeOf[java.lang.Boolean]) => 
Schema(BooleanType, nullable = true)
+      case t if isSubtype(t, definitions.IntTpe) => Schema(IntegerType, 
nullable = false)
+      case t if isSubtype(t, definitions.LongTpe) => Schema(LongType, nullable 
= false)
+      case t if isSubtype(t, definitions.DoubleTpe) => Schema(DoubleType, 
nullable = false)
+      case t if isSubtype(t, definitions.FloatTpe) => Schema(FloatType, 
nullable = false)
+      case t if isSubtype(t, definitions.ShortTpe) => Schema(ShortType, 
nullable = false)
+      case t if isSubtype(t, definitions.ByteTpe) => Schema(ByteType, nullable 
= false)
+      case t if isSubtype(t, definitions.BooleanTpe) => Schema(BooleanType, 
nullable = false)
       case t if definedByConstructorParams(t) =>
         val params = getConstructorParameters(t)
         Schema(StructType(
@@ -805,9 +825,9 @@ object ScalaReflection extends ScalaReflection {
   def definedByConstructorParams(tpe: Type): Boolean = 
cleanUpReflectionObjects {
     tpe.dealias match {
       // `Option` is a `Product`, but we don't wanna treat `Option[Int]` as a 
struct type.
-      case t if t <:< localTypeOf[Option[_]] => 
definedByConstructorParams(t.typeArgs.head)
-      case _ => tpe.dealias <:< localTypeOf[Product] ||
-        tpe.dealias <:< localTypeOf[DefinedByConstructorParams]
+      case t if isSubtype(t, localTypeOf[Option[_]]) => 
definedByConstructorParams(t.typeArgs.head)
+      case _ => isSubtype(tpe.dealias, localTypeOf[Product]) ||
+        isSubtype(tpe.dealias, localTypeOf[DefinedByConstructorParams])
     }
   }
 
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
index f9ee948..38b8e7b 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
@@ -112,6 +112,12 @@ object TestingUDT {
 class ScalaReflectionSuite extends SparkFunSuite {
   import org.apache.spark.sql.catalyst.ScalaReflection._
 
+  test("isSubtype") {
+    assert(isSubtype(localTypeOf[Option[Int]], localTypeOf[Option[_]]))
+    assert(isSubtype(localTypeOf[Option[Int]], localTypeOf[Option[Int]]))
+    assert(!isSubtype(localTypeOf[Option[_]], localTypeOf[Option[Int]]))
+  }
+
   test("SQLUserDefinedType annotation on Scala structure") {
     val schema = schemaFor[TestingUDT.NestedStruct]
     assert(schema === Schema(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to