Repository: spark
Updated Branches:
  refs/heads/branch-2.2 a1c1199e1 -> 86609a95a


[SPARK-21567][SQL] Dataset should work with type alias

If we create a type alias for a type workable with Dataset, the type alias 
doesn't work with Dataset.

A reproducible case looks like:

    object C {
      type TwoInt = (Int, Int)
      def tupleTypeAlias: TwoInt = (1, 1)
    }

    Seq(1).toDS().map(_ => ("", C.tupleTypeAlias))

It throws an exception like:

    type T1 is not a class
    scala.ScalaReflectionException: type T1 is not a class
      at scala.reflect.api.Symbols$SymbolApi$class.asClass(Symbols.scala:275)
      ...

This patch accesses the dealias of type in many places in `ScalaReflection` to 
fix it.

Added test case.

Author: Liang-Chi Hsieh <vii...@gmail.com>

Closes #18813 from viirya/SPARK-21567.

(cherry picked from commit ee1304199bcd9c1d5fc94f5b06fdd5f6fe7336a1)
Signed-off-by: Wenchen Fan <wenc...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/86609a95
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/86609a95
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/86609a95

Branch: refs/heads/branch-2.2
Commit: 86609a95af4b700e83638b7416c7e3706c2d64c6
Parents: a1c1199
Author: Liang-Chi Hsieh <vii...@gmail.com>
Authored: Tue Aug 8 16:12:41 2017 +0800
Committer: Wenchen Fan <wenc...@databricks.com>
Committed: Tue Aug 8 16:17:05 2017 +0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/ScalaReflection.scala    | 27 ++++++++++----------
 .../org/apache/spark/sql/DatasetSuite.scala     | 24 +++++++++++++++++
 2 files changed, 38 insertions(+), 13 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/86609a95/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index c887634..fe994b8 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -63,7 +63,7 @@ object ScalaReflection extends ScalaReflection {
   def dataTypeFor[T : TypeTag]: DataType = dataTypeFor(localTypeOf[T])
 
   private def dataTypeFor(tpe: `Type`): DataType = 
ScalaReflectionLock.synchronized {
-    tpe match {
+    tpe.dealias match {
       case t if t <:< definitions.IntTpe => IntegerType
       case t if t <:< definitions.LongTpe => LongType
       case t if t <:< definitions.DoubleTpe => DoubleType
@@ -93,7 +93,7 @@ object ScalaReflection extends ScalaReflection {
    * JVM form instead of the Scala Array that handles auto boxing.
    */
   private def arrayClassFor(tpe: `Type`): ObjectType = 
ScalaReflectionLock.synchronized {
-    val cls = tpe match {
+    val cls = tpe.dealias match {
       case t if t <:< definitions.IntTpe => classOf[Array[Int]]
       case t if t <:< definitions.LongTpe => classOf[Array[Long]]
       case t if t <:< definitions.DoubleTpe => classOf[Array[Double]]
@@ -192,7 +192,7 @@ object ScalaReflection extends ScalaReflection {
       case _ => UpCast(expr, expected, walkedTypePath)
     }
 
-    tpe match {
+    tpe.dealias match {
       case t if !dataTypeFor(t).isInstanceOf[ObjectType] => getPath
 
       case t if t <:< localTypeOf[Option[_]] =>
@@ -479,7 +479,7 @@ object ScalaReflection extends ScalaReflection {
       }
     }
 
-    tpe match {
+    tpe.dealias match {
       case _ if !inputObject.dataType.isInstanceOf[ObjectType] => inputObject
 
       case t if t <:< localTypeOf[Option[_]] =>
@@ -633,7 +633,7 @@ object ScalaReflection extends ScalaReflection {
    * we also treat [[DefinedByConstructorParams]] as product type.
    */
   def optionOfProductType(tpe: `Type`): Boolean = 
ScalaReflectionLock.synchronized {
-    tpe match {
+    tpe.dealias match {
       case t if t <:< localTypeOf[Option[_]] =>
         val TypeRef(_, _, Seq(optType)) = t
         definedByConstructorParams(optType)
@@ -680,7 +680,7 @@ object ScalaReflection extends ScalaReflection {
   /*
    * Retrieves the runtime class corresponding to the provided type.
    */
-  def getClassFromType(tpe: Type): Class[_] = 
mirror.runtimeClass(tpe.typeSymbol.asClass)
+  def getClassFromType(tpe: Type): Class[_] = 
mirror.runtimeClass(tpe.dealias.typeSymbol.asClass)
 
   case class Schema(dataType: DataType, nullable: Boolean)
 
@@ -695,7 +695,7 @@ object ScalaReflection extends ScalaReflection {
 
   /** Returns a catalyst DataType and its nullability for the given Scala Type 
using reflection. */
   def schemaFor(tpe: `Type`): Schema = ScalaReflectionLock.synchronized {
-    tpe match {
+    tpe.dealias match {
       case t if t.typeSymbol.annotations.exists(_.tpe =:= 
typeOf[SQLUserDefinedType]) =>
         val udt = 
getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance()
         Schema(udt, nullable = true)
@@ -761,7 +761,7 @@ object ScalaReflection extends ScalaReflection {
    * Whether the fields of the given type is defined entirely by its 
constructor parameters.
    */
   def definedByConstructorParams(tpe: Type): Boolean = {
-    tpe <:< localTypeOf[Product] || tpe <:< 
localTypeOf[DefinedByConstructorParams]
+    tpe.dealias <:< localTypeOf[Product] || tpe.dealias <:< 
localTypeOf[DefinedByConstructorParams]
   }
 
   private val javaKeywords = Set("abstract", "assert", "boolean", "break", 
"byte", "case", "catch",
@@ -816,7 +816,7 @@ trait ScalaReflection {
    * synthetic classes, emulating behaviour in Java bytecode.
    */
   def getClassNameFromType(tpe: `Type`): String = {
-    tpe.erasure.typeSymbol.asClass.fullName
+    tpe.dealias.erasure.typeSymbol.asClass.fullName
   }
 
   /**
@@ -835,9 +835,10 @@ trait ScalaReflection {
    * support inner class.
    */
   def getConstructorParameters(tpe: Type): Seq[(String, Type)] = {
-    val formalTypeArgs = tpe.typeSymbol.asClass.typeParams
-    val TypeRef(_, _, actualTypeArgs) = tpe
-    val params = constructParams(tpe)
+    val dealiasedTpe = tpe.dealias
+    val formalTypeArgs = dealiasedTpe.typeSymbol.asClass.typeParams
+    val TypeRef(_, _, actualTypeArgs) = dealiasedTpe
+    val params = constructParams(dealiasedTpe)
     // if there are type variables to fill in, do the substitution 
(SomeClass[T] -> SomeClass[Int])
     if (actualTypeArgs.nonEmpty) {
       params.map { p =>
@@ -851,7 +852,7 @@ trait ScalaReflection {
   }
 
   protected def constructParams(tpe: Type): Seq[Symbol] = {
-    val constructorSymbol = tpe.member(nme.CONSTRUCTOR)
+    val constructorSymbol = tpe.dealias.member(nme.CONSTRUCTOR)
     val params = if (constructorSymbol.isMethod) {
       constructorSymbol.asMethod.paramss
     } else {

http://git-wip-us.apache.org/repos/asf/spark/blob/86609a95/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 060fbe7..a2deb85 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -34,6 +34,16 @@ import org.apache.spark.sql.types._
 case class TestDataPoint(x: Int, y: Double, s: String, t: TestDataPoint2)
 case class TestDataPoint2(x: Int, s: String)
 
+object TestForTypeAlias {
+  type TwoInt = (Int, Int)
+  type ThreeInt = (TwoInt, Int)
+  type SeqOfTwoInt = Seq[TwoInt]
+
+  def tupleTypeAlias: TwoInt = (1, 1)
+  def nestedTupleTypeAlias: ThreeInt = ((1, 1), 2)
+  def seqOfTupleTypeAlias: SeqOfTwoInt = Seq((1, 1), (2, 2))
+}
+
 class DatasetSuite extends QueryTest with SharedSQLContext {
   import testImplicits._
 
@@ -1210,6 +1220,20 @@ class DatasetSuite extends QueryTest with 
SharedSQLContext {
     checkAnswer(df.orderBy($"id"), expected)
     checkAnswer(df.orderBy('id), expected)
   }
+
+  test("SPARK-21567: Dataset should work with type alias") {
+    checkDataset(
+      Seq(1).toDS().map(_ => ("", TestForTypeAlias.tupleTypeAlias)),
+      ("", (1, 1)))
+
+    checkDataset(
+      Seq(1).toDS().map(_ => ("", TestForTypeAlias.nestedTupleTypeAlias)),
+      ("", ((1, 1), 2)))
+
+    checkDataset(
+      Seq(1).toDS().map(_ => ("", TestForTypeAlias.seqOfTupleTypeAlias)),
+      ("", Seq((1, 1), (2, 2))))
+  }
 }
 
 case class WithImmutableMap(id: String, map_test: 
scala.collection.immutable.Map[Long, String])


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to