A simple example:

We have a scala file:

package com.myorg.example

import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, 
import org.apache.spark.sql.functions.{rand, sum}
import org.apache.spark.sql.types.{DataType, DoubleType, StructField, 

class PerformSumUDAF() extends UserDefinedAggregateFunction {

  def inputSchema: StructType = StructType(Array(StructField("item", 

  def bufferSchema: StructType = StructType(Array(StructField("sum", 

  def dataType: DataType = DoubleType

  def deterministic: Boolean = true

  def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0.toDouble

  def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer(0) = buffer.getDouble(0) + input.getDouble(0)

  def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getDouble(0) + buffer2.getDouble(0)

  def evaluate(buffer: Row): Any = {

We place the file under 
Under myroot we create a pom file (sorry for not cleaning it up, it includes 
some stuff you probably not need like guava and avro)

  <name>example packages</name>
    <dependency> <!-- Spark dependency -->



Now you can compile the scala like so: mvn clean install (I assume you have 
maven installed).

Now we want to call this from python (assuming spark is your spark session):
# get a reference dataframe to do the example on:
df = spark.range(20)

# get the jvm pointer
jvm = spark.sparkContext._gateway.jvm
# import the class
from py4j.java_gateway import java_import
java_import(jvm, "com.myorg.example.PerformSumUDAF")

#create an object from the class:
udafObj = jvm.com.myorg.example.PerformSumUDAF()
# define a python function to do the aggregation.
from pyspark.sql.column import Column, _to_java_column, _to_seq
def pythonudaf(c):
    # the _to_seq portion is because we need to convert this to a sequence of
    # input columns the way scala (java) expects them. The returned
    # value must then be converted to a pyspark Column
    return Column(udafObj.apply(_to_seq(spark.sparkContext, [c], 

# now lets use the function

Lastly when you run, make sure to use both –jars and --driver-class-path with 
the jar created from scala to make sure it is available in all nodes.

From: Tobi Bosede [mailto:ani.to...@gmail.com]
Sent: Monday, October 17, 2016 10:15 PM
To: Mendelson, Assaf
Cc: Holden Karau; user
Subject: Re: Aggregate UDF (UDAF) in Python

Thanks Assaf. Yes please provide an example of how to wrap code for python. I 
am leaning towards scala.

On Mon, Oct 17, 2016 at 1:50 PM, Mendelson, Assaf 
<assaf.mendel...@rsa.com<mailto:assaf.mendel...@rsa.com>> wrote:
A possible (bad) workaround would be to use the collect_list function. This 
will give you all the values in an array (list) and you can then create a UDF 
to do the aggregation yourself. This would be very slow and cost a lot of 
memory but it would work if your cluster can handle it.
This is the only workaround I can think of, otherwise you  will need to write 
the UDAF in java/scala and wrap it for python use. If you need an example on 
how to do so I can provide one.

From: Tobi Bosede [mailto:ani.to...@gmail.com<mailto:ani.to...@gmail.com>]
Sent: Sunday, October 16, 2016 7:49 PM
To: Holden Karau
Cc: user
Subject: Re: Aggregate UDF (UDAF) in Python

OK, I misread the year on the dev list. Can you comment on work arounds? (I.e. 
question about if scala/java are the only option.)

On Sun, Oct 16, 2016 at 12:09 PM, Holden Karau 
<hol...@pigscanfly.ca<mailto:hol...@pigscanfly.ca>> wrote:
The comment on the developer list is from earlier this week. I'm not sure why 
UDAF support hasn't made the hop to Python - while I work a fair amount on 
PySpark it's mostly in core & ML and not a lot with SQL so there could be good 
reasons I'm just not familiar with. We can try pinging Davies or Michael on the 
JIRA to see what their thoughts are.

On Sunday, October 16, 2016, Tobi Bosede 
<ani.to...@gmail.com<mailto:ani.to...@gmail.com>> wrote:
Thanks for the info Holden.

So it seems both the jira and the comment on the developer list are over a year 
old. More surprising, the jira has no assignee. Any particular reason for the 
lack of activity in this area?

Is writing scala/java the only work around for this? I hear a lot of people say 
python is the gateway language to scala. It is because of issues like this that 
people use scala for Spark rather than python or eventually abandon python for 
scala. It just takes too long for features to get ported over from scala/java.

On Sun, Oct 16, 2016 at 8:42 AM, Holden Karau 
<hol...@pigscanfly.ca<mailto:hol...@pigscanfly.ca>> wrote:
I don't believe UDAFs are available in PySpark as this came up on the developer 
list while I was asking for what features people were missing in PySpark - see 
 . The JIRA for tacking this issue is at 

On Sat, Oct 15, 2016 at 7:20 PM, Tobi Bosede 
<ani.to...@gmail.com<mailto:ani.to...@gmail.com>> wrote:

I am trying to use a UDF that calculates inter-quartile (IQR) range for pivot() 
and SQL in pyspark and got the error that my function wasn't an aggregate 
function in both scenarios. Does anyone know if UDAF functionality is available 
in python? If not, what can I do as a work around?


Cell : 425-233-8271<tel:425-233-8271>
Twitter: https://twitter.com/holdenkarau

Cell : 425-233-8271<tel:425-233-8271>
Twitter: https://twitter.com/holdenkarau

Reply via email to