A simple example: We have a scala file:
package com.myorg.example import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.functions.{rand, sum} import org.apache.spark.sql.types.{DataType, DoubleType, StructField, StructType} class PerformSumUDAF() extends UserDefinedAggregateFunction { def inputSchema: StructType = StructType(Array(StructField("item", DoubleType))) def bufferSchema: StructType = StructType(Array(StructField("sum", DoubleType))) def dataType: DataType = DoubleType def deterministic: Boolean = true def initialize(buffer: MutableAggregationBuffer): Unit = { buffer(0) = 0.toDouble } def update(buffer: MutableAggregationBuffer, input: Row): Unit = { buffer(0) = buffer.getDouble(0) + input.getDouble(0) } def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { buffer1(0) = buffer1.getDouble(0) + buffer2.getDouble(0) } def evaluate(buffer: Row): Any = { buffer.getDouble(0) } } We place the file under myroot/src/main/scala/com/myorg/example/ExampleUDAF.scala Under myroot we create a pom file (sorry for not cleaning it up, it includes some stuff you probably not need like guava and avro) <project> <groupId>edu.berkeley</groupId> <artifactId>simple-project</artifactId> <modelVersion>4.0.0</modelVersion> <name>example packages</name> <packaging>jar</packaging> <version>1.0</version> <properties> <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding> <maven.compiler.source>1.8</maven.compiler.source> <maven.compiler.target>1.8</maven.compiler.target> </properties> <dependencies> <dependency> <groupId>com.google.guava</groupId> <artifactId>guava</artifactId> <version>19.0</version> </dependency> <dependency> <!-- Spark dependency --> <groupId>org.apache.spark</groupId> <artifactId>spark-core_2.11</artifactId> <version>2.0.0</version> <scope>provided</scope> </dependency> <dependency> <groupId>org.postgresql</groupId> <artifactId>postgresql</artifactId> <version>9.4.1208</version> </dependency> <dependency> <groupId>com.databricks</groupId> <artifactId>spark-avro_2.11</artifactId> <version>3.0.0-preview2</version> </dependency> <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-sql_2.11</artifactId> <version>2.0.0</version> <scope>provided</scope> </dependency> <dependency> <groupId>org.scala-lang</groupId> <artifactId>scala-library</artifactId> <version>2.11.8</version> <scope>provided</scope> </dependency> </dependencies> <build> <plugins> <plugin> <groupId>org.apache.maven.plugins</groupId> <artifactId>maven-shade-plugin</artifactId> <version>2.4.3</version> <executions> <execution> <phase>package</phase> <goals> <goal>shade</goal> </goals> <configuration> <relocations> <relocation> <pattern>com.google.common</pattern> <shadedPattern>com.myorg.shaded.com.google.common</shadedPattern> </relocation> </relocations> <finalName>simple-project-1.0-jar-with-dependencies</finalName> </configuration> </execution> </executions> </plugin> <plugin> <groupId>org.scala-tools</groupId> <artifactId>maven-scala-plugin</artifactId> <version>2.15.2</version> <executions> <execution> <goals> <goal>compile</goal> </goals> </execution> </executions> </plugin> </plugins> </build> </project> Now you can compile the scala like so: mvn clean install (I assume you have maven installed). Now we want to call this from python (assuming spark is your spark session): # get a reference dataframe to do the example on: df = spark.range(20) # get the jvm pointer jvm = spark.sparkContext._gateway.jvm # import the class from py4j.java_gateway import java_import java_import(jvm, "com.myorg.example.PerformSumUDAF") #create an object from the class: udafObj = jvm.com.myorg.example.PerformSumUDAF() # define a python function to do the aggregation. from pyspark.sql.column import Column, _to_java_column, _to_seq def pythonudaf(c): # the _to_seq portion is because we need to convert this to a sequence of # input columns the way scala (java) expects them. The returned # value must then be converted to a pyspark Column return Column(udafObj.apply(_to_seq(spark.sparkContext, [c], _to_java_column))) # now lets use the function df.agg(pythonudaf(df.id)).show() Lastly when you run, make sure to use both –jars and --driver-class-path with the jar created from scala to make sure it is available in all nodes. From: Tobi Bosede [mailto:ani.to...@gmail.com] Sent: Monday, October 17, 2016 10:15 PM To: Mendelson, Assaf Cc: Holden Karau; user Subject: Re: Aggregate UDF (UDAF) in Python Thanks Assaf. Yes please provide an example of how to wrap code for python. I am leaning towards scala. On Mon, Oct 17, 2016 at 1:50 PM, Mendelson, Assaf <assaf.mendel...@rsa.com<mailto:assaf.mendel...@rsa.com>> wrote: A possible (bad) workaround would be to use the collect_list function. This will give you all the values in an array (list) and you can then create a UDF to do the aggregation yourself. This would be very slow and cost a lot of memory but it would work if your cluster can handle it. This is the only workaround I can think of, otherwise you will need to write the UDAF in java/scala and wrap it for python use. If you need an example on how to do so I can provide one. Assaf. From: Tobi Bosede [mailto:ani.to...@gmail.com<mailto:ani.to...@gmail.com>] Sent: Sunday, October 16, 2016 7:49 PM To: Holden Karau Cc: user Subject: Re: Aggregate UDF (UDAF) in Python OK, I misread the year on the dev list. Can you comment on work arounds? (I.e. question about if scala/java are the only option.) On Sun, Oct 16, 2016 at 12:09 PM, Holden Karau <hol...@pigscanfly.ca<mailto:hol...@pigscanfly.ca>> wrote: The comment on the developer list is from earlier this week. I'm not sure why UDAF support hasn't made the hop to Python - while I work a fair amount on PySpark it's mostly in core & ML and not a lot with SQL so there could be good reasons I'm just not familiar with. We can try pinging Davies or Michael on the JIRA to see what their thoughts are. On Sunday, October 16, 2016, Tobi Bosede <ani.to...@gmail.com<mailto:ani.to...@gmail.com>> wrote: Thanks for the info Holden. So it seems both the jira and the comment on the developer list are over a year old. More surprising, the jira has no assignee. Any particular reason for the lack of activity in this area? Is writing scala/java the only work around for this? I hear a lot of people say python is the gateway language to scala. It is because of issues like this that people use scala for Spark rather than python or eventually abandon python for scala. It just takes too long for features to get ported over from scala/java. On Sun, Oct 16, 2016 at 8:42 AM, Holden Karau <hol...@pigscanfly.ca<mailto:hol...@pigscanfly.ca>> wrote: I don't believe UDAFs are available in PySpark as this came up on the developer list while I was asking for what features people were missing in PySpark - see http://apache-spark-developers-list.1001551.n3.nabble.com/Python-Spark-Improvements-forked-from-Spark-Improvement-Proposals-td19422.html . The JIRA for tacking this issue is at https://issues.apache.org/jira/browse/SPARK-10915 On Sat, Oct 15, 2016 at 7:20 PM, Tobi Bosede <ani.to...@gmail.com<mailto:ani.to...@gmail.com>> wrote: Hello, I am trying to use a UDF that calculates inter-quartile (IQR) range for pivot() and SQL in pyspark and got the error that my function wasn't an aggregate function in both scenarios. Does anyone know if UDAF functionality is available in python? If not, what can I do as a work around? Thanks, Tobi -- Cell : 425-233-8271<tel:425-233-8271> Twitter: https://twitter.com/holdenkarau -- Cell : 425-233-8271<tel:425-233-8271> Twitter: https://twitter.com/holdenkarau