[ 
https://issues.apache.org/jira/browse/SPARK-12878?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=16513647#comment-16513647
 ] 

Liang-Chi Hsieh commented on SPARK-12878:
-----------------------------------------


Is this a real issue? Seems to me that you can't write nested UDT like the 
example code in the description.

The nested UDT example should be look like the following, you need to serialize 
nested UDT objects when you serialize the wrapper object:

{code:scala}
@SQLUserDefinedType(udt = classOf[WrapperUDT])
case class Wrapper(list: Seq[Element])

class WrapperUDT extends UserDefinedType[Wrapper] {
  override def sqlType: DataType = StructType(Seq(StructField("list",
    ArrayType(new ElementUDT(), containsNull = false), nullable = true)))
  override def userClass: Class[Wrapper] = classOf[Wrapper]
  override def serialize(obj: Wrapper): Any = obj match {
    case Wrapper(list) =>
      val row = new GenericInternalRow(1)
      val elementUDT = new ElementUDT()
      val serializedElements = list.map((e: Element) => elementUDT.serialize(e))
      row.update(0, new GenericArrayData(serializedElements.toArray))
      row
  }

  override def deserialize(datum: Any): Wrapper = datum match {
    case row: InternalRow =>
      val elementUDF = new ElementUDT()
      Wrapper(row.getArray(0).toArray(elementUDF).map((e: Any) => 
elementUDF.deserialize(e)))
  }
}

@SQLUserDefinedType(udt = classOf[ElementUDT])
case class Element(num: Int)

class ElementUDT extends UserDefinedType[Element] {
  override def sqlType: DataType =
    StructType(Seq(StructField("num", IntegerType, nullable = false)))
  override def userClass: Class[Element] = classOf[Element]
  override def serialize(obj: Element): Any = obj match {
    case Element(num) =>
      val row = new GenericInternalRow(1)
      row.setInt(0, num)
      row
  }

  override def deserialize(datum: Any): Element = datum match {
    case row: InternalRow => Element(row.getInt(0))
  }
}

val data = Seq(Wrapper(Seq(Element(1), Element(2))), Wrapper(Seq(Element(3), 
Element(4))))
val df = sparkContext.parallelize((1 to 2).zip(data)).toDF("id", "b")
df.collect().map(println(_))
{code}

{code}
[1,Wrapper(ArraySeq(Element(1), Element(2)))]
[2,Wrapper(ArraySeq(Element(3), Element(4)))]
{code}

> Dataframe fails with nested User Defined Types
> ----------------------------------------------
>
>                 Key: SPARK-12878
>                 URL: https://issues.apache.org/jira/browse/SPARK-12878
>             Project: Spark
>          Issue Type: Bug
>          Components: SQL
>    Affects Versions: 1.6.0
>            Reporter: Joao Duarte
>            Priority: Major
>
> Spark 1.6.0 crashes when using nested User Defined Types in a Dataframe. 
> In version 1.5.2 the code below worked just fine:
> import org.apache.spark.{SparkConf, SparkContext}
> import org.apache.spark.sql.catalyst.InternalRow
> import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
> import org.apache.spark.sql.types._
> @SQLUserDefinedType(udt = classOf[AUDT])
> case class A(list:Seq[B])
> class AUDT extends UserDefinedType[A] {
>   override def sqlType: DataType = StructType(Seq(StructField("list", 
> ArrayType(BUDT, containsNull = false), nullable = true)))
>   override def userClass: Class[A] = classOf[A]
>   override def serialize(obj: Any): Any = obj match {
>     case A(list) =>
>       val row = new GenericMutableRow(1)
>       row.update(0, new 
> GenericArrayData(list.map(_.asInstanceOf[Any]).toArray))
>       row
>   }
>   override def deserialize(datum: Any): A = {
>     datum match {
>       case row: InternalRow => new A(row.getArray(0).toArray(BUDT).toSeq)
>     }
>   }
> }
> object AUDT extends AUDT
> @SQLUserDefinedType(udt = classOf[BUDT])
> case class B(text:Int)
> class BUDT extends UserDefinedType[B] {
>   override def sqlType: DataType = StructType(Seq(StructField("num", 
> IntegerType, nullable = false)))
>   override def userClass: Class[B] = classOf[B]
>   override def serialize(obj: Any): Any = obj match {
>     case B(text) =>
>       val row = new GenericMutableRow(1)
>       row.setInt(0, text)
>       row
>   }
>   override def deserialize(datum: Any): B = {
>     datum match {  case row: InternalRow => new B(row.getInt(0))  }
>   }
> }
> object BUDT extends BUDT
> object Test {
>   def main(args:Array[String]) = {
>     val col = Seq(new A(Seq(new B(1), new B(2))),
>       new A(Seq(new B(3), new B(4))))
>     val sc = new SparkContext(new 
> SparkConf().setMaster("local[1]").setAppName("TestSpark"))
>     val sqlContext = new org.apache.spark.sql.SQLContext(sc)
>     import sqlContext.implicits._
>     val df = sc.parallelize(1 to 2 zip col).toDF("id","b")
>     df.select("b").show()
>     df.collect().foreach(println)
>   }
> }
> In the new version (1.6.0) I needed to include the following import:
> import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
> However, Spark crashes in runtime:
> 16/01/18 14:36:22 ERROR Executor: Exception in task 0.0 in stage 0.0 (TID 0)
> java.lang.ClassCastException: scala.runtime.BoxedUnit cannot be cast to 
> org.apache.spark.sql.catalyst.InternalRow
>       at 
> org.apache.spark.sql.catalyst.expressions.BaseGenericInternalRow$class.getStruct(rows.scala:51)
>       at 
> org.apache.spark.sql.catalyst.expressions.GenericMutableRow.getStruct(rows.scala:248)
>       at 
> org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(Unknown
>  Source)
>       at 
> org.apache.spark.sql.execution.Project$$anonfun$1$$anonfun$apply$1.apply(basicOperators.scala:51)
>       at 
> org.apache.spark.sql.execution.Project$$anonfun$1$$anonfun$apply$1.apply(basicOperators.scala:49)
>       at scala.collection.Iterator$$anon$11.next(Iterator.scala:328)
>       at scala.collection.Iterator$$anon$11.next(Iterator.scala:328)
>       at scala.collection.Iterator$$anon$11.next(Iterator.scala:328)
>       at scala.collection.Iterator$$anon$10.next(Iterator.scala:312)
>       at scala.collection.Iterator$class.foreach(Iterator.scala:727)
>       at scala.collection.AbstractIterator.foreach(Iterator.scala:1157)
>       at 
> scala.collection.generic.Growable$class.$plus$plus$eq(Growable.scala:48)
>       at 
> scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:103)
>       at 
> scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:47)
>       at scala.collection.TraversableOnce$class.to(TraversableOnce.scala:273)
>       at scala.collection.AbstractIterator.to(Iterator.scala:1157)
>       at 
> scala.collection.TraversableOnce$class.toBuffer(TraversableOnce.scala:265)
>       at scala.collection.AbstractIterator.toBuffer(Iterator.scala:1157)
>       at 
> scala.collection.TraversableOnce$class.toArray(TraversableOnce.scala:252)
>       at scala.collection.AbstractIterator.toArray(Iterator.scala:1157)
>       at 
> org.apache.spark.sql.execution.SparkPlan$$anonfun$5.apply(SparkPlan.scala:212)
>       at 
> org.apache.spark.sql.execution.SparkPlan$$anonfun$5.apply(SparkPlan.scala:212)
>       at 
> org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1858)
>       at 
> org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1858)
>       at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:66)
>       at org.apache.spark.scheduler.Task.run(Task.scala:89)
>       at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:213)
>       at 
> java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1145)
>       at 
> java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:615)
>       at java.lang.Thread.run(Thread.java:745)



--
This message was sent by Atlassian JIRA
(v7.6.3#76005)

---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org
For additional commands, e-mail: issues-h...@spark.apache.org

Reply via email to