This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new d5c08fc [SPARK-26555][SQL] make ScalaReflection subtype checking
thread safe
d5c08fc is described below
commit d5c08fcaab141e13f95bd8d7a2c900aeccdf718d
Author: mwlon <[email protected]>
AuthorDate: Tue Mar 19 18:22:01 2019 +0800
[SPARK-26555][SQL] make ScalaReflection subtype checking thread safe
## What changes were proposed in this pull request?
Make ScalaReflection subtype checking thread safe by adding a lock. There
is a thread safety bug in the <:< operator in all versions of scala
(https://github.com/scala/bug/issues/10766).
## How was this patch tested?
Existing tests and a new one for the new subtype checking function.
Closes #24085 from mwlon/SPARK-26555.
Authored-by: mwlon <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../spark/sql/catalyst/ScalaReflection.scala | 234 ++++++++++++---------
.../spark/sql/catalyst/ScalaReflectionSuite.scala | 6 +
2 files changed, 136 insertions(+), 104 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index d8d268a..fa8993e 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -38,6 +38,9 @@ import org.apache.spark.unsafe.types.{CalendarInterval,
UTF8String}
trait DefinedByConstructorParams
+private[catalyst] object ScalaSubtypeLock
+
+
/**
* A default version of ScalaReflection that uses the runtime universe.
*/
@@ -66,19 +69,32 @@ object ScalaReflection extends ScalaReflection {
*/
def dataTypeFor[T : TypeTag]: DataType = dataTypeFor(localTypeOf[T])
+ /**
+ * Synchronize to prevent concurrent usage of `<:<` operator.
+ * This operator is not thread safe in any current version of scala; i.e.
+ * (2.11.12, 2.12.8, 2.13.0-M5).
+ *
+ * See https://github.com/scala/bug/issues/10766
+ */
+ private[catalyst] def isSubtype(tpe1: `Type`, tpe2: `Type`): Boolean = {
+ ScalaSubtypeLock.synchronized {
+ tpe1 <:< tpe2
+ }
+ }
+
private def dataTypeFor(tpe: `Type`): DataType = cleanUpReflectionObjects {
tpe.dealias match {
- case t if t <:< definitions.NullTpe => NullType
- case t if t <:< definitions.IntTpe => IntegerType
- case t if t <:< definitions.LongTpe => LongType
- case t if t <:< definitions.DoubleTpe => DoubleType
- case t if t <:< definitions.FloatTpe => FloatType
- case t if t <:< definitions.ShortTpe => ShortType
- case t if t <:< definitions.ByteTpe => ByteType
- case t if t <:< definitions.BooleanTpe => BooleanType
- case t if t <:< localTypeOf[Array[Byte]] => BinaryType
- case t if t <:< localTypeOf[CalendarInterval] => CalendarIntervalType
- case t if t <:< localTypeOf[Decimal] => DecimalType.SYSTEM_DEFAULT
+ case t if isSubtype(t, definitions.NullTpe) => NullType
+ case t if isSubtype(t, definitions.IntTpe) => IntegerType
+ case t if isSubtype(t, definitions.LongTpe) => LongType
+ case t if isSubtype(t, definitions.DoubleTpe) => DoubleType
+ case t if isSubtype(t, definitions.FloatTpe) => FloatType
+ case t if isSubtype(t, definitions.ShortTpe) => ShortType
+ case t if isSubtype(t, definitions.ByteTpe) => ByteType
+ case t if isSubtype(t, definitions.BooleanTpe) => BooleanType
+ case t if isSubtype(t, localTypeOf[Array[Byte]]) => BinaryType
+ case t if isSubtype(t, localTypeOf[CalendarInterval]) =>
CalendarIntervalType
+ case t if isSubtype(t, localTypeOf[Decimal]) =>
DecimalType.SYSTEM_DEFAULT
case _ =>
val className = getClassNameFromType(tpe)
className match {
@@ -101,13 +117,13 @@ object ScalaReflection extends ScalaReflection {
*/
private def arrayClassFor(tpe: `Type`): ObjectType =
cleanUpReflectionObjects {
val cls = tpe.dealias match {
- case t if t <:< definitions.IntTpe => classOf[Array[Int]]
- case t if t <:< definitions.LongTpe => classOf[Array[Long]]
- case t if t <:< definitions.DoubleTpe => classOf[Array[Double]]
- case t if t <:< definitions.FloatTpe => classOf[Array[Float]]
- case t if t <:< definitions.ShortTpe => classOf[Array[Short]]
- case t if t <:< definitions.ByteTpe => classOf[Array[Byte]]
- case t if t <:< definitions.BooleanTpe => classOf[Array[Boolean]]
+ case t if isSubtype(t, definitions.IntTpe) => classOf[Array[Int]]
+ case t if isSubtype(t, definitions.LongTpe) => classOf[Array[Long]]
+ case t if isSubtype(t, definitions.DoubleTpe) => classOf[Array[Double]]
+ case t if isSubtype(t, definitions.FloatTpe) => classOf[Array[Float]]
+ case t if isSubtype(t, definitions.ShortTpe) => classOf[Array[Short]]
+ case t if isSubtype(t, definitions.ByteTpe) => classOf[Array[Byte]]
+ case t if isSubtype(t, definitions.BooleanTpe) => classOf[Array[Boolean]]
case other =>
// There is probably a better way to do this, but I couldn't find it...
val elementType = dataTypeFor(other).asInstanceOf[ObjectType].cls
@@ -161,68 +177,68 @@ object ScalaReflection extends ScalaReflection {
tpe.dealias match {
case t if !dataTypeFor(t).isInstanceOf[ObjectType] => path
- case t if t <:< localTypeOf[Option[_]] =>
+ case t if isSubtype(t, localTypeOf[Option[_]]) =>
val TypeRef(_, _, Seq(optType)) = t
val className = getClassNameFromType(optType)
val newTypePath = walkedTypePath.recordOption(className)
WrapOption(deserializerFor(optType, path, newTypePath),
dataTypeFor(optType))
- case t if t <:< localTypeOf[java.lang.Integer] =>
+ case t if isSubtype(t, localTypeOf[java.lang.Integer]) =>
createDeserializerForTypesSupportValueOf(path,
classOf[java.lang.Integer])
- case t if t <:< localTypeOf[java.lang.Long] =>
+ case t if isSubtype(t, localTypeOf[java.lang.Long]) =>
createDeserializerForTypesSupportValueOf(path,
classOf[java.lang.Long])
- case t if t <:< localTypeOf[java.lang.Double] =>
+ case t if isSubtype(t, localTypeOf[java.lang.Double]) =>
createDeserializerForTypesSupportValueOf(path,
classOf[java.lang.Double])
- case t if t <:< localTypeOf[java.lang.Float] =>
+ case t if isSubtype(t, localTypeOf[java.lang.Float]) =>
createDeserializerForTypesSupportValueOf(path,
classOf[java.lang.Float])
- case t if t <:< localTypeOf[java.lang.Short] =>
+ case t if isSubtype(t, localTypeOf[java.lang.Short]) =>
createDeserializerForTypesSupportValueOf(path,
classOf[java.lang.Short])
- case t if t <:< localTypeOf[java.lang.Byte] =>
+ case t if isSubtype(t, localTypeOf[java.lang.Byte]) =>
createDeserializerForTypesSupportValueOf(path,
classOf[java.lang.Byte])
- case t if t <:< localTypeOf[java.lang.Boolean] =>
+ case t if isSubtype(t, localTypeOf[java.lang.Boolean]) =>
createDeserializerForTypesSupportValueOf(path,
classOf[java.lang.Boolean])
- case t if t <:< localTypeOf[java.time.LocalDate] =>
+ case t if isSubtype(t, localTypeOf[java.time.LocalDate]) =>
createDeserializerForLocalDate(path)
- case t if t <:< localTypeOf[java.sql.Date] =>
+ case t if isSubtype(t, localTypeOf[java.sql.Date]) =>
createDeserializerForSqlDate(path)
- case t if t <:< localTypeOf[java.time.Instant] =>
+ case t if isSubtype(t, localTypeOf[java.time.Instant]) =>
createDeserializerForInstant(path)
- case t if t <:< localTypeOf[java.sql.Timestamp] =>
+ case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) =>
createDeserializerForSqlTimestamp(path)
- case t if t <:< localTypeOf[java.lang.String] =>
+ case t if isSubtype(t, localTypeOf[java.lang.String]) =>
createDeserializerForString(path, returnNullable = false)
- case t if t <:< localTypeOf[java.math.BigDecimal] =>
+ case t if isSubtype(t, localTypeOf[java.math.BigDecimal]) =>
createDeserializerForJavaBigDecimal(path, returnNullable = false)
- case t if t <:< localTypeOf[BigDecimal] =>
+ case t if isSubtype(t, localTypeOf[BigDecimal]) =>
createDeserializerForScalaBigDecimal(path, returnNullable = false)
- case t if t <:< localTypeOf[java.math.BigInteger] =>
+ case t if isSubtype(t, localTypeOf[java.math.BigInteger]) =>
createDeserializerForJavaBigInteger(path, returnNullable = false)
- case t if t <:< localTypeOf[scala.math.BigInt] =>
+ case t if isSubtype(t, localTypeOf[scala.math.BigInt]) =>
createDeserializerForScalaBigInt(path)
- case t if t <:< localTypeOf[Array[_]] =>
+ case t if isSubtype(t, localTypeOf[Array[_]]) =>
val TypeRef(_, _, Seq(elementType)) = t
val Schema(dataType, elementNullable) = schemaFor(elementType)
val className = getClassNameFromType(elementType)
@@ -242,13 +258,13 @@ object ScalaReflection extends ScalaReflection {
val arrayCls = arrayClassFor(elementType)
val methodName = elementType match {
- case t if t <:< definitions.IntTpe => "toIntArray"
- case t if t <:< definitions.LongTpe => "toLongArray"
- case t if t <:< definitions.DoubleTpe => "toDoubleArray"
- case t if t <:< definitions.FloatTpe => "toFloatArray"
- case t if t <:< definitions.ShortTpe => "toShortArray"
- case t if t <:< definitions.ByteTpe => "toByteArray"
- case t if t <:< definitions.BooleanTpe => "toBooleanArray"
+ case t if isSubtype(t, definitions.IntTpe) => "toIntArray"
+ case t if isSubtype(t, definitions.LongTpe) => "toLongArray"
+ case t if isSubtype(t, definitions.DoubleTpe) => "toDoubleArray"
+ case t if isSubtype(t, definitions.FloatTpe) => "toFloatArray"
+ case t if isSubtype(t, definitions.ShortTpe) => "toShortArray"
+ case t if isSubtype(t, definitions.ByteTpe) => "toByteArray"
+ case t if isSubtype(t, definitions.BooleanTpe) => "toBooleanArray"
// non-primitive
case _ => "array"
}
@@ -256,8 +272,8 @@ object ScalaReflection extends ScalaReflection {
// We serialize a `Set` to Catalyst array. When we deserialize a
Catalyst array
// to a `Set`, if there are duplicated elements, the elements will be
de-duplicated.
- case t if t <:< localTypeOf[Seq[_]] ||
- t <:< localTypeOf[scala.collection.Set[_]] =>
+ case t if isSubtype(t, localTypeOf[Seq[_]]) ||
+ isSubtype(t, localTypeOf[scala.collection.Set[_]]) =>
val TypeRef(_, _, Seq(elementType)) = t
val Schema(dataType, elementNullable) = schemaFor(elementType)
val className = getClassNameFromType(elementType)
@@ -274,14 +290,14 @@ object ScalaReflection extends ScalaReflection {
val companion = t.dealias.typeSymbol.companion.typeSignature
val cls = companion.member(TermName("newBuilder")) match {
- case NoSymbol if t <:< localTypeOf[Seq[_]] => classOf[Seq[_]]
- case NoSymbol if t <:< localTypeOf[scala.collection.Set[_]] =>
+ case NoSymbol if isSubtype(t, localTypeOf[Seq[_]]) => classOf[Seq[_]]
+ case NoSymbol if isSubtype(t, localTypeOf[scala.collection.Set[_]])
=>
classOf[scala.collection.Set[_]]
case _ => mirror.runtimeClass(t.typeSymbol.asClass)
}
UnresolvedMapObjects(mapFunction, path, Some(cls))
- case t if t <:< localTypeOf[Map[_, _]] =>
+ case t if isSubtype(t, localTypeOf[Map[_, _]]) =>
val TypeRef(_, _, Seq(keyType, valueType)) = t
val classNameForKey = getClassNameFromType(keyType)
@@ -411,7 +427,7 @@ object ScalaReflection extends ScalaReflection {
tpe.dealias match {
case _ if !inputObject.dataType.isInstanceOf[ObjectType] => inputObject
- case t if t <:< localTypeOf[Option[_]] =>
+ case t if isSubtype(t, localTypeOf[Option[_]]) =>
val TypeRef(_, _, Seq(optType)) = t
val className = getClassNameFromType(optType)
val newPath = walkedTypePath.recordOption(className)
@@ -421,15 +437,15 @@ object ScalaReflection extends ScalaReflection {
// Since List[_] also belongs to localTypeOf[Product], we put this case
before
// "case t if definedByConstructorParams(t)" to make sure it will match
to the
// case "localTypeOf[Seq[_]]"
- case t if t <:< localTypeOf[Seq[_]] =>
+ case t if isSubtype(t, localTypeOf[Seq[_]]) =>
val TypeRef(_, _, Seq(elementType)) = t
toCatalystArray(inputObject, elementType)
- case t if t <:< localTypeOf[Array[_]] =>
+ case t if isSubtype(t, localTypeOf[Array[_]]) =>
val TypeRef(_, _, Seq(elementType)) = t
toCatalystArray(inputObject, elementType)
- case t if t <:< localTypeOf[Map[_, _]] =>
+ case t if isSubtype(t, localTypeOf[Map[_, _]]) =>
val TypeRef(_, _, Seq(keyType, valueType)) = t
val keyClsName = getClassNameFromType(keyType)
val valueClsName = getClassNameFromType(valueType)
@@ -448,7 +464,7 @@ object ScalaReflection extends ScalaReflection {
serializerFor(_, valueType, valuePath, seenTypeSet))
)
- case t if t <:< localTypeOf[scala.collection.Set[_]] =>
+ case t if isSubtype(t, localTypeOf[scala.collection.Set[_]]) =>
val TypeRef(_, _, Seq(elementType)) = t
// There's no corresponding Catalyst type for `Set`, we serialize a
`Set` to Catalyst array.
@@ -461,35 +477,41 @@ object ScalaReflection extends ScalaReflection {
toCatalystArray(newInput, elementType)
- case t if t <:< localTypeOf[String] =>
createSerializerForString(inputObject)
+ case t if isSubtype(t, localTypeOf[String]) =>
createSerializerForString(inputObject)
- case t if t <:< localTypeOf[java.time.Instant] =>
createSerializerForJavaInstant(inputObject)
+ case t if isSubtype(t, localTypeOf[java.time.Instant]) =>
+ createSerializerForJavaInstant(inputObject)
- case t if t <:< localTypeOf[java.sql.Timestamp] =>
+ case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) =>
createSerializerForSqlTimestamp(inputObject)
- case t if t <:< localTypeOf[java.time.LocalDate] =>
+ case t if isSubtype(t, localTypeOf[java.time.LocalDate]) =>
createSerializerForJavaLocalDate(inputObject)
- case t if t <:< localTypeOf[java.sql.Date] =>
createSerializerForSqlDate(inputObject)
+ case t if isSubtype(t, localTypeOf[java.sql.Date]) =>
createSerializerForSqlDate(inputObject)
- case t if t <:< localTypeOf[BigDecimal] =>
createSerializerForScalaBigDecimal(inputObject)
+ case t if isSubtype(t, localTypeOf[BigDecimal]) =>
+ createSerializerForScalaBigDecimal(inputObject)
- case t if t <:< localTypeOf[java.math.BigDecimal] =>
+ case t if isSubtype(t, localTypeOf[java.math.BigDecimal]) =>
createSerializerForJavaBigDecimal(inputObject)
- case t if t <:< localTypeOf[java.math.BigInteger] =>
+ case t if isSubtype(t, localTypeOf[java.math.BigInteger]) =>
createSerializerForJavaBigInteger(inputObject)
- case t if t <:< localTypeOf[scala.math.BigInt] =>
createSerializerForScalaBigInt(inputObject)
+ case t if isSubtype(t, localTypeOf[scala.math.BigInt]) =>
+ createSerializerForScalaBigInt(inputObject)
- case t if t <:< localTypeOf[java.lang.Integer] =>
createSerializerForInteger(inputObject)
- case t if t <:< localTypeOf[java.lang.Long] =>
createSerializerForLong(inputObject)
- case t if t <:< localTypeOf[java.lang.Double] =>
createSerializerForDouble(inputObject)
- case t if t <:< localTypeOf[java.lang.Float] =>
createSerializerForFloat(inputObject)
- case t if t <:< localTypeOf[java.lang.Short] =>
createSerializerForShort(inputObject)
- case t if t <:< localTypeOf[java.lang.Byte] =>
createSerializerForByte(inputObject)
- case t if t <:< localTypeOf[java.lang.Boolean] =>
createSerializerForBoolean(inputObject)
+ case t if isSubtype(t, localTypeOf[java.lang.Integer]) =>
+ createSerializerForInteger(inputObject)
+ case t if isSubtype(t, localTypeOf[java.lang.Long]) =>
createSerializerForLong(inputObject)
+ case t if isSubtype(t, localTypeOf[java.lang.Double]) =>
+ createSerializerForDouble(inputObject)
+ case t if isSubtype(t, localTypeOf[java.lang.Float]) =>
createSerializerForFloat(inputObject)
+ case t if isSubtype(t, localTypeOf[java.lang.Short]) =>
createSerializerForShort(inputObject)
+ case t if isSubtype(t, localTypeOf[java.lang.Byte]) =>
createSerializerForByte(inputObject)
+ case t if isSubtype(t, localTypeOf[java.lang.Boolean]) =>
+ createSerializerForBoolean(inputObject)
case t if t.typeSymbol.annotations.exists(_.tree.tpe =:=
typeOf[SQLUserDefinedType]) =>
val udt = getClassFromType(t)
@@ -540,7 +562,7 @@ object ScalaReflection extends ScalaReflection {
*/
def optionOfProductType(tpe: `Type`): Boolean = cleanUpReflectionObjects {
tpe.dealias match {
- case t if t <:< localTypeOf[Option[_]] =>
+ case t if isSubtype(t, localTypeOf[Option[_]]) =>
val TypeRef(_, _, Seq(optType)) = t
definedByConstructorParams(optType)
case _ => false
@@ -606,7 +628,7 @@ object ScalaReflection extends ScalaReflection {
tpe.dealias match {
// this must be the first case, since all objects in scala are instances
of Null, therefore
// Null type would wrongly match the first of them, which is Option as
of now
- case t if t <:< definitions.NullTpe => Schema(NullType, nullable = true)
+ case t if isSubtype(t, definitions.NullTpe) => Schema(NullType, nullable
= true)
case t if t.typeSymbol.annotations.exists(_.tree.tpe =:=
typeOf[SQLUserDefinedType]) =>
val udt =
getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().
getConstructor().newInstance()
@@ -615,54 +637,58 @@ object ScalaReflection extends ScalaReflection {
val udt =
UDTRegistration.getUDTFor(getClassNameFromType(t)).get.getConstructor().
newInstance().asInstanceOf[UserDefinedType[_]]
Schema(udt, nullable = true)
- case t if t <:< localTypeOf[Option[_]] =>
+ case t if isSubtype(t, localTypeOf[Option[_]]) =>
val TypeRef(_, _, Seq(optType)) = t
Schema(schemaFor(optType).dataType, nullable = true)
- case t if t <:< localTypeOf[Array[Byte]] => Schema(BinaryType, nullable
= true)
- case t if t <:< localTypeOf[Array[_]] =>
+ case t if isSubtype(t, localTypeOf[Array[Byte]]) => Schema(BinaryType,
nullable = true)
+ case t if isSubtype(t, localTypeOf[Array[_]]) =>
val TypeRef(_, _, Seq(elementType)) = t
val Schema(dataType, nullable) = schemaFor(elementType)
Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
- case t if t <:< localTypeOf[Seq[_]] =>
+ case t if isSubtype(t, localTypeOf[Seq[_]]) =>
val TypeRef(_, _, Seq(elementType)) = t
val Schema(dataType, nullable) = schemaFor(elementType)
Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
- case t if t <:< localTypeOf[Map[_, _]] =>
+ case t if isSubtype(t, localTypeOf[Map[_, _]]) =>
val TypeRef(_, _, Seq(keyType, valueType)) = t
val Schema(valueDataType, valueNullable) = schemaFor(valueType)
Schema(MapType(schemaFor(keyType).dataType,
valueDataType, valueContainsNull = valueNullable), nullable = true)
- case t if t <:< localTypeOf[Set[_]] =>
+ case t if isSubtype(t, localTypeOf[Set[_]]) =>
val TypeRef(_, _, Seq(elementType)) = t
val Schema(dataType, nullable) = schemaFor(elementType)
Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
- case t if t <:< localTypeOf[String] => Schema(StringType, nullable =
true)
- case t if t <:< localTypeOf[java.time.Instant] => Schema(TimestampType,
nullable = true)
- case t if t <:< localTypeOf[java.sql.Timestamp] => Schema(TimestampType,
nullable = true)
- case t if t <:< localTypeOf[java.time.LocalDate] => Schema(DateType,
nullable = true)
- case t if t <:< localTypeOf[java.sql.Date] => Schema(DateType, nullable
= true)
- case t if t <:< localTypeOf[BigDecimal] =>
Schema(DecimalType.SYSTEM_DEFAULT, nullable = true)
- case t if t <:< localTypeOf[java.math.BigDecimal] =>
+ case t if isSubtype(t, localTypeOf[String]) => Schema(StringType,
nullable = true)
+ case t if isSubtype(t, localTypeOf[java.time.Instant]) =>
+ Schema(TimestampType, nullable = true)
+ case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) =>
+ Schema(TimestampType, nullable = true)
+ case t if isSubtype(t, localTypeOf[java.time.LocalDate]) =>
Schema(DateType, nullable = true)
+ case t if isSubtype(t, localTypeOf[java.sql.Date]) => Schema(DateType,
nullable = true)
+ case t if isSubtype(t, localTypeOf[BigDecimal]) =>
Schema(DecimalType.SYSTEM_DEFAULT, nullable = true)
- case t if t <:< localTypeOf[java.math.BigInteger] =>
+ case t if isSubtype(t, localTypeOf[java.math.BigDecimal]) =>
+ Schema(DecimalType.SYSTEM_DEFAULT, nullable = true)
+ case t if isSubtype(t, localTypeOf[java.math.BigInteger]) =>
Schema(DecimalType.BigIntDecimal, nullable = true)
- case t if t <:< localTypeOf[scala.math.BigInt] =>
+ case t if isSubtype(t, localTypeOf[scala.math.BigInt]) =>
Schema(DecimalType.BigIntDecimal, nullable = true)
- case t if t <:< localTypeOf[Decimal] =>
Schema(DecimalType.SYSTEM_DEFAULT, nullable = true)
- case t if t <:< localTypeOf[java.lang.Integer] => Schema(IntegerType,
nullable = true)
- case t if t <:< localTypeOf[java.lang.Long] => Schema(LongType, nullable
= true)
- case t if t <:< localTypeOf[java.lang.Double] => Schema(DoubleType,
nullable = true)
- case t if t <:< localTypeOf[java.lang.Float] => Schema(FloatType,
nullable = true)
- case t if t <:< localTypeOf[java.lang.Short] => Schema(ShortType,
nullable = true)
- case t if t <:< localTypeOf[java.lang.Byte] => Schema(ByteType, nullable
= true)
- case t if t <:< localTypeOf[java.lang.Boolean] => Schema(BooleanType,
nullable = true)
- case t if t <:< definitions.IntTpe => Schema(IntegerType, nullable =
false)
- case t if t <:< definitions.LongTpe => Schema(LongType, nullable = false)
- case t if t <:< definitions.DoubleTpe => Schema(DoubleType, nullable =
false)
- case t if t <:< definitions.FloatTpe => Schema(FloatType, nullable =
false)
- case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable =
false)
- case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false)
- case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable =
false)
+ case t if isSubtype(t, localTypeOf[Decimal]) =>
+ Schema(DecimalType.SYSTEM_DEFAULT, nullable = true)
+ case t if isSubtype(t, localTypeOf[java.lang.Integer]) =>
Schema(IntegerType, nullable = true)
+ case t if isSubtype(t, localTypeOf[java.lang.Long]) => Schema(LongType,
nullable = true)
+ case t if isSubtype(t, localTypeOf[java.lang.Double]) =>
Schema(DoubleType, nullable = true)
+ case t if isSubtype(t, localTypeOf[java.lang.Float]) =>
Schema(FloatType, nullable = true)
+ case t if isSubtype(t, localTypeOf[java.lang.Short]) =>
Schema(ShortType, nullable = true)
+ case t if isSubtype(t, localTypeOf[java.lang.Byte]) => Schema(ByteType,
nullable = true)
+ case t if isSubtype(t, localTypeOf[java.lang.Boolean]) =>
Schema(BooleanType, nullable = true)
+ case t if isSubtype(t, definitions.IntTpe) => Schema(IntegerType,
nullable = false)
+ case t if isSubtype(t, definitions.LongTpe) => Schema(LongType, nullable
= false)
+ case t if isSubtype(t, definitions.DoubleTpe) => Schema(DoubleType,
nullable = false)
+ case t if isSubtype(t, definitions.FloatTpe) => Schema(FloatType,
nullable = false)
+ case t if isSubtype(t, definitions.ShortTpe) => Schema(ShortType,
nullable = false)
+ case t if isSubtype(t, definitions.ByteTpe) => Schema(ByteType, nullable
= false)
+ case t if isSubtype(t, definitions.BooleanTpe) => Schema(BooleanType,
nullable = false)
case t if definedByConstructorParams(t) =>
val params = getConstructorParameters(t)
Schema(StructType(
@@ -715,9 +741,9 @@ object ScalaReflection extends ScalaReflection {
def definedByConstructorParams(tpe: Type): Boolean =
cleanUpReflectionObjects {
tpe.dealias match {
// `Option` is a `Product`, but we don't wanna treat `Option[Int]` as a
struct type.
- case t if t <:< localTypeOf[Option[_]] =>
definedByConstructorParams(t.typeArgs.head)
- case _ => tpe.dealias <:< localTypeOf[Product] ||
- tpe.dealias <:< localTypeOf[DefinedByConstructorParams]
+ case t if isSubtype(t, localTypeOf[Option[_]]) =>
definedByConstructorParams(t.typeArgs.head)
+ case _ => isSubtype(tpe.dealias, localTypeOf[Product]) ||
+ isSubtype(tpe.dealias, localTypeOf[DefinedByConstructorParams])
}
}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
index 80824cc..e8df031 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
@@ -145,6 +145,12 @@ class ScalaReflectionSuite extends SparkFunSuite {
private def deserializerFor[T: TypeTag]: Expression =
deserializerForType(ScalaReflection.localTypeOf[T])
+ test("isSubtype") {
+ assert(isSubtype(localTypeOf[Option[Int]], localTypeOf[Option[_]]))
+ assert(isSubtype(localTypeOf[Option[Int]], localTypeOf[Option[Int]]))
+ assert(!isSubtype(localTypeOf[Option[_]], localTypeOf[Option[Int]]))
+ }
+
test("SQLUserDefinedType annotation on Scala structure") {
val schema = schemaFor[TestingUDT.NestedStruct]
assert(schema === Schema(
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]