Repository: spark Updated Branches: refs/heads/master e3d434994 -> 94d671448
[SPARK-23907][SQL] Add regr_* functions ## What changes were proposed in this pull request? The PR introduces regr_slope, regr_intercept, regr_r2, regr_sxx, regr_syy, regr_sxy, regr_avgx, regr_avgy, regr_count. The implementation of this functions mirrors Hive's one in HIVE-15978. ## How was this patch tested? added UT (values compared with Hive) Author: Marco Gaido <[email protected]> Closes #21054 from mgaido91/SPARK-23907. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/94d67144 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/94d67144 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/94d67144 Branch: refs/heads/master Commit: 94d671448240c8f6da11d2523ba9e4ae5b56a410 Parents: e3d4349 Author: Marco Gaido <[email protected]> Authored: Thu May 10 20:38:52 2018 +0900 Committer: Takuya UESHIN <[email protected]> Committed: Thu May 10 20:38:52 2018 +0900 ---------------------------------------------------------------------- .../catalyst/analysis/FunctionRegistry.scala | 9 + .../expressions/aggregate/Average.scala | 47 +++-- .../aggregate/CentralMomentAgg.scala | 60 +++--- .../catalyst/expressions/aggregate/Corr.scala | 52 ++--- .../catalyst/expressions/aggregate/Count.scala | 47 +++-- .../expressions/aggregate/Covariance.scala | 36 ++-- .../expressions/aggregate/regression.scala | 190 +++++++++++++++++++ .../scala/org/apache/spark/sql/functions.scala | 172 +++++++++++++++++ .../sql-tests/inputs/udaf-regrfunctions.sql | 56 ++++++ .../results/udaf-regrfunctions.sql.out | 93 +++++++++ .../spark/sql/DataFrameAggregateSuite.scala | 71 ++++++- 11 files changed, 721 insertions(+), 112 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/94d67144/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 87b0911..087d000 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -299,6 +299,15 @@ object FunctionRegistry { expression[CollectList]("collect_list"), expression[CollectSet]("collect_set"), expression[CountMinSketchAgg]("count_min_sketch"), + expression[RegrCount]("regr_count"), + expression[RegrSXX]("regr_sxx"), + expression[RegrSYY]("regr_syy"), + expression[RegrAvgX]("regr_avgx"), + expression[RegrAvgY]("regr_avgy"), + expression[RegrSXY]("regr_sxy"), + expression[RegrSlope]("regr_slope"), + expression[RegrR2]("regr_r2"), + expression[RegrIntercept]("regr_intercept"), // string functions expression[Ascii]("ascii"), http://git-wip-us.apache.org/repos/asf/spark/blob/94d67144/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index 708bdbf..a133bc2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -23,24 +23,12 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ -@ExpressionDescription( - usage = "_FUNC_(expr) - Returns the mean calculated from values of a group.") -case class Average(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes { - - override def prettyName: String = "avg" - - override def children: Seq[Expression] = child :: Nil +abstract class AverageLike(child: Expression) extends DeclarativeAggregate { override def nullable: Boolean = true - // Return data type. override def dataType: DataType = resultType - override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) - - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForNumericExpr(child.dataType, "function average") - private lazy val resultType = child.dataType match { case DecimalType.Fixed(p, s) => DecimalType.bounded(p + 4, s + 4) @@ -62,14 +50,6 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit /* count = */ Literal(0L) ) - override lazy val updateExpressions = Seq( - /* sum = */ - Add( - sum, - Coalesce(Cast(child, sumDataType) :: Cast(Literal(0), sumDataType) :: Nil)), - /* count = */ If(IsNull(child), count, count + 1L) - ) - override lazy val mergeExpressions = Seq( /* sum = */ sum.left + sum.right, /* count = */ count.left + count.right @@ -85,4 +65,29 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit case _ => Cast(sum, resultType) / Cast(count, resultType) } + + protected def updateExpressionsDef: Seq[Expression] = Seq( + /* sum = */ + Add( + sum, + Coalesce(Cast(child, sumDataType) :: Cast(Literal(0), sumDataType) :: Nil)), + /* count = */ If(IsNull(child), count, count + 1L) + ) + + override lazy val updateExpressions = updateExpressionsDef +} + +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns the mean calculated from values of a group.") +case class Average(child: Expression) + extends AverageLike(child) with ImplicitCastInputTypes { + + override def prettyName: String = "avg" + + override def children: Seq[Expression] = child :: Nil + + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) + + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, "function average") } http://git-wip-us.apache.org/repos/asf/spark/blob/94d67144/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala index 572d29c..6bbb083 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala @@ -67,35 +67,7 @@ abstract class CentralMomentAgg(child: Expression) override val initialValues: Seq[Expression] = Array.fill(momentOrder + 1)(Literal(0.0)) - override val updateExpressions: Seq[Expression] = { - val newN = n + Literal(1.0) - val delta = child - avg - val deltaN = delta / newN - val newAvg = avg + deltaN - val newM2 = m2 + delta * (delta - deltaN) - - val delta2 = delta * delta - val deltaN2 = deltaN * deltaN - val newM3 = if (momentOrder >= 3) { - m3 - Literal(3.0) * deltaN * newM2 + delta * (delta2 - deltaN2) - } else { - Literal(0.0) - } - val newM4 = if (momentOrder >= 4) { - m4 - Literal(4.0) * deltaN * newM3 - Literal(6.0) * deltaN2 * newM2 + - delta * (delta * delta2 - deltaN * deltaN2) - } else { - Literal(0.0) - } - - trimHigherOrder(Seq( - If(IsNull(child), n, newN), - If(IsNull(child), avg, newAvg), - If(IsNull(child), m2, newM2), - If(IsNull(child), m3, newM3), - If(IsNull(child), m4, newM4) - )) - } + override lazy val updateExpressions: Seq[Expression] = updateExpressionsDef override val mergeExpressions: Seq[Expression] = { @@ -128,6 +100,36 @@ abstract class CentralMomentAgg(child: Expression) trimHigherOrder(Seq(newN, newAvg, newM2, newM3, newM4)) } + + protected def updateExpressionsDef: Seq[Expression] = { + val newN = n + Literal(1.0) + val delta = child - avg + val deltaN = delta / newN + val newAvg = avg + deltaN + val newM2 = m2 + delta * (delta - deltaN) + + val delta2 = delta * delta + val deltaN2 = deltaN * deltaN + val newM3 = if (momentOrder >= 3) { + m3 - Literal(3.0) * deltaN * newM2 + delta * (delta2 - deltaN2) + } else { + Literal(0.0) + } + val newM4 = if (momentOrder >= 4) { + m4 - Literal(4.0) * deltaN * newM3 - Literal(6.0) * deltaN2 * newM2 + + delta * (delta * delta2 - deltaN * deltaN2) + } else { + Literal(0.0) + } + + trimHigherOrder(Seq( + If(IsNull(child), n, newN), + If(IsNull(child), avg, newAvg), + If(IsNull(child), m2, newM2), + If(IsNull(child), m3, newM3), + If(IsNull(child), m4, newM4) + )) + } } // Compute the population standard deviation of a column http://git-wip-us.apache.org/repos/asf/spark/blob/94d67144/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala index 95a4a0d..3cdef72 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala @@ -22,17 +22,13 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ /** - * Compute Pearson correlation between two expressions. + * Base class for computing Pearson correlation between two expressions. * When applied on empty data (i.e., count is zero), it returns NULL. * * Definition of Pearson correlation can be found at * http://en.wikipedia.org/wiki/Pearson_product-moment_correlation_coefficient */ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "_FUNC_(expr1, expr2) - Returns Pearson coefficient of correlation between a set of number pairs.") -// scalastyle:on line.size.limit -case class Corr(x: Expression, y: Expression) +abstract class PearsonCorrelation(x: Expression, y: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes { override def children: Seq[Expression] = Seq(x, y) @@ -51,7 +47,26 @@ case class Corr(x: Expression, y: Expression) override val initialValues: Seq[Expression] = Array.fill(6)(Literal(0.0)) - override val updateExpressions: Seq[Expression] = { + override lazy val updateExpressions: Seq[Expression] = updateExpressionsDef + + override val mergeExpressions: Seq[Expression] = { + val n1 = n.left + val n2 = n.right + val newN = n1 + n2 + val dx = xAvg.right - xAvg.left + val dxN = If(newN === Literal(0.0), Literal(0.0), dx / newN) + val dy = yAvg.right - yAvg.left + val dyN = If(newN === Literal(0.0), Literal(0.0), dy / newN) + val newXAvg = xAvg.left + dxN * n2 + val newYAvg = yAvg.left + dyN * n2 + val newCk = ck.left + ck.right + dx * dyN * n1 * n2 + val newXMk = xMk.left + xMk.right + dx * dxN * n1 * n2 + val newYMk = yMk.left + yMk.right + dy * dyN * n1 * n2 + + Seq(newN, newXAvg, newYAvg, newCk, newXMk, newYMk) + } + + protected def updateExpressionsDef: Seq[Expression] = { val newN = n + Literal(1.0) val dx = x - xAvg val dxN = dx / newN @@ -73,24 +88,15 @@ case class Corr(x: Expression, y: Expression) If(isNull, yMk, newYMk) ) } +} - override val mergeExpressions: Seq[Expression] = { - - val n1 = n.left - val n2 = n.right - val newN = n1 + n2 - val dx = xAvg.right - xAvg.left - val dxN = If(newN === Literal(0.0), Literal(0.0), dx / newN) - val dy = yAvg.right - yAvg.left - val dyN = If(newN === Literal(0.0), Literal(0.0), dy / newN) - val newXAvg = xAvg.left + dxN * n2 - val newYAvg = yAvg.left + dyN * n2 - val newCk = ck.left + ck.right + dx * dyN * n1 * n2 - val newXMk = xMk.left + xMk.right + dx * dxN * n1 * n2 - val newYMk = yMk.left + yMk.right + dy * dyN * n1 * n2 - Seq(newN, newXAvg, newYAvg, newCk, newXMk, newYMk) - } +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(expr1, expr2) - Returns Pearson coefficient of correlation between a set of number pairs.") +// scalastyle:on line.size.limit +case class Corr(x: Expression, y: Expression) + extends PearsonCorrelation(x, y) { override val evaluateExpression: Expression = { If(n === Literal(0.0), Literal.create(null, DoubleType), http://git-wip-us.apache.org/repos/asf/spark/blob/94d67144/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala index 1990f2f..40582d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala @@ -21,24 +21,16 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = """ - _FUNC_(*) - Returns the total number of retrieved rows, including rows containing null. - - _FUNC_(expr) - Returns the number of rows for which the supplied expression is non-null. - - _FUNC_(DISTINCT expr[, expr...]) - Returns the number of rows for which the supplied expression(s) are unique and non-null. - """) -// scalastyle:on line.size.limit -case class Count(children: Seq[Expression]) extends DeclarativeAggregate { - +/** + * Base class for all counting aggregators. + */ +abstract class CountLike extends DeclarativeAggregate { override def nullable: Boolean = false // Return data type. override def dataType: DataType = LongType - private lazy val count = AttributeReference("count", LongType, nullable = false)() + protected lazy val count = AttributeReference("count", LongType, nullable = false)() override lazy val aggBufferAttributes = count :: Nil @@ -46,6 +38,27 @@ case class Count(children: Seq[Expression]) extends DeclarativeAggregate { /* count = */ Literal(0L) ) + override lazy val mergeExpressions = Seq( + /* count = */ count.left + count.right + ) + + override lazy val evaluateExpression = count + + override def defaultResult: Option[Literal] = Option(Literal(0L)) +} + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(*) - Returns the total number of retrieved rows, including rows containing null. + + _FUNC_(expr) - Returns the number of rows for which the supplied expression is non-null. + + _FUNC_(DISTINCT expr[, expr...]) - Returns the number of rows for which the supplied expression(s) are unique and non-null. + """) +// scalastyle:on line.size.limit +case class Count(children: Seq[Expression]) extends CountLike { + override lazy val updateExpressions = { val nullableChildren = children.filter(_.nullable) if (nullableChildren.isEmpty) { @@ -58,14 +71,6 @@ case class Count(children: Seq[Expression]) extends DeclarativeAggregate { ) } } - - override lazy val mergeExpressions = Seq( - /* count = */ count.left + count.right - ) - - override lazy val evaluateExpression = count - - override def defaultResult: Option[Literal] = Option(Literal(0L)) } object Count { http://git-wip-us.apache.org/repos/asf/spark/blob/94d67144/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala index fc6c34b..72a7c62 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala @@ -42,23 +42,7 @@ abstract class Covariance(x: Expression, y: Expression) override val initialValues: Seq[Expression] = Array.fill(4)(Literal(0.0)) - override lazy val updateExpressions: Seq[Expression] = { - val newN = n + Literal(1.0) - val dx = x - xAvg - val dy = y - yAvg - val dyN = dy / newN - val newXAvg = xAvg + dx / newN - val newYAvg = yAvg + dyN - val newCk = ck + dx * (y - newYAvg) - - val isNull = IsNull(x) || IsNull(y) - Seq( - If(isNull, n, newN), - If(isNull, xAvg, newXAvg), - If(isNull, yAvg, newYAvg), - If(isNull, ck, newCk) - ) - } + override lazy val updateExpressions: Seq[Expression] = updateExpressionsDef override val mergeExpressions: Seq[Expression] = { @@ -75,6 +59,24 @@ abstract class Covariance(x: Expression, y: Expression) Seq(newN, newXAvg, newYAvg, newCk) } + + protected def updateExpressionsDef: Seq[Expression] = { + val newN = n + Literal(1.0) + val dx = x - xAvg + val dy = y - yAvg + val dyN = dy / newN + val newXAvg = xAvg + dx / newN + val newYAvg = yAvg + dyN + val newCk = ck + dx * (y - newYAvg) + + val isNull = IsNull(x) || IsNull(y) + Seq( + If(isNull, n, newN), + If(isNull, xAvg, newXAvg), + If(isNull, yAvg, newYAvg), + If(isNull, ck, newCk) + ) + } } @ExpressionDescription( http://git-wip-us.apache.org/repos/asf/spark/blob/94d67144/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/regression.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/regression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/regression.scala new file mode 100644 index 0000000..d8f4505 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/regression.scala @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.{AbstractDataType, DoubleType} + +/** + * Base trait for all regression functions. + */ +trait RegrLike extends AggregateFunction with ImplicitCastInputTypes { + def y: Expression + def x: Expression + + override def children: Seq[Expression] = Seq(y, x) + override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType) + + protected def updateIfNotNull(exprs: Seq[Expression]): Seq[Expression] = { + assert(aggBufferAttributes.length == exprs.length) + val nullableChildren = children.filter(_.nullable) + if (nullableChildren.isEmpty) { + exprs + } else { + exprs.zip(aggBufferAttributes).map { case (e, a) => + If(nullableChildren.map(IsNull).reduce(Or), a, e) + } + } + } +} + + +@ExpressionDescription( + usage = "_FUNC_(y, x) - Returns the number of non-null pairs.", + since = "2.4.0") +case class RegrCount(y: Expression, x: Expression) + extends CountLike with RegrLike { + + override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(Seq(count + 1L)) + + override def prettyName: String = "regr_count" +} + + +@ExpressionDescription( + usage = "_FUNC_(y, x) - Returns SUM(x*x)-SUM(x)*SUM(x)/N. Any pair with a NULL is ignored.", + since = "2.4.0") +case class RegrSXX(y: Expression, x: Expression) + extends CentralMomentAgg(x) with RegrLike { + + override protected def momentOrder = 2 + + override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef) + + override val evaluateExpression: Expression = { + If(n === Literal(0.0), Literal.create(null, DoubleType), m2) + } + + override def prettyName: String = "regr_sxx" +} + + +@ExpressionDescription( + usage = "_FUNC_(y, x) - Returns SUM(y*y)-SUM(y)*SUM(y)/N. Any pair with a NULL is ignored.", + since = "2.4.0") +case class RegrSYY(y: Expression, x: Expression) + extends CentralMomentAgg(y) with RegrLike { + + override protected def momentOrder = 2 + + override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef) + + override val evaluateExpression: Expression = { + If(n === Literal(0.0), Literal.create(null, DoubleType), m2) + } + + override def prettyName: String = "regr_syy" +} + + +@ExpressionDescription( + usage = "_FUNC_(y, x) - Returns the average of x. Any pair with a NULL is ignored.", + since = "2.4.0") +case class RegrAvgX(y: Expression, x: Expression) + extends AverageLike(x) with RegrLike { + + override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef) + + override def prettyName: String = "regr_avgx" +} + + +@ExpressionDescription( + usage = "_FUNC_(y, x) - Returns the average of y. Any pair with a NULL is ignored.", + since = "2.4.0") +case class RegrAvgY(y: Expression, x: Expression) + extends AverageLike(y) with RegrLike { + + override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef) + + override def prettyName: String = "regr_avgy" +} + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(y, x) - Returns the covariance of y and x multiplied for the number of items in the dataset. Any pair with a NULL is ignored.", + since = "2.4.0") +// scalastyle:on line.size.limit +case class RegrSXY(y: Expression, x: Expression) + extends Covariance(y, x) with RegrLike { + + override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef) + + override val evaluateExpression: Expression = { + If(n === Literal(0.0), Literal.create(null, DoubleType), ck) + } + + override def prettyName: String = "regr_sxy" +} + + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(y, x) - Returns the slope of the linear regression line. Any pair with a NULL is ignored.", + since = "2.4.0") +// scalastyle:on line.size.limit +case class RegrSlope(y: Expression, x: Expression) + extends PearsonCorrelation(y, x) with RegrLike { + + override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef) + + override val evaluateExpression: Expression = { + If(n < Literal(2.0) || yMk === Literal(0.0), Literal.create(null, DoubleType), ck / yMk) + } + + override def prettyName: String = "regr_slope" +} + + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(y, x) - Returns the coefficient of determination (also called R-squared or goodness of fit) for the regression line. Any pair with a NULL is ignored.", + since = "2.4.0") +// scalastyle:on line.size.limit +case class RegrR2(y: Expression, x: Expression) + extends PearsonCorrelation(y, x) with RegrLike { + + override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef) + + override val evaluateExpression: Expression = { + If(n < Literal(2.0) || yMk === Literal(0.0), Literal.create(null, DoubleType), + If(xMk === Literal(0.0), Literal(1.0), ck * ck / yMk / xMk)) + } + + override def prettyName: String = "regr_r2" +} + + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(y, x) - Returns the y-intercept of the linear regression line. Any pair with a NULL is ignored.", + since = "2.4.0") +// scalastyle:on line.size.limit +case class RegrIntercept(y: Expression, x: Expression) + extends PearsonCorrelation(y, x) with RegrLike { + + override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef) + + override val evaluateExpression: Expression = { + If(n === Literal(0.0) || yMk === Literal(0.0), Literal.create(null, DoubleType), + xAvg - (ck / yMk) * yAvg) + } + + override def prettyName: String = "regr_intercept" +} http://git-wip-us.apache.org/repos/asf/spark/blob/94d67144/sql/core/src/main/scala/org/apache/spark/sql/functions.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 8f9e4ae..28cf705 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -775,6 +775,178 @@ object functions { */ def var_pop(columnName: String): Column = var_pop(Column(columnName)) + /** + * Aggregate function: returns the number of non-null pairs. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_count(y: Column, x: Column): Column = withAggregateFunction { + RegrCount(y.expr, x.expr) + } + + /** + * Aggregate function: returns the number of non-null pairs. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_count(y: String, x: String): Column = regr_count(Column(y), Column(x)) + + /** + * Aggregate function: returns SUM(x*x)-SUM(x)*SUM(x)/N. Any pair with a NULL is ignored. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_sxx(y: Column, x: Column): Column = withAggregateFunction { + RegrSXX(y.expr, x.expr) + } + + /** + * Aggregate function: returns SUM(x*x)-SUM(x)*SUM(x)/N. Any pair with a NULL is ignored. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_sxx(y: String, x: String): Column = regr_sxx(Column(y), Column(x)) + + /** + * Aggregate function: returns SUM(y*y)-SUM(y)*SUM(y)/N. Any pair with a NULL is ignored. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_syy(y: Column, x: Column): Column = withAggregateFunction { + RegrSYY(y.expr, x.expr) + } + + /** + * Aggregate function: returns SUM(y*y)-SUM(y)*SUM(y)/N. Any pair with a NULL is ignored. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_syy(y: String, x: String): Column = regr_syy(Column(y), Column(x)) + + /** + * Aggregate function: returns the average of y. Any pair with a NULL is ignored. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_avgy(y: Column, x: Column): Column = withAggregateFunction { + RegrAvgY(y.expr, x.expr) + } + + /** + * Aggregate function: returns the average of y. Any pair with a NULL is ignored. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_avgy(y: String, x: String): Column = regr_avgy(Column(y), Column(x)) + + /** + * Aggregate function: returns the average of x. Any pair with a NULL is ignored. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_avgx(y: Column, x: Column): Column = withAggregateFunction { + RegrAvgX(y.expr, x.expr) + } + + /** + * Aggregate function: returns the average of x. Any pair with a NULL is ignored. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_avgx(y: String, x: String): Column = regr_avgx(Column(y), Column(x)) + + /** + * Aggregate function: returns the covariance of y and x multiplied for the number of items in + * the dataset. Any pair with a NULL is ignored. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_sxy(y: Column, x: Column): Column = withAggregateFunction { + RegrSXY(y.expr, x.expr) + } + + /** + * Aggregate function: returns the covariance of y and x multiplied for the number of items in + * the dataset. Any pair with a NULL is ignored. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_sxy(y: String, x: String): Column = regr_sxy(Column(y), Column(x)) + + /** + * Aggregate function: returns the slope of the linear regression line. Any pair with a NULL is + * ignored. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_slope(y: Column, x: Column): Column = withAggregateFunction { + RegrSlope(y.expr, x.expr) + } + + /** + * Aggregate function: returns the slope of the linear regression line. Any pair with a NULL is + * ignored. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_slope(y: String, x: String): Column = regr_slope(Column(y), Column(x)) + + /** + * Aggregate function: returns the coefficient of determination (also called R-squared or + * goodness of fit) for the regression line. Any pair with a NULL is ignored. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_r2(y: Column, x: Column): Column = withAggregateFunction { + RegrR2(y.expr, x.expr) + } + + /** + * Aggregate function: returns the coefficient of determination (also called R-squared or + * goodness of fit) for the regression line. Any pair with a NULL is ignored. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_r2(y: String, x: String): Column = regr_r2(Column(y), Column(x)) + + /** + * Aggregate function: returns the y-intercept of the linear regression line. Any pair with a + * NULL is ignored. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_intercept(y: Column, x: Column): Column = withAggregateFunction { + RegrIntercept(y.expr, x.expr) + } + + /** + * Aggregate function: returns the y-intercept of the linear regression line. Any pair with a + * NULL is ignored. + * + * @group agg_funcs + * @since 2.4.0 + */ + def regr_intercept(y: String, x: String): Column = regr_intercept(Column(y), Column(x)) + + + ////////////////////////////////////////////////////////////////////////////////////////////// // Window functions ////////////////////////////////////////////////////////////////////////////////////////////// http://git-wip-us.apache.org/repos/asf/spark/blob/94d67144/sql/core/src/test/resources/sql-tests/inputs/udaf-regrfunctions.sql ---------------------------------------------------------------------- diff --git a/sql/core/src/test/resources/sql-tests/inputs/udaf-regrfunctions.sql b/sql/core/src/test/resources/sql-tests/inputs/udaf-regrfunctions.sql new file mode 100644 index 0000000..92c7e26 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/udaf-regrfunctions.sql @@ -0,0 +1,56 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one or more +-- contributor license agreements. See the NOTICE file distributed with +-- this work for additional information regarding copyright ownership. +-- The ASF licenses this file to You under the Apache License, Version 2.0 +-- (the "License"); you may not use this file except in compliance with +-- the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- + +CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES + (101, 1, 1, 1), + (201, 2, 1, 1), + (301, 3, 1, 1), + (401, 4, 1, 11), + (501, 5, 1, null), + (601, 6, null, 1), + (701, 6, null, null), + (102, 1, 2, 2), + (202, 2, 1, 2), + (302, 3, 2, 1), + (402, 4, 2, 12), + (502, 5, 2, null), + (602, 6, null, 2), + (702, 6, null, null), + (103, 1, 3, 3), + (203, 2, 1, 3), + (303, 3, 3, 1), + (403, 4, 3, 13), + (503, 5, 3, null), + (603, 6, null, 3), + (703, 6, null, null), + (104, 1, 4, 4), + (204, 2, 1, 4), + (304, 3, 4, 1), + (404, 4, 4, 14), + (504, 5, 4, null), + (604, 6, null, 4), + (704, 6, null, null), + (800, 7, 1, 1) +as t1(id, px, y, x); + +select px, var_pop(x), var_pop(y), corr(y,x), covar_samp(y,x), covar_pop(y,x), regr_count(y,x), + regr_slope(y,x), regr_intercept(y,x), regr_r2(y,x), regr_sxx(y,x), regr_syy(y,x), regr_sxy(y,x), + regr_avgx(y,x), regr_avgy(y,x), regr_count(y,x) +from t1 group by px order by px; + + +select id, regr_count(y,x) over (partition by px) from t1 order by id; http://git-wip-us.apache.org/repos/asf/spark/blob/94d67144/sql/core/src/test/resources/sql-tests/results/udaf-regrfunctions.sql.out ---------------------------------------------------------------------- diff --git a/sql/core/src/test/resources/sql-tests/results/udaf-regrfunctions.sql.out b/sql/core/src/test/resources/sql-tests/results/udaf-regrfunctions.sql.out new file mode 100644 index 0000000..d7d009a --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/udaf-regrfunctions.sql.out @@ -0,0 +1,93 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 3 + + +-- !query 0 +CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES + (101, 1, 1, 1), + (201, 2, 1, 1), + (301, 3, 1, 1), + (401, 4, 1, 11), + (501, 5, 1, null), + (601, 6, null, 1), + (701, 6, null, null), + (102, 1, 2, 2), + (202, 2, 1, 2), + (302, 3, 2, 1), + (402, 4, 2, 12), + (502, 5, 2, null), + (602, 6, null, 2), + (702, 6, null, null), + (103, 1, 3, 3), + (203, 2, 1, 3), + (303, 3, 3, 1), + (403, 4, 3, 13), + (503, 5, 3, null), + (603, 6, null, 3), + (703, 6, null, null), + (104, 1, 4, 4), + (204, 2, 1, 4), + (304, 3, 4, 1), + (404, 4, 4, 14), + (504, 5, 4, null), + (604, 6, null, 4), + (704, 6, null, null), + (800, 7, 1, 1) +as t1(id, px, y, x) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +select px, var_pop(x), var_pop(y), corr(y,x), covar_samp(y,x), covar_pop(y,x), regr_count(y,x), + regr_slope(y,x), regr_intercept(y,x), regr_r2(y,x), regr_sxx(y,x), regr_syy(y,x), regr_sxy(y,x), + regr_avgx(y,x), regr_avgy(y,x), regr_count(y,x) +from t1 group by px order by px +-- !query 1 schema +struct<px:int,var_pop(CAST(x AS DOUBLE)):double,var_pop(CAST(y AS DOUBLE)):double,corr(CAST(y AS DOUBLE), CAST(x AS DOUBLE)):double,covar_samp(CAST(y AS DOUBLE), CAST(x AS DOUBLE)):double,covar_pop(CAST(y AS DOUBLE), CAST(x AS DOUBLE)):double,regr_count(CAST(y AS DOUBLE), CAST(x AS DOUBLE)):bigint,regr_slope(CAST(y AS DOUBLE), CAST(x AS DOUBLE)):double,regr_intercept(CAST(y AS DOUBLE), CAST(x AS DOUBLE)):double,regr_r2(CAST(y AS DOUBLE), CAST(x AS DOUBLE)):double,regr_sxx(CAST(y AS DOUBLE), CAST(x AS DOUBLE)):double,regr_syy(CAST(y AS DOUBLE), CAST(x AS DOUBLE)):double,regr_sxy(CAST(y AS DOUBLE), CAST(x AS DOUBLE)):double,regr_avgx(CAST(y AS DOUBLE), CAST(x AS DOUBLE)):double,regr_avgy(CAST(y AS DOUBLE), CAST(x AS DOUBLE)):double,regr_count(CAST(y AS DOUBLE), CAST(x AS DOUBLE)):bigint> +-- !query 1 output +1 1.25 1.25 1.0 1.6666666666666667 1.25 4 1.0 0.0 1.0 5.0 5.0 5.0 2.5 2.5 4 +2 1.25 0.0 NULL 0.0 0.0 4 0.0 1.0 1.0 5.0 0.0 0.0 2.5 1.0 4 +3 0.0 1.25 NULL 0.0 0.0 4 NULL NULL NULL 0.0 5.0 0.0 1.0 2.5 4 +4 1.25 1.25 1.0 1.6666666666666667 1.25 4 1.0 -10.0 1.0 5.0 5.0 5.0 12.5 2.5 4 +5 NULL 1.25 NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL 0 +6 1.25 NULL NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL 0 +7 0.0 0.0 NaN NaN 0.0 1 NULL NULL NULL 0.0 0.0 0.0 1.0 1.0 1 + + +-- !query 2 +select id, regr_count(y,x) over (partition by px) from t1 order by id +-- !query 2 schema +struct<id:int,regr_count(CAST(y AS DOUBLE), CAST(x AS DOUBLE)) OVER (PARTITION BY px ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING):bigint> +-- !query 2 output +101 4 +102 4 +103 4 +104 4 +201 4 +202 4 +203 4 +204 4 +301 4 +302 4 +303 4 +304 4 +401 4 +402 4 +403 4 +404 4 +501 0 +502 0 +503 0 +504 0 +601 0 +602 0 +603 0 +604 0 +701 0 +702 0 +703 0 +704 0 +800 1 http://git-wip-us.apache.org/repos/asf/spark/blob/94d67144/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index e7776e3..4337fb2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -36,6 +36,8 @@ case class Fact(date: Int, hour: Int, minute: Int, room_name: String, temp: Doub class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { import testImplicits._ + val absTol = 1e-8 + test("groupBy") { checkAnswer( testData2.groupBy("a").agg(sum($"b")), @@ -416,7 +418,6 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { } test("moments") { - val absTol = 1e-8 val sparkVariance = testData2.agg(variance('a)) checkAggregatesWithTol(sparkVariance, Row(4.0 / 5.0), absTol) @@ -686,4 +687,72 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { } } } + + test("SPARK-23907: regression functions") { + val emptyTableData = Seq.empty[(Double, Double)].toDF("a", "b") + val correlatedData = Seq[(Double, Double)]((2, 3), (3, 4), (7.5, 8.2), (10.3, 12)) + .toDF("a", "b") + val correlatedDataWithNull = Seq[(java.lang.Double, java.lang.Double)]( + (2.0, 3.0), (3.0, null), (7.5, 8.2), (10.3, 12.0)).toDF("a", "b") + checkAnswer(testData2.groupBy().agg(regr_count("a", "b")), Seq(Row(6))) + checkAnswer(testData3.groupBy().agg(regr_count("a", "b")), Seq(Row(1))) + checkAnswer(emptyTableData.groupBy().agg(regr_count("a", "b")), Seq(Row(0))) + + checkAggregatesWithTol(testData2.groupBy().agg(regr_sxx("a", "b")), Row(1.5), absTol) + checkAggregatesWithTol(testData3.groupBy().agg(regr_sxx("a", "b")), Row(0.0), absTol) + checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_sxx("a", "b")), Row(null), absTol) + checkAggregatesWithTol(testData2.groupBy().agg(regr_syy("b", "a")), Row(1.5), absTol) + checkAggregatesWithTol(testData3.groupBy().agg(regr_syy("b", "a")), Row(0.0), absTol) + checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_syy("b", "a")), Row(null), absTol) + + checkAggregatesWithTol(testData2.groupBy().agg(regr_avgx("a", "b")), Row(1.5), absTol) + checkAggregatesWithTol(testData3.groupBy().agg(regr_avgx("a", "b")), Row(2.0), absTol) + checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_avgx("a", "b")), Row(null), absTol) + checkAggregatesWithTol(testData2.groupBy().agg(regr_avgy("b", "a")), Row(1.5), absTol) + checkAggregatesWithTol(testData3.groupBy().agg(regr_avgy("b", "a")), Row(2.0), absTol) + checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_avgy("b", "a")), Row(null), absTol) + + checkAggregatesWithTol(testData2.groupBy().agg(regr_sxy("a", "b")), Row(0.0), absTol) + checkAggregatesWithTol(testData3.groupBy().agg(regr_sxy("a", "b")), Row(0.0), absTol) + checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_sxy("a", "b")), Row(null), absTol) + + checkAggregatesWithTol(testData2.groupBy().agg(regr_slope("a", "b")), Row(0.0), absTol) + checkAggregatesWithTol(testData3.groupBy().agg(regr_slope("a", "b")), Row(null), absTol) + checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_slope("a", "b")), Row(null), absTol) + + checkAggregatesWithTol(testData2.groupBy().agg(regr_r2("a", "b")), Row(0.0), absTol) + checkAggregatesWithTol(testData3.groupBy().agg(regr_r2("a", "b")), Row(null), absTol) + checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_r2("a", "b")), Row(null), absTol) + + checkAggregatesWithTol(testData2.groupBy().agg(regr_intercept("a", "b")), Row(2.0), absTol) + checkAggregatesWithTol(testData3.groupBy().agg(regr_intercept("a", "b")), Row(null), absTol) + checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_intercept("a", "b")), + Row(null), absTol) + + + checkAggregatesWithTol(correlatedData.groupBy().agg( + regr_count("a", "b"), + regr_avgx("a", "b"), + regr_avgy("a", "b"), + regr_sxx("a", "b"), + regr_syy("a", "b"), + regr_sxy("a", "b"), + regr_slope("a", "b"), + regr_r2("a", "b"), + regr_intercept("a", "b")), + Row(4, 6.8, 5.7, 51.28, 45.38, 48.06, 0.937207488, 0.992556013, -0.67301092), + absTol) + checkAggregatesWithTol(correlatedDataWithNull.groupBy().agg( + regr_count("a", "b"), + regr_avgx("a", "b"), + regr_avgy("a", "b"), + regr_sxx("a", "b"), + regr_syy("a", "b"), + regr_sxy("a", "b"), + regr_slope("a", "b"), + regr_r2("a", "b"), + regr_intercept("a", "b")), + Row(3, 7.73333333, 6.6, 40.82666666, 35.66, 37.98, 0.93027433, 0.99079694, -0.59412149), + absTol) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
