Maybe an udf to flatten is an interesting option as well. http://stackoverflow.com/q/42888711/2587904 would a uadf very more performant? shyla deshpande <deshpandesh...@gmail.com> schrieb am Fr. 24. März 2017 um 04:04:
> Thanks a million Yong. Great help!!! It solved my problem. > > On Thu, Mar 23, 2017 at 6:00 PM, Yong Zhang <java8...@hotmail.com> wrote: > > Change: > > val arrayinput = input.getAs[Array[String]](0) > > to: > > val arrayinput = input.getAs[*Seq*[String]](0) > > > Yong > > > ------------------------------ > *From:* shyla deshpande <deshpandesh...@gmail.com> > *Sent:* Thursday, March 23, 2017 8:18 PM > *To:* user > *Subject:* Spark dataframe, UserDefinedAggregateFunction(UDAF) help!! > > This is my input data. The UDAF needs to aggregate the goals for a team > and return a map that gives the count for every goal in the team. > I am getting the following error > > java.lang.ClassCastException: scala.collection.mutable.WrappedArray$ofRef > cannot be cast to [Ljava.lang.String; > at com.whil.common.GoalAggregator.update(GoalAggregator.scala:27) > > +------+--------------+ > |teamid|goals | > +------+--------------+ > |t1 |[Goal1, Goal2]| > |t1 |[Goal1, Goal3]| > |t2 |[Goal1, Goal2]| > |t3 |[Goal2, Goal3]| > +------+--------------+ > > root > |-- teamid: string (nullable = true) > |-- goals: array (nullable = true) > | |-- element: string (containsNull = true) > > /////////////////////////Calling the UDAF////////// > > object TestUDAF { > def main(args: Array[String]): Unit = { > > val spark = SparkSession > .builder > .getOrCreate() > > val sc: SparkContext = spark.sparkContext > val sqlContext = spark.sqlContext > > import sqlContext.implicits._ > > val data = Seq( > ("t1", Seq("Goal1", "Goal2")), > ("t1", Seq("Goal1", "Goal3")), > ("t2", Seq("Goal1", "Goal2")), > ("t3", Seq("Goal2", "Goal3"))).toDF("teamid","goals") > > data.show(truncate = false) > data.printSchema() > > import spark.implicits._ > > val sumgoals = new GoalAggregator > val result = data.groupBy("teamid").agg(sumgoals(col("goals"))) > > result.show(truncate = false) > > } > } > > ///////////////UDAF///////////////// > > import org.apache.spark.sql.expressions.MutableAggregationBuffer > import org.apache.spark.sql.expressions.UserDefinedAggregateFunction > import org.apache.spark.sql.Row > import org.apache.spark.sql.types._ > > class GoalAggregator extends UserDefinedAggregateFunction{ > > override def inputSchema: org.apache.spark.sql.types.StructType = > StructType(StructField("value", ArrayType(StringType)) :: Nil) > > override def bufferSchema: StructType = StructType( > StructField("combined", MapType(StringType,IntegerType)) :: Nil > ) > > override def dataType: DataType = MapType(StringType,IntegerType) > > override def deterministic: Boolean = true > > override def initialize(buffer: MutableAggregationBuffer): Unit = { > buffer.update(0, Map[String, Integer]()) > } > > override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { > val mapbuf = buffer.getAs[Map[String, Int]](0) > val arrayinput = input.getAs[Array[String]](0) > val result = mapbuf ++ arrayinput.map(goal => { > val cnt = mapbuf.get(goal).getOrElse(0) + 1 > goal -> cnt > }) > buffer.update(0, result) > } > > override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = > { > val map1 = buffer1.getAs[Map[String, Int]](0) > val map2 = buffer2.getAs[Map[String, Int]](0) > val result = map1 ++ map2.map { case (k,v) => > val cnt = map1.get(k).getOrElse(0) + 1 > k -> cnt > } > buffer1.update(0, result) > } > > override def evaluate(buffer: Row): Any = { > buffer.getAs[Map[String, Int]](0) > } > } > > > > >