Repository: spark
Updated Branches:
  refs/heads/branch-1.5 252eb6193 -> 29ace3bbf


[SPARK-9664] [SQL] Remove UDAFRegistration and add apply to 
UserDefinedAggregateFunction.

https://issues.apache.org/jira/browse/SPARK-9664

Author: Yin Huai <[email protected]>

Closes #7982 from yhuai/udafRegister and squashes the following commits:

0cc2287 [Yin Huai] Remove UDAFRegistration and add apply to 
UserDefinedAggregateFunction.

(cherry picked from commit d5a9af3230925c347d0904fe7f2402e468e80bc8)
Signed-off-by: Reynold Xin <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/29ace3bb
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/29ace3bb
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/29ace3bb

Branch: refs/heads/branch-1.5
Commit: 29ace3bbf06bb786904f2243a5a4d1204f53c3a0
Parents: 252eb61
Author: Yin Huai <[email protected]>
Authored: Wed Aug 5 21:50:35 2015 -0700
Committer: Reynold Xin <[email protected]>
Committed: Wed Aug 5 21:50:41 2015 -0700

----------------------------------------------------------------------
 .../scala/org/apache/spark/sql/SQLContext.scala |  3 --
 .../org/apache/spark/sql/UDAFRegistration.scala | 36 --------------------
 .../org/apache/spark/sql/UDFRegistration.scala  | 16 +++++++++
 .../spark/sql/execution/aggregate/udaf.scala    |  8 ++---
 .../org/apache/spark/sql/expressions/udaf.scala | 32 ++++++++++++++++-
 .../scala/org/apache/spark/sql/functions.scala  |  1 +
 .../spark/sql/hive/JavaDataFrameSuite.java      | 26 ++++++++++++++
 .../hive/execution/AggregationQuerySuite.scala  |  4 +--
 8 files changed, 80 insertions(+), 46 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/29ace3bb/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index ffc2baf..6f8ffb5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -291,9 +291,6 @@ class SQLContext(@transient val sparkContext: SparkContext)
   @transient
   val udf: UDFRegistration = new UDFRegistration(this)
 
-  @transient
-  val udaf: UDAFRegistration = new UDAFRegistration(this)
-
   /**
    * Returns true if the table is currently cached in-memory.
    * @group cachemgmt

http://git-wip-us.apache.org/repos/asf/spark/blob/29ace3bb/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala
deleted file mode 100644
index 0d4e30f..0000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala
+++ /dev/null
@@ -1,36 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql
-
-import org.apache.spark.Logging
-import org.apache.spark.sql.catalyst.expressions.{Expression}
-import org.apache.spark.sql.execution.aggregate.ScalaUDAF
-import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
-
-class UDAFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
-
-  private val functionRegistry = sqlContext.functionRegistry
-
-  def register(
-      name: String,
-      func: UserDefinedAggregateFunction): UserDefinedAggregateFunction = {
-    def builder(children: Seq[Expression]) = ScalaUDAF(children, func)
-    functionRegistry.registerFunction(name, builder)
-    func
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/29ace3bb/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
index 7cd7421..1f27056 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
@@ -26,6 +26,8 @@ import org.apache.spark.Logging
 import org.apache.spark.sql.api.java._
 import org.apache.spark.sql.catalyst.ScalaReflection
 import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF}
+import org.apache.spark.sql.execution.aggregate.ScalaUDAF
+import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
 import org.apache.spark.sql.types.DataType
 
 /**
@@ -52,6 +54,20 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) 
extends Logging {
     functionRegistry.registerFunction(name, udf.builder)
   }
 
+  /**
+   * Register a user-defined aggregate function (UDAF).
+   * @param name the name of the UDAF.
+   * @param udaf the UDAF needs to be registered.
+   * @return the registered UDAF.
+   */
+  def register(
+      name: String,
+      udaf: UserDefinedAggregateFunction): UserDefinedAggregateFunction = {
+    def builder(children: Seq[Expression]) = ScalaUDAF(children, udaf)
+    functionRegistry.registerFunction(name, builder)
+    udaf
+  }
+
   // scalastyle:off
 
   /* register 0-22 were generated by this script

http://git-wip-us.apache.org/repos/asf/spark/blob/29ace3bb/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
index 5fafc91..7619f3e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
@@ -316,7 +316,7 @@ private[sql] case class ScalaUDAF(
 
   override lazy val cloneBufferAttributes = 
bufferAttributes.map(_.newInstance())
 
-  private[this] val childrenSchema: StructType = {
+  private[this] lazy val childrenSchema: StructType = {
     val inputFields = children.zipWithIndex.map {
       case (child, index) =>
         StructField(s"input$index", child.dataType, child.nullable, 
Metadata.empty)
@@ -337,16 +337,16 @@ private[sql] case class ScalaUDAF(
     }
   }
 
-  private[this] val inputToScalaConverters: Any => Any =
+  private[this] lazy val inputToScalaConverters: Any => Any =
     CatalystTypeConverters.createToScalaConverter(childrenSchema)
 
-  private[this] val bufferValuesToCatalystConverters: Array[Any => Any] = {
+  private[this] lazy val bufferValuesToCatalystConverters: Array[Any => Any] = 
{
     bufferSchema.fields.map { field =>
       CatalystTypeConverters.createToCatalystConverter(field.dataType)
     }
   }
 
-  private[this] val bufferValuesToScalaConverters: Array[Any => Any] = {
+  private[this] lazy val bufferValuesToScalaConverters: Array[Any => Any] = {
     bufferSchema.fields.map { field =>
       CatalystTypeConverters.createToScalaConverter(field.dataType)
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/29ace3bb/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala
index 278dd43..5180871 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala
@@ -17,7 +17,10 @@
 
 package org.apache.spark.sql.expressions
 
-import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.expressions.ScalaUDF
+import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, 
AggregateExpression2}
+import org.apache.spark.sql.execution.aggregate.ScalaUDAF
+import org.apache.spark.sql.{Column, Row}
 import org.apache.spark.sql.types._
 import org.apache.spark.annotation.Experimental
 
@@ -87,6 +90,33 @@ abstract class UserDefinedAggregateFunction extends 
Serializable {
    * aggregation buffer.
    */
   def evaluate(buffer: Row): Any
+
+  /**
+   * Creates a [[Column]] for this UDAF with given [[Column]]s as arguments.
+   */
+  @scala.annotation.varargs
+  def apply(exprs: Column*): Column = {
+    val aggregateExpression =
+      AggregateExpression2(
+        ScalaUDAF(exprs.map(_.expr), this),
+        Complete,
+        isDistinct = false)
+    Column(aggregateExpression)
+  }
+
+  /**
+   * Creates a [[Column]] for this UDAF with given [[Column]]s as arguments.
+   * If `isDistinct` is true, this UDAF is working on distinct input values.
+   */
+  @scala.annotation.varargs
+  def apply(isDistinct: Boolean, exprs: Column*): Column = {
+    val aggregateExpression =
+      AggregateExpression2(
+        ScalaUDAF(exprs.map(_.expr), this),
+        Complete,
+        isDistinct = isDistinct)
+    Column(aggregateExpression)
+  }
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/29ace3bb/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 5a10c38..39aa905 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -2500,6 +2500,7 @@ object functions {
    * @group udf_funcs
    * @since 1.5.0
    */
+  @scala.annotation.varargs
   def callUDF(udfName: String, cols: Column*): Column = {
     UnresolvedFunction(udfName, cols.map(_.expr), isDistinct = false)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/29ace3bb/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java 
b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java
index 613b2bc..21b053f 100644
--- 
a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java
+++ 
b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java
@@ -29,8 +29,12 @@ import org.junit.Test;
 import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.sql.*;
 import org.apache.spark.sql.expressions.Window;
+import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
+import static org.apache.spark.sql.functions.*;
 import org.apache.spark.sql.hive.HiveContext;
 import org.apache.spark.sql.hive.test.TestHive$;
+import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
+import test.org.apache.spark.sql.hive.aggregate.MyDoubleSum;
 
 public class JavaDataFrameSuite {
   private transient JavaSparkContext sc;
@@ -77,4 +81,26 @@ public class JavaDataFrameSuite {
         "      ROWS BETWEEN 1 preceding and 1 following) " +
         "FROM window_table").collectAsList());
   }
+
+  @Test
+  public void testUDAF() {
+    DataFrame df = hc.range(0, 100).unionAll(hc.range(0, 
100)).select(col("id").as("value"));
+    UserDefinedAggregateFunction udaf = new MyDoubleSum();
+    UserDefinedAggregateFunction registeredUDAF = 
hc.udf().register("mydoublesum", udaf);
+    // Create Columns for the UDAF. For now, callUDF does not take an argument 
to specific if
+    // we want to use distinct aggregation.
+    DataFrame aggregatedDF =
+      df.groupBy()
+        .agg(
+          udaf.apply(true, col("value")),
+          udaf.apply(col("value")),
+          registeredUDAF.apply(col("value")),
+          callUDF("mydoublesum", col("value")));
+
+    List<Row> expectedResult = new ArrayList<Row>();
+    expectedResult.add(RowFactory.create(4950.0, 9900.0, 9900.0, 9900.0));
+    checkAnswer(
+      aggregatedDF,
+      expectedResult);
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/29ace3bb/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
index 6f0db27..4b35c8f 100644
--- 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
+++ 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
@@ -73,8 +73,8 @@ abstract class AggregationQuerySuite extends QueryTest with 
SQLTestUtils with Be
     emptyDF.registerTempTable("emptyTable")
 
     // Register UDAFs
-    sqlContext.udaf.register("mydoublesum", new MyDoubleSum)
-    sqlContext.udaf.register("mydoubleavg", new MyDoubleAvg)
+    sqlContext.udf.register("mydoublesum", new MyDoubleSum)
+    sqlContext.udf.register("mydoubleavg", new MyDoubleAvg)
   }
 
   override def afterAll(): Unit = {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to