cloud-fan commented on a change in pull request #25024: [SPARK-27296][SQL] 
Allows Aggregator to be registered as a UDF

 File path: 
 @@ -0,0 +1,404 @@
+package org.apache.spark.sql.hive.execution
+import java.lang.{Double => jlDouble, Integer => jlInt, Long => jlLong}
+import scala.collection.JavaConverters._
+import scala.util.Random
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.expressions.{Aggregator}
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.hive.test.TestHiveSingleton
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SQLTestUtils
+import org.apache.spark.sql.types._
+class MyDoubleAvgAggBase extends Aggregator[jlDouble, (Double, Long), 
jlDouble] {
+  def zero: (Double, Long) = (0.0, 0L)
+  def reduce(b: (Double, Long), a: jlDouble): (Double, Long) = {
+    if (a != null) (b._1 + a, b._2 + 1L) else b
+  }
+  def merge(b1: (Double, Long), b2: (Double, Long)): (Double, Long) =
+    (b1._1 + b2._1, b1._2 + b2._2)
+  def finish(r: (Double, Long)): jlDouble =
+    if (r._2 > 0L) 100.0 + (r._1 / r._2.toDouble) else null
+  def bufferEncoder: Encoder[(Double, Long)] =
+    Encoders.tuple(Encoders.scalaDouble, Encoders.scalaLong)
+  def outputEncoder: Encoder[jlDouble] = Encoders.DOUBLE
+object MyDoubleAvgAgg extends MyDoubleAvgAggBase
+object MyDoubleSumAgg extends MyDoubleAvgAggBase {
+  override def finish(r: (Double, Long)): jlDouble = if (r._2 > 0L) r._1 else 
+object LongProductSumAgg extends Aggregator[(jlLong, jlLong), Long, jlLong] {
+  def zero: Long = 0L
+  def reduce(b: Long, a: (jlLong, jlLong)): Long = {
+    if ((a._1 != null) && (a._2 != null)) b + (a._1 * a._2) else b
+  }
+  def merge(b1: Long, b2: Long): Long = b1 + b2
+  def finish(r: Long): jlLong = r
+  def bufferEncoder: Encoder[Long] = Encoders.scalaLong
+  def outputEncoder: Encoder[jlLong] = Encoders.LONG
+@SQLUserDefinedType(udt = classOf[CountSerDeUDT])
+case class CountSerDeSQL(nSer: Int, nDeSer: Int, sum: Int)
+class CountSerDeUDT extends UserDefinedType[CountSerDeSQL] {
+  def userClass: Class[CountSerDeSQL] = classOf[CountSerDeSQL]
+  override def typeName: String = "count-ser-de"
+  private[spark] override def asNullable: CountSerDeUDT = this
+  def sqlType: DataType = StructType(
+    StructField("nSer", IntegerType, false) ::
+    StructField("nDeSer", IntegerType, false) ::
+    StructField("sum", IntegerType, false) ::
+    Nil)
+  def serialize(sql: CountSerDeSQL): Any = {
+    val row = new GenericInternalRow(3)
+    row.setInt(0, 1 + sql.nSer)
+    row.setInt(1, sql.nDeSer)
+    row.setInt(2, sql.sum)
+    row
+  }
+  def deserialize(any: Any): CountSerDeSQL = any match {
+    case row: InternalRow if (row.numFields == 3) =>
+      CountSerDeSQL(row.getInt(0), 1 + row.getInt(1), row.getInt(2))
+    case u => throw new Exception(s"failed to deserialize: $u")
+  }
+  override def equals(obj: Any): Boolean = {
+    obj match {
+      case _: CountSerDeUDT => true
+      case _ => false
+    }
+  }
+  override def hashCode(): Int = classOf[CountSerDeUDT].getName.hashCode()
+case object CountSerDeUDT extends CountSerDeUDT
+object CountSerDeAgg extends Aggregator[Int, CountSerDeSQL, CountSerDeSQL] {
+  def zero: CountSerDeSQL = CountSerDeSQL(0, 0, 0)
+  def reduce(b: CountSerDeSQL, a: Int): CountSerDeSQL = b.copy(sum = b.sum + a)
+  def merge(b1: CountSerDeSQL, b2: CountSerDeSQL): CountSerDeSQL =
+    CountSerDeSQL(b1.nSer + b2.nSer, b1.nDeSer + b2.nDeSer, b1.sum + b2.sum)
+  def finish(r: CountSerDeSQL): CountSerDeSQL = r
+  def bufferEncoder: Encoder[CountSerDeSQL] = 
+  def outputEncoder: Encoder[CountSerDeSQL] = 
+abstract class UDAQuerySuite extends QueryTest with SQLTestUtils with 
TestHiveSingleton {
+  import testImplicits._
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    val data1 = Seq[(Integer, Integer)](
+      (1, 10),
+      (null, -60),
+      (1, 20),
+      (1, 30),
+      (2, 0),
+      (null, -10),
+      (2, -1),
+      (2, null),
+      (2, null),
+      (null, 100),
+      (3, null),
+      (null, null),
+      (3, null)).toDF("key", "value")
+    data1.write.saveAsTable("agg1")
+    val data2 = Seq[(Integer, Integer, Integer)](
+      (1, 10, -10),
+      (null, -60, 60),
+      (1, 30, -30),
+      (1, 30, 30),
+      (2, 1, 1),
+      (null, -10, 10),
+      (2, -1, null),
+      (2, 1, 1),
+      (2, null, 1),
+      (null, 100, -10),
+      (3, null, 3),
+      (null, null, null),
+      (3, null, null)).toDF("key", "value1", "value2")
+    data2.write.saveAsTable("agg2")
+    val data3 = Seq[(Seq[Integer], Integer, Integer)](
+      (Seq[Integer](1, 1), 10, -10),
+      (Seq[Integer](null), -60, 60),
+      (Seq[Integer](1, 1), 30, -30),
+      (Seq[Integer](1), 30, 30),
+      (Seq[Integer](2), 1, 1),
+      (null, -10, 10),
+      (Seq[Integer](2, 3), -1, null),
+      (Seq[Integer](2, 3), 1, 1),
+      (Seq[Integer](2, 3, 4), null, 1),
+      (Seq[Integer](null), 100, -10),
+      (Seq[Integer](3), null, 3),
+      (null, null, null),
+      (Seq[Integer](3), null, null)).toDF("key", "value1", "value2")
+    data3.write.saveAsTable("agg3")
+    val emptyDF = spark.createDataFrame(
+      sparkContext.emptyRDD[Row],
+      StructType(StructField("key", StringType) :: StructField("value", 
IntegerType) :: Nil))
+    emptyDF.createOrReplaceTempView("emptyTable")
+    // Register UDAs
+    spark.udf.register("mydoublesum", udaf(MyDoubleSumAgg))
+    spark.udf.register("mydoubleavg", udaf(MyDoubleAvgAgg))
+    spark.udf.register("longProductSum", udaf(LongProductSumAgg))
+  }
+  override def afterAll(): Unit = {
+    try {
+      spark.sql("DROP TABLE IF EXISTS agg1")
+      spark.sql("DROP TABLE IF EXISTS agg2")
+      spark.sql("DROP TABLE IF EXISTS agg3")
+      spark.catalog.dropTempView("emptyTable")
+    } finally {
+      super.afterAll()
+    }
+  }
+  test("aggregators") {
+    checkAnswer(
+      spark.sql(
+        """
+          |SELECT
+          |  key,
+          |  mydoublesum(value + 1.5 * key),
+          |  mydoubleavg(value),
+          |  avg(value - key),
+          |  mydoublesum(value - 1.5 * key),
+          |  avg(value)
+          |FROM agg1
+          |GROUP BY key
+        """.stripMargin),
+      Row(1, 64.5, 120.0, 19.0, 55.5, 20.0) ::
+        Row(2, 5.0, 99.5, -2.5, -7.0, -0.5) ::
+        Row(3, null, null, null, null, null) ::
+        Row(null, null, 110.0, null, null, 10.0) :: Nil)
+  }
+  test("non-deterministic children expressions of aggregator") {
+    val e = intercept[AnalysisException] {
+      spark.sql(
+        """
+          |SELECT mydoublesum(value + 1.5 * key + rand())
+          |FROM agg1
+          |GROUP BY key
+        """.stripMargin)
+    }.getMessage
+    assert(Seq("nondeterministic expression",
+      "should not appear in the arguments of an aggregate 
+  }
+  test("interpreted aggregate function") {
+    checkAnswer(
+      spark.sql(
+        """
+          |SELECT mydoublesum(value), key
+          |FROM agg1
+          |GROUP BY key
+        """.stripMargin),
+      Row(60.0, 1) :: Row(-1.0, 2) :: Row(null, 3) :: Row(30.0, null) :: Nil)
+    checkAnswer(
+      spark.sql(
+        """
+          |SELECT mydoublesum(value) FROM agg1
+        """.stripMargin),
+      Row(89.0) :: Nil)
+    checkAnswer(
+      spark.sql(
+        """
+          |SELECT mydoublesum(null)
+        """.stripMargin),
+      Row(null) :: Nil)
+  }
+  test("interpreted and expression-based aggregation functions") {
+    checkAnswer(
+      spark.sql(
+        """
+          |SELECT mydoublesum(value), key, avg(value)
+          |FROM agg1
+          |GROUP BY key
+        """.stripMargin),
+      Row(60.0, 1, 20.0) ::
+        Row(-1.0, 2, -0.5) ::
+        Row(null, 3, null) ::
+        Row(30.0, null, 10.0) :: Nil)
+    checkAnswer(
+      spark.sql(
+        """
+          |SELECT
+          |  mydoublesum(value + 1.5 * key),
+          |  avg(value - key),
+          |  key,
+          |  mydoublesum(value - 1.5 * key),
+          |  avg(value)
+          |FROM agg1
+          |GROUP BY key
+        """.stripMargin),
+      Row(64.5, 19.0, 1, 55.5, 20.0) ::
+        Row(5.0, -2.5, 2, -7.0, -0.5) ::
+        Row(null, null, 3, null, null) ::
+        Row(null, null, null, null, 10.0) :: Nil)
+  }
+  test("single distinct column set") {
+    checkAnswer(
+      spark.sql(
+        """
+          |SELECT
+          |  mydoubleavg(distinct value1),
+          |  avg(value1),
+          |  avg(value2),
+          |  key,
+          |  mydoubleavg(value1 - 1),
+          |  mydoubleavg(distinct value1) * 0.1,
+          |  avg(value1 + value2)
+          |FROM agg2
+          |GROUP BY key
+        """.stripMargin),
+      Row(120.0, 70.0/3.0, -10.0/3.0, 1, 67.0/3.0 + 100.0, 12.0, 20.0) ::
+        Row(100.0, 1.0/3.0, 1.0, 2, -2.0/3.0 + 100.0, 10.0, 2.0) ::
+        Row(null, null, 3.0, 3, null, null, null) ::
+        Row(110.0, 10.0, 20.0, null, 109.0, 11.0, 30.0) :: Nil)
+    checkAnswer(
+      spark.sql(
+        """
+          |SELECT
+          |  key,
+          |  mydoubleavg(distinct value1),
+          |  mydoublesum(value2),
+          |  mydoublesum(distinct value1),
+          |  mydoubleavg(distinct value1),
+          |  mydoubleavg(value1)
+          |FROM agg2
+          |GROUP BY key
+        """.stripMargin),
+      Row(1, 120.0, -10.0, 40.0, 120.0, 70.0/3.0 + 100.0) ::
+        Row(2, 100.0, 3.0, 0.0, 100.0, 1.0/3.0 + 100.0) ::
+        Row(3, null, 3.0, null, null, null) ::
+        Row(null, 110.0, 60.0, 30.0, 110.0, 110.0) :: Nil)
+  }
+  test("multiple distinct multiple columns sets") {
+    checkAnswer(
+      spark.sql(
+        """
+          |SELECT
+          |  key,
+          |  count(distinct value1),
+          |  sum(distinct value1),
+          |  count(distinct value2),
+          |  sum(distinct value2),
+          |  count(distinct value1, value2),
+          |  longProductSum(distinct value1, value2),
+          |  count(value1),
+          |  sum(value1),
+          |  count(value2),
+          |  sum(value2),
+          |  longProductSum(value1, value2),
+          |  count(*),
+          |  count(1)
+          |FROM agg2
+          |GROUP BY key
+        """.stripMargin),
+      Row(null, 3, 30, 3, 60, 3, -4700, 3, 30, 3, 60, -4700, 4, 4) ::
+        Row(1, 2, 40, 3, -10, 3, -100, 3, 70, 3, -10, -100, 3, 3) ::
+        Row(2, 2, 0, 1, 1, 1, 1, 3, 1, 3, 3, 2, 4, 4) ::
+        Row(3, 0, null, 1, 3, 0, 0, 0, null, 1, 3, 0, 2, 2) :: Nil)
+  }
+  test("verify aggregator ser/de behavior") {
+    val data = sparkContext.parallelize((1 to 100).toSeq, 3).toDF("value1")
+    val agg = udaf(CountSerDeAgg)
+    checkAnswer(
+      data.agg(agg($"value1")),
+      Row(CountSerDeSQL(4, 4, 5050)) :: Nil)
+  }
 Review comment:
   can we add an negative test that fails type check? e.g. the input column is 
boolean but aggregator requires long.

Reply via email to