Repository: spark
Updated Branches:
  refs/heads/master 5f9633dc9 -> 327bb3007


[SPARK-23911][SQL] Add aggregate function.

## What changes were proposed in this pull request?

This pr adds `aggregate` function which applies a binary operator to an initial 
state and all elements in the array, and reduces this to a single state. The 
final state is converted into the final result by applying a finish function.

```sql
> SELECT aggregate(array(1, 2, 3), (acc, x) -> acc + x);
 6
> SELECT aggregate(array(1, 2, 3), (acc, x) -> acc + x, acc -> acc * 10);
 60
```

## How was this patch tested?

Added tests.

Author: Takuya UESHIN <ues...@databricks.com>

Closes #21982 from ueshin/issues/SPARK-23911/aggregate.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/327bb300
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/327bb300
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/327bb300

Branch: refs/heads/master
Commit: 327bb30075834c873cdb78061c9b647e5e13b8a6
Parents: 5f9633d
Author: Takuya UESHIN <ues...@databricks.com>
Authored: Sun Aug 5 08:58:35 2018 +0900
Committer: Takuya UESHIN <ues...@databricks.com>
Committed: Sun Aug 5 08:58:35 2018 +0900

----------------------------------------------------------------------
 .../catalyst/analysis/FunctionRegistry.scala    |   1 +
 .../expressions/higherOrderFunctions.scala      |  95 +++++++++++++++
 .../expressions/HigherOrderFunctionsSuite.scala |  50 ++++++++
 .../sql-tests/inputs/higher-order-functions.sql |  12 ++
 .../results/higher-order-functions.sql.out      |  40 +++++-
 .../spark/sql/DataFrameFunctionsSuite.scala     | 121 +++++++++++++++++++
 6 files changed, 318 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/327bb300/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index d0efe97..35f8de1 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -442,6 +442,7 @@ object FunctionRegistry {
     expression[ArrayDistinct]("array_distinct"),
     expression[ArrayTransform]("transform"),
     expression[ArrayFilter]("filter"),
+    expression[ArrayAggregate]("aggregate"),
     CreateStruct.registryEntry,
 
     // misc functions

http://git-wip-us.apache.org/repos/asf/spark/blob/327bb300/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
index e15225f..20c7f7d 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
@@ -22,6 +22,7 @@ import java.util.concurrent.atomic.AtomicReference
 import scala.collection.mutable
 
 import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, 
UnresolvedAttribute}
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
 import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
@@ -76,6 +77,13 @@ case class LambdaFunction(
   override def eval(input: InternalRow): Any = function.eval(input)
 }
 
+object LambdaFunction {
+  val identity: LambdaFunction = {
+    val id = UnresolvedAttribute.quoted("id")
+    LambdaFunction(id, Seq(id))
+  }
+}
+
 /**
  * A higher order function takes one or more (lambda) functions and applies 
these to some objects.
  * The function produces a number of variables which can be consumed by some 
lambda function.
@@ -270,3 +278,90 @@ case class ArrayFilter(
 
   override def prettyName: String = "filter"
 }
+
+/**
+ * Applies a binary operator to a start value and all elements in the array.
+ */
+@ExpressionDescription(
+  usage =
+    """
+      _FUNC_(expr, start, merge, finish) - Applies a binary operator to an 
initial state and all
+      elements in the array, and reduces this to a single state. The final 
state is converted
+      into the final result by applying a finish function.
+    """,
+  examples = """
+    Examples:
+      > SELECT _FUNC_(array(1, 2, 3), (acc, x) -> acc + x);
+       6
+      > SELECT _FUNC_(array(1, 2, 3), (acc, x) -> acc + x, acc -> acc * 10);
+       60
+  """,
+  since = "2.4.0")
+case class ArrayAggregate(
+    input: Expression,
+    zero: Expression,
+    merge: Expression,
+    finish: Expression)
+  extends HigherOrderFunction with CodegenFallback {
+
+  def this(input: Expression, zero: Expression, merge: Expression) = {
+    this(input, zero, merge, LambdaFunction.identity)
+  }
+
+  override def inputs: Seq[Expression] = input :: zero :: Nil
+
+  override def functions: Seq[Expression] = merge :: finish :: Nil
+
+  override def nullable: Boolean = input.nullable || finish.nullable
+
+  override def dataType: DataType = finish.dataType
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    if (!ArrayType.acceptsType(input.dataType)) {
+      TypeCheckResult.TypeCheckFailure(
+        s"argument 1 requires ${ArrayType.simpleString} type, " +
+          s"however, '${input.sql}' is of ${input.dataType.catalogString} 
type.")
+    } else if (!DataType.equalsStructurally(
+        zero.dataType, merge.dataType, ignoreNullability = true)) {
+      TypeCheckResult.TypeCheckFailure(
+        s"argument 3 requires ${zero.dataType.simpleString} type, " +
+          s"however, '${merge.sql}' is of ${merge.dataType.catalogString} 
type.")
+    } else {
+      TypeCheckResult.TypeCheckSuccess
+    }
+  }
+
+  override def bind(f: (Expression, Seq[(DataType, Boolean)]) => 
LambdaFunction): ArrayAggregate = {
+    // Be very conservative with nullable. We cannot be sure that the 
accumulator does not
+    // evaluate to null. So we always set nullable to true here.
+    val elem = 
ArrayBasedHigherOrderFunction.elementArgumentType(input.dataType)
+    val acc = zero.dataType -> true
+    val newMerge = f(merge, acc :: elem :: Nil)
+    val newFinish = f(finish, acc :: Nil)
+    copy(merge = newMerge, finish = newFinish)
+  }
+
+  @transient lazy val LambdaFunction(_,
+    Seq(accForMergeVar: NamedLambdaVariable, elementVar: NamedLambdaVariable), 
_) = merge
+  @transient lazy val LambdaFunction(_, Seq(accForFinishVar: 
NamedLambdaVariable), _) = finish
+
+  override def eval(input: InternalRow): Any = {
+    val arr = this.input.eval(input).asInstanceOf[ArrayData]
+    if (arr == null) {
+      null
+    } else {
+      val Seq(mergeForEval, finishForEval) = functionsForEval
+      accForMergeVar.value.set(zero.eval(input))
+      var i = 0
+      while (i < arr.numElements()) {
+        elementVar.value.set(arr.get(i, elementVar.dataType))
+        accForMergeVar.value.set(mergeForEval.eval(input))
+        i += 1
+      }
+      accForFinishVar.value.set(accForMergeVar.value.get)
+      finishForEval.eval(input)
+    }
+  }
+
+  override def prettyName: String = "aggregate"
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/327bb300/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala
index d1330c7..40cfc0c 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala
@@ -59,6 +59,27 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with 
ExpressionEvalHelper
     ArrayFilter(expr, createLambda(at.elementType, at.containsNull, f))
   }
 
+  def aggregate(
+      expr: Expression,
+      zero: Expression,
+      merge: (Expression, Expression) => Expression,
+      finish: Expression => Expression): Expression = {
+    val at = expr.dataType.asInstanceOf[ArrayType]
+    val zeroType = zero.dataType
+    ArrayAggregate(
+      expr,
+      zero,
+      createLambda(zeroType, true, at.elementType, at.containsNull, merge),
+      createLambda(zeroType, true, finish))
+  }
+
+  def aggregate(
+      expr: Expression,
+      zero: Expression,
+      merge: (Expression, Expression) => Expression): Expression = {
+    aggregate(expr, zero, merge, identity)
+  }
+
   test("ArrayTransform") {
     val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull 
= false))
     val ai1 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, 
containsNull = true))
@@ -131,4 +152,33 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with 
ExpressionEvalHelper
     checkEvaluation(transform(aai, ix => filter(ix, isNullOrOdd)),
       Seq(Seq(1, 3), null, Seq(5)))
   }
+
+  test("ArrayAggregate") {
+    val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull 
= false))
+    val ai1 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, 
containsNull = true))
+    val ai2 = Literal.create(Seq.empty[Int], ArrayType(IntegerType, 
containsNull = false))
+    val ain = Literal.create(null, ArrayType(IntegerType, containsNull = 
false))
+
+    checkEvaluation(aggregate(ai0, 0, (acc, elem) => acc + elem, acc => acc * 
10), 60)
+    checkEvaluation(aggregate(ai1, 0, (acc, elem) => acc + coalesce(elem, 0), 
acc => acc * 10), 40)
+    checkEvaluation(aggregate(ai2, 0, (acc, elem) => acc + elem, acc => acc * 
10), 0)
+    checkEvaluation(aggregate(ain, 0, (acc, elem) => acc + elem, acc => acc * 
10), null)
+
+    val as0 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType, 
containsNull = false))
+    val as1 = Literal.create(Seq("a", null, "c"), ArrayType(StringType, 
containsNull = true))
+    val as2 = Literal.create(Seq.empty[String], ArrayType(StringType, 
containsNull = false))
+    val asn = Literal.create(null, ArrayType(StringType, containsNull = false))
+
+    checkEvaluation(aggregate(as0, "", (acc, elem) => Concat(Seq(acc, elem))), 
"abc")
+    checkEvaluation(aggregate(as1, "", (acc, elem) => Concat(Seq(acc, 
coalesce(elem, "x")))), "axc")
+    checkEvaluation(aggregate(as2, "", (acc, elem) => Concat(Seq(acc, elem))), 
"")
+    checkEvaluation(aggregate(asn, "", (acc, elem) => Concat(Seq(acc, elem))), 
null)
+
+    val aai = Literal.create(Seq[Seq[Integer]](Seq(1, 2, 3), null, Seq(4, 5)),
+      ArrayType(ArrayType(IntegerType, containsNull = false), containsNull = 
true))
+    checkEvaluation(
+      aggregate(aai, 0,
+        (acc, array) => coalesce(aggregate(array, acc, (acc, elem) => acc + 
elem), acc)),
+      15)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/327bb300/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql 
b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql
index f833aa5..136396d 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql
@@ -33,3 +33,15 @@ select filter(cast(null as array<int>), y -> true) as v;
 
 -- Filter nested arrays
 select transform(zs, z -> filter(z, zz -> zz > 50)) as v from nested;
+
+-- Aggregate.
+select aggregate(ys, 0, (y, a) -> y + a + x) as v from nested;
+
+-- Aggregate average.
+select aggregate(ys, (0 as sum, 0 as n), (acc, x) -> (acc.sum + x, acc.n + 1), 
acc -> acc.sum / acc.n) as v from nested;
+
+-- Aggregate nested arrays
+select transform(zs, z -> aggregate(z, 1, (acc, val) -> acc * val * size(z))) 
as v from nested;
+
+-- Aggregate a null array
+select aggregate(cast(null as array<int>), 0, (a, y) -> a + y + 1, a -> a + 2) 
as v;

http://git-wip-us.apache.org/repos/asf/spark/blob/327bb300/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out 
b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out
index 4c5d972..e6f62f2 100644
--- 
a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out
+++ 
b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out
@@ -1,5 +1,5 @@
 -- Automatically generated by SQLQueryTestSuite
--- Number of queries: 11
+-- Number of queries: 15
 
 
 -- !query 0
@@ -107,3 +107,41 @@ struct<v:array<array<int>>>
 [[96,65],[]]
 [[99],[123],[]]
 [[]]
+
+
+-- !query 11
+select aggregate(ys, 0, (y, a) -> y + a + x) as v from nested
+-- !query 11 schema
+struct<v:int>
+-- !query 11 output
+131
+15
+5
+
+
+-- !query 12
+select aggregate(ys, (0 as sum, 0 as n), (acc, x) -> (acc.sum + x, acc.n + 1), 
acc -> acc.sum / acc.n) as v from nested
+-- !query 12 schema
+struct<v:double>
+-- !query 12 output
+0.5
+12.0
+64.5
+
+
+-- !query 13
+select transform(zs, z -> aggregate(z, 1, (acc, val) -> acc * val * size(z))) 
as v from nested
+-- !query 13 schema
+struct<v:array<int>>
+-- !query 13 output
+[1010880,8]
+[17]
+[4752,20664,1]
+
+
+-- !query 14
+select aggregate(cast(null as array<int>), 0, (a, y) -> a + y + 1, a -> a + 2) 
as v
+-- !query 14 schema
+struct<v:int>
+-- !query 14 output
+NULL

http://git-wip-us.apache.org/repos/asf/spark/blob/327bb300/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 1d5707a..af3301b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -1896,6 +1896,127 @@ class DataFrameFunctionsSuite extends QueryTest with 
SharedSQLContext {
     assert(ex3.getMessage.contains("data type mismatch: argument 2 requires 
boolean type"))
   }
 
+  test("aggregate function - array for primitive type not containing null") {
+    val df = Seq(
+      Seq(1, 9, 8, 7),
+      Seq(5, 8, 9, 7, 2),
+      Seq.empty,
+      null
+    ).toDF("i")
+
+    def testArrayOfPrimitiveTypeNotContainsNull(): Unit = {
+      checkAnswer(df.selectExpr("aggregate(i, 0, (acc, x) -> acc + x)"),
+        Seq(
+          Row(25),
+          Row(31),
+          Row(0),
+          Row(null)))
+      checkAnswer(df.selectExpr("aggregate(i, 0, (acc, x) -> acc + x, acc -> 
acc * 10)"),
+        Seq(
+          Row(250),
+          Row(310),
+          Row(0),
+          Row(null)))
+    }
+
+    // Test with local relation, the Project will be evaluated without codegen
+    testArrayOfPrimitiveTypeNotContainsNull()
+    // Test with cached relation, the Project will be evaluated with codegen
+    df.cache()
+    testArrayOfPrimitiveTypeNotContainsNull()
+  }
+
+  test("aggregate function - array for primitive type containing null") {
+    val df = Seq[Seq[Integer]](
+      Seq(1, 9, 8, 7),
+      Seq(5, null, 8, 9, 7, 2),
+      Seq.empty,
+      null
+    ).toDF("i")
+
+    def testArrayOfPrimitiveTypeContainsNull(): Unit = {
+      checkAnswer(df.selectExpr("aggregate(i, 0, (acc, x) -> acc + x)"),
+        Seq(
+          Row(25),
+          Row(null),
+          Row(0),
+          Row(null)))
+      checkAnswer(
+        df.selectExpr("aggregate(i, 0, (acc, x) -> acc + x, acc -> 
coalesce(acc, 0) * 10)"),
+        Seq(
+          Row(250),
+          Row(0),
+          Row(0),
+          Row(null)))
+    }
+
+    // Test with local relation, the Project will be evaluated without codegen
+    testArrayOfPrimitiveTypeContainsNull()
+    // Test with cached relation, the Project will be evaluated with codegen
+    df.cache()
+    testArrayOfPrimitiveTypeContainsNull()
+  }
+
+  test("aggregate function - array for non-primitive type") {
+    val df = Seq(
+      (Seq("c", "a", "b"), "a"),
+      (Seq("b", null, "c", null), "b"),
+      (Seq.empty, "c"),
+      (null, "d")
+    ).toDF("ss", "s")
+
+    def testNonPrimitiveType(): Unit = {
+      checkAnswer(df.selectExpr("aggregate(ss, s, (acc, x) -> concat(acc, 
x))"),
+        Seq(
+          Row("acab"),
+          Row(null),
+          Row("c"),
+          Row(null)))
+      checkAnswer(
+        df.selectExpr("aggregate(ss, s, (acc, x) -> concat(acc, x), acc -> 
coalesce(acc , ''))"),
+        Seq(
+          Row("acab"),
+          Row(""),
+          Row("c"),
+          Row(null)))
+    }
+
+    // Test with local relation, the Project will be evaluated without codegen
+    testNonPrimitiveType()
+    // Test with cached relation, the Project will be evaluated with codegen
+    df.cache()
+    testNonPrimitiveType()
+  }
+
+  test("aggregate function - invalid") {
+    val df = Seq(
+      (Seq("c", "a", "b"), 1),
+      (Seq("b", null, "c", null), 2),
+      (Seq.empty, 3),
+      (null, 4)
+    ).toDF("s", "i")
+
+    val ex1 = intercept[AnalysisException] {
+      df.selectExpr("aggregate(s, '', x -> x)")
+    }
+    assert(ex1.getMessage.contains("The number of lambda function arguments 
'1' does not match"))
+
+    val ex2 = intercept[AnalysisException] {
+      df.selectExpr("aggregate(s, '', (acc, x) -> x, (acc, x) -> x)")
+    }
+    assert(ex2.getMessage.contains("The number of lambda function arguments 
'2' does not match"))
+
+    val ex3 = intercept[AnalysisException] {
+      df.selectExpr("aggregate(i, 0, (acc, x) -> x)")
+    }
+    assert(ex3.getMessage.contains("data type mismatch: argument 1 requires 
array type"))
+
+    val ex4 = intercept[AnalysisException] {
+      df.selectExpr("aggregate(s, 0, (acc, x) -> x)")
+    }
+    assert(ex4.getMessage.contains("data type mismatch: argument 3 requires 
int type"))
+  }
+
   private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = {
     import DataFrameFunctionsSuite.CodegenFallbackExpr
     for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), 
(false, true))) {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to