Maybe an udf to flatten is an interesting option as well.
http://stackoverflow.com/q/42888711/2587904 would a uadf very more
performant?
shyla deshpande <deshpandesh...@gmail.com> schrieb am Fr. 24. März 2017 um
04:04:

> Thanks a million Yong. Great help!!! It solved my problem.
>
> On Thu, Mar 23, 2017 at 6:00 PM, Yong Zhang <java8...@hotmail.com> wrote:
>
> Change:
>
> val arrayinput = input.getAs[Array[String]](0)
>
> to:
>
> val arrayinput = input.getAs[*Seq*[String]](0)
>
>
> Yong
>
>
> ------------------------------
> *From:* shyla deshpande <deshpandesh...@gmail.com>
> *Sent:* Thursday, March 23, 2017 8:18 PM
> *To:* user
> *Subject:* Spark dataframe, UserDefinedAggregateFunction(UDAF) help!!
>
> This is my input data. The UDAF needs to aggregate the goals for a team
> and return a map that  gives the count for every goal in the team.
> I am getting the following error
>
> java.lang.ClassCastException: scala.collection.mutable.WrappedArray$ofRef
> cannot be cast to [Ljava.lang.String;
> at com.whil.common.GoalAggregator.update(GoalAggregator.scala:27)
>
> +------+--------------+
> |teamid|goals         |
> +------+--------------+
> |t1    |[Goal1, Goal2]|
> |t1    |[Goal1, Goal3]|
> |t2    |[Goal1, Goal2]|
> |t3    |[Goal2, Goal3]|
> +------+--------------+
>
> root
>  |-- teamid: string (nullable = true)
>  |-- goals: array (nullable = true)
>  |    |-- element: string (containsNull = true)
>
> /////////////////////////Calling the UDAF//////////
>
> object TestUDAF {
>   def main(args: Array[String]): Unit = {
>
>     val spark = SparkSession
>       .builder
>       .getOrCreate()
>
>     val sc: SparkContext = spark.sparkContext
>     val sqlContext = spark.sqlContext
>
>     import sqlContext.implicits._
>
>     val data = Seq(
>       ("t1", Seq("Goal1", "Goal2")),
>       ("t1", Seq("Goal1", "Goal3")),
>       ("t2", Seq("Goal1", "Goal2")),
>       ("t3", Seq("Goal2", "Goal3"))).toDF("teamid","goals")
>
>     data.show(truncate = false)
>     data.printSchema()
>
>     import spark.implicits._
>
>     val sumgoals = new GoalAggregator
>     val result = data.groupBy("teamid").agg(sumgoals(col("goals")))
>
>     result.show(truncate = false)
>
>   }
> }
>
> ///////////////UDAF/////////////////
>
> import org.apache.spark.sql.expressions.MutableAggregationBuffer
> import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
> import org.apache.spark.sql.Row
> import org.apache.spark.sql.types._
>
> class GoalAggregator extends UserDefinedAggregateFunction{
>
>   override def inputSchema: org.apache.spark.sql.types.StructType =
>   StructType(StructField("value", ArrayType(StringType)) :: Nil)
>
>   override def bufferSchema: StructType = StructType(
>       StructField("combined", MapType(StringType,IntegerType)) :: Nil
>   )
>
>   override def dataType: DataType = MapType(StringType,IntegerType)
>
>   override def deterministic: Boolean = true
>
>   override def initialize(buffer: MutableAggregationBuffer): Unit = {
>     buffer.update(0, Map[String, Integer]())
>   }
>
>   override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
>     val mapbuf = buffer.getAs[Map[String, Int]](0)
>     val arrayinput = input.getAs[Array[String]](0)
>     val result = mapbuf ++ arrayinput.map(goal => {
>       val cnt  = mapbuf.get(goal).getOrElse(0) + 1
>       goal -> cnt
>     })
>     buffer.update(0, result)
>   }
>
>   override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = 
> {
>     val map1 = buffer1.getAs[Map[String, Int]](0)
>     val map2 = buffer2.getAs[Map[String, Int]](0)
>     val result = map1 ++ map2.map { case (k,v) =>
>       val cnt = map1.get(k).getOrElse(0) + 1
>       k -> cnt
>     }
>     buffer1.update(0, result)
>   }
>
>   override def evaluate(buffer: Row): Any = {
>     buffer.getAs[Map[String, Int]](0)
>   }
> }
>
>
>
>
>

Reply via email to