Repository: spark Updated Branches: refs/heads/master 5f9633dc9 -> 327bb3007
[SPARK-23911][SQL] Add aggregate function. ## What changes were proposed in this pull request? This pr adds `aggregate` function which applies a binary operator to an initial state and all elements in the array, and reduces this to a single state. The final state is converted into the final result by applying a finish function. ```sql > SELECT aggregate(array(1, 2, 3), (acc, x) -> acc + x); 6 > SELECT aggregate(array(1, 2, 3), (acc, x) -> acc + x, acc -> acc * 10); 60 ``` ## How was this patch tested? Added tests. Author: Takuya UESHIN <ues...@databricks.com> Closes #21982 from ueshin/issues/SPARK-23911/aggregate. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/327bb300 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/327bb300 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/327bb300 Branch: refs/heads/master Commit: 327bb30075834c873cdb78061c9b647e5e13b8a6 Parents: 5f9633d Author: Takuya UESHIN <ues...@databricks.com> Authored: Sun Aug 5 08:58:35 2018 +0900 Committer: Takuya UESHIN <ues...@databricks.com> Committed: Sun Aug 5 08:58:35 2018 +0900 ---------------------------------------------------------------------- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/higherOrderFunctions.scala | 95 +++++++++++++++ .../expressions/HigherOrderFunctionsSuite.scala | 50 ++++++++ .../sql-tests/inputs/higher-order-functions.sql | 12 ++ .../results/higher-order-functions.sql.out | 40 +++++- .../spark/sql/DataFrameFunctionsSuite.scala | 121 +++++++++++++++++++ 6 files changed, 318 insertions(+), 1 deletion(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/327bb300/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index d0efe97..35f8de1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -442,6 +442,7 @@ object FunctionRegistry { expression[ArrayDistinct]("array_distinct"), expression[ArrayTransform]("transform"), expression[ArrayFilter]("filter"), + expression[ArrayAggregate]("aggregate"), CreateStruct.registryEntry, // misc functions http://git-wip-us.apache.org/repos/asf/spark/blob/327bb300/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index e15225f..20c7f7d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -22,6 +22,7 @@ import java.util.concurrent.atomic.AtomicReference import scala.collection.mutable import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} @@ -76,6 +77,13 @@ case class LambdaFunction( override def eval(input: InternalRow): Any = function.eval(input) } +object LambdaFunction { + val identity: LambdaFunction = { + val id = UnresolvedAttribute.quoted("id") + LambdaFunction(id, Seq(id)) + } +} + /** * A higher order function takes one or more (lambda) functions and applies these to some objects. * The function produces a number of variables which can be consumed by some lambda function. @@ -270,3 +278,90 @@ case class ArrayFilter( override def prettyName: String = "filter" } + +/** + * Applies a binary operator to a start value and all elements in the array. + */ +@ExpressionDescription( + usage = + """ + _FUNC_(expr, start, merge, finish) - Applies a binary operator to an initial state and all + elements in the array, and reduces this to a single state. The final state is converted + into the final result by applying a finish function. + """, + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), (acc, x) -> acc + x); + 6 + > SELECT _FUNC_(array(1, 2, 3), (acc, x) -> acc + x, acc -> acc * 10); + 60 + """, + since = "2.4.0") +case class ArrayAggregate( + input: Expression, + zero: Expression, + merge: Expression, + finish: Expression) + extends HigherOrderFunction with CodegenFallback { + + def this(input: Expression, zero: Expression, merge: Expression) = { + this(input, zero, merge, LambdaFunction.identity) + } + + override def inputs: Seq[Expression] = input :: zero :: Nil + + override def functions: Seq[Expression] = merge :: finish :: Nil + + override def nullable: Boolean = input.nullable || finish.nullable + + override def dataType: DataType = finish.dataType + + override def checkInputDataTypes(): TypeCheckResult = { + if (!ArrayType.acceptsType(input.dataType)) { + TypeCheckResult.TypeCheckFailure( + s"argument 1 requires ${ArrayType.simpleString} type, " + + s"however, '${input.sql}' is of ${input.dataType.catalogString} type.") + } else if (!DataType.equalsStructurally( + zero.dataType, merge.dataType, ignoreNullability = true)) { + TypeCheckResult.TypeCheckFailure( + s"argument 3 requires ${zero.dataType.simpleString} type, " + + s"however, '${merge.sql}' is of ${merge.dataType.catalogString} type.") + } else { + TypeCheckResult.TypeCheckSuccess + } + } + + override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayAggregate = { + // Be very conservative with nullable. We cannot be sure that the accumulator does not + // evaluate to null. So we always set nullable to true here. + val elem = ArrayBasedHigherOrderFunction.elementArgumentType(input.dataType) + val acc = zero.dataType -> true + val newMerge = f(merge, acc :: elem :: Nil) + val newFinish = f(finish, acc :: Nil) + copy(merge = newMerge, finish = newFinish) + } + + @transient lazy val LambdaFunction(_, + Seq(accForMergeVar: NamedLambdaVariable, elementVar: NamedLambdaVariable), _) = merge + @transient lazy val LambdaFunction(_, Seq(accForFinishVar: NamedLambdaVariable), _) = finish + + override def eval(input: InternalRow): Any = { + val arr = this.input.eval(input).asInstanceOf[ArrayData] + if (arr == null) { + null + } else { + val Seq(mergeForEval, finishForEval) = functionsForEval + accForMergeVar.value.set(zero.eval(input)) + var i = 0 + while (i < arr.numElements()) { + elementVar.value.set(arr.get(i, elementVar.dataType)) + accForMergeVar.value.set(mergeForEval.eval(input)) + i += 1 + } + accForFinishVar.value.set(accForMergeVar.value.get) + finishForEval.eval(input) + } + } + + override def prettyName: String = "aggregate" +} http://git-wip-us.apache.org/repos/asf/spark/blob/327bb300/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index d1330c7..40cfc0c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -59,6 +59,27 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper ArrayFilter(expr, createLambda(at.elementType, at.containsNull, f)) } + def aggregate( + expr: Expression, + zero: Expression, + merge: (Expression, Expression) => Expression, + finish: Expression => Expression): Expression = { + val at = expr.dataType.asInstanceOf[ArrayType] + val zeroType = zero.dataType + ArrayAggregate( + expr, + zero, + createLambda(zeroType, true, at.elementType, at.containsNull, merge), + createLambda(zeroType, true, finish)) + } + + def aggregate( + expr: Expression, + zero: Expression, + merge: (Expression, Expression) => Expression): Expression = { + aggregate(expr, zero, merge, identity) + } + test("ArrayTransform") { val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) val ai1 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true)) @@ -131,4 +152,33 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(transform(aai, ix => filter(ix, isNullOrOdd)), Seq(Seq(1, 3), null, Seq(5))) } + + test("ArrayAggregate") { + val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) + val ai1 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true)) + val ai2 = Literal.create(Seq.empty[Int], ArrayType(IntegerType, containsNull = false)) + val ain = Literal.create(null, ArrayType(IntegerType, containsNull = false)) + + checkEvaluation(aggregate(ai0, 0, (acc, elem) => acc + elem, acc => acc * 10), 60) + checkEvaluation(aggregate(ai1, 0, (acc, elem) => acc + coalesce(elem, 0), acc => acc * 10), 40) + checkEvaluation(aggregate(ai2, 0, (acc, elem) => acc + elem, acc => acc * 10), 0) + checkEvaluation(aggregate(ain, 0, (acc, elem) => acc + elem, acc => acc * 10), null) + + val as0 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType, containsNull = false)) + val as1 = Literal.create(Seq("a", null, "c"), ArrayType(StringType, containsNull = true)) + val as2 = Literal.create(Seq.empty[String], ArrayType(StringType, containsNull = false)) + val asn = Literal.create(null, ArrayType(StringType, containsNull = false)) + + checkEvaluation(aggregate(as0, "", (acc, elem) => Concat(Seq(acc, elem))), "abc") + checkEvaluation(aggregate(as1, "", (acc, elem) => Concat(Seq(acc, coalesce(elem, "x")))), "axc") + checkEvaluation(aggregate(as2, "", (acc, elem) => Concat(Seq(acc, elem))), "") + checkEvaluation(aggregate(asn, "", (acc, elem) => Concat(Seq(acc, elem))), null) + + val aai = Literal.create(Seq[Seq[Integer]](Seq(1, 2, 3), null, Seq(4, 5)), + ArrayType(ArrayType(IntegerType, containsNull = false), containsNull = true)) + checkEvaluation( + aggregate(aai, 0, + (acc, array) => coalesce(aggregate(array, acc, (acc, elem) => acc + elem), acc)), + 15) + } } http://git-wip-us.apache.org/repos/asf/spark/blob/327bb300/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql ---------------------------------------------------------------------- diff --git a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql index f833aa5..136396d 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql @@ -33,3 +33,15 @@ select filter(cast(null as array<int>), y -> true) as v; -- Filter nested arrays select transform(zs, z -> filter(z, zz -> zz > 50)) as v from nested; + +-- Aggregate. +select aggregate(ys, 0, (y, a) -> y + a + x) as v from nested; + +-- Aggregate average. +select aggregate(ys, (0 as sum, 0 as n), (acc, x) -> (acc.sum + x, acc.n + 1), acc -> acc.sum / acc.n) as v from nested; + +-- Aggregate nested arrays +select transform(zs, z -> aggregate(z, 1, (acc, val) -> acc * val * size(z))) as v from nested; + +-- Aggregate a null array +select aggregate(cast(null as array<int>), 0, (a, y) -> a + y + 1, a -> a + 2) as v; http://git-wip-us.apache.org/repos/asf/spark/blob/327bb300/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out ---------------------------------------------------------------------- diff --git a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out index 4c5d972..e6f62f2 100644 --- a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 11 +-- Number of queries: 15 -- !query 0 @@ -107,3 +107,41 @@ struct<v:array<array<int>>> [[96,65],[]] [[99],[123],[]] [[]] + + +-- !query 11 +select aggregate(ys, 0, (y, a) -> y + a + x) as v from nested +-- !query 11 schema +struct<v:int> +-- !query 11 output +131 +15 +5 + + +-- !query 12 +select aggregate(ys, (0 as sum, 0 as n), (acc, x) -> (acc.sum + x, acc.n + 1), acc -> acc.sum / acc.n) as v from nested +-- !query 12 schema +struct<v:double> +-- !query 12 output +0.5 +12.0 +64.5 + + +-- !query 13 +select transform(zs, z -> aggregate(z, 1, (acc, val) -> acc * val * size(z))) as v from nested +-- !query 13 schema +struct<v:array<int>> +-- !query 13 output +[1010880,8] +[17] +[4752,20664,1] + + +-- !query 14 +select aggregate(cast(null as array<int>), 0, (a, y) -> a + y + 1, a -> a + 2) as v +-- !query 14 schema +struct<v:int> +-- !query 14 output +NULL http://git-wip-us.apache.org/repos/asf/spark/blob/327bb300/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 1d5707a..af3301b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1896,6 +1896,127 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(ex3.getMessage.contains("data type mismatch: argument 2 requires boolean type")) } + test("aggregate function - array for primitive type not containing null") { + val df = Seq( + Seq(1, 9, 8, 7), + Seq(5, 8, 9, 7, 2), + Seq.empty, + null + ).toDF("i") + + def testArrayOfPrimitiveTypeNotContainsNull(): Unit = { + checkAnswer(df.selectExpr("aggregate(i, 0, (acc, x) -> acc + x)"), + Seq( + Row(25), + Row(31), + Row(0), + Row(null))) + checkAnswer(df.selectExpr("aggregate(i, 0, (acc, x) -> acc + x, acc -> acc * 10)"), + Seq( + Row(250), + Row(310), + Row(0), + Row(null))) + } + + // Test with local relation, the Project will be evaluated without codegen + testArrayOfPrimitiveTypeNotContainsNull() + // Test with cached relation, the Project will be evaluated with codegen + df.cache() + testArrayOfPrimitiveTypeNotContainsNull() + } + + test("aggregate function - array for primitive type containing null") { + val df = Seq[Seq[Integer]]( + Seq(1, 9, 8, 7), + Seq(5, null, 8, 9, 7, 2), + Seq.empty, + null + ).toDF("i") + + def testArrayOfPrimitiveTypeContainsNull(): Unit = { + checkAnswer(df.selectExpr("aggregate(i, 0, (acc, x) -> acc + x)"), + Seq( + Row(25), + Row(null), + Row(0), + Row(null))) + checkAnswer( + df.selectExpr("aggregate(i, 0, (acc, x) -> acc + x, acc -> coalesce(acc, 0) * 10)"), + Seq( + Row(250), + Row(0), + Row(0), + Row(null))) + } + + // Test with local relation, the Project will be evaluated without codegen + testArrayOfPrimitiveTypeContainsNull() + // Test with cached relation, the Project will be evaluated with codegen + df.cache() + testArrayOfPrimitiveTypeContainsNull() + } + + test("aggregate function - array for non-primitive type") { + val df = Seq( + (Seq("c", "a", "b"), "a"), + (Seq("b", null, "c", null), "b"), + (Seq.empty, "c"), + (null, "d") + ).toDF("ss", "s") + + def testNonPrimitiveType(): Unit = { + checkAnswer(df.selectExpr("aggregate(ss, s, (acc, x) -> concat(acc, x))"), + Seq( + Row("acab"), + Row(null), + Row("c"), + Row(null))) + checkAnswer( + df.selectExpr("aggregate(ss, s, (acc, x) -> concat(acc, x), acc -> coalesce(acc , ''))"), + Seq( + Row("acab"), + Row(""), + Row("c"), + Row(null))) + } + + // Test with local relation, the Project will be evaluated without codegen + testNonPrimitiveType() + // Test with cached relation, the Project will be evaluated with codegen + df.cache() + testNonPrimitiveType() + } + + test("aggregate function - invalid") { + val df = Seq( + (Seq("c", "a", "b"), 1), + (Seq("b", null, "c", null), 2), + (Seq.empty, 3), + (null, 4) + ).toDF("s", "i") + + val ex1 = intercept[AnalysisException] { + df.selectExpr("aggregate(s, '', x -> x)") + } + assert(ex1.getMessage.contains("The number of lambda function arguments '1' does not match")) + + val ex2 = intercept[AnalysisException] { + df.selectExpr("aggregate(s, '', (acc, x) -> x, (acc, x) -> x)") + } + assert(ex2.getMessage.contains("The number of lambda function arguments '2' does not match")) + + val ex3 = intercept[AnalysisException] { + df.selectExpr("aggregate(i, 0, (acc, x) -> x)") + } + assert(ex3.getMessage.contains("data type mismatch: argument 1 requires array type")) + + val ex4 = intercept[AnalysisException] { + df.selectExpr("aggregate(s, 0, (acc, x) -> x)") + } + assert(ex4.getMessage.contains("data type mismatch: argument 3 requires int type")) + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org