Repository: spark Updated Branches: refs/heads/master 9b670bcae -> af9789a4f
[SPARK-18632][SQL] AggregateFunction should not implement ImplicitCastInputTypes ## What changes were proposed in this pull request? `AggregateFunction` currently implements `ImplicitCastInputTypes` (which enables implicit input type casting). There are actually quite a few situations in which we don't need this, or require more control over our input. A recent example is the aggregate for `CountMinSketch` which should only take string, binary or integral types inputs. This PR removes `ImplicitCastInputTypes` from the `AggregateFunction` and makes a case-by-case decision on what kind of input validation we should use. ## How was this patch tested? Refactoring only. Existing tests. Author: Herman van Hovell <hvanhov...@databricks.com> Closes #16066 from hvanhovell/SPARK-18632. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/af9789a4 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/af9789a4 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/af9789a4 Branch: refs/heads/master Commit: af9789a4f5d00b3141f102e9f0ca52217e26c082 Parents: 9b670bc Author: Herman van Hovell <hvanhov...@databricks.com> Authored: Tue Nov 29 20:05:15 2016 -0800 Committer: Reynold Xin <r...@databricks.com> Committed: Tue Nov 29 20:05:15 2016 -0800 ---------------------------------------------------------------------- .../aggregate/ApproximatePercentile.scala | 7 +++--- .../expressions/aggregate/Average.scala | 2 +- .../aggregate/CentralMomentAgg.scala | 3 ++- .../catalyst/expressions/aggregate/Corr.scala | 3 ++- .../catalyst/expressions/aggregate/Count.scala | 3 --- .../aggregate/CountMinSketchAgg.scala | 5 ++-- .../expressions/aggregate/Covariance.scala | 3 ++- .../catalyst/expressions/aggregate/First.scala | 26 ++++++++++++++------ .../aggregate/HyperLogLogPlusPlus.scala | 2 -- .../catalyst/expressions/aggregate/Last.scala | 26 ++++++++++++++------ .../catalyst/expressions/aggregate/Max.scala | 3 --- .../catalyst/expressions/aggregate/Min.scala | 3 --- .../expressions/aggregate/Percentile.scala | 9 ++++--- .../expressions/aggregate/PivotFirst.scala | 2 -- .../catalyst/expressions/aggregate/Sum.scala | 2 +- .../expressions/aggregate/collect.scala | 2 -- .../expressions/aggregate/interfaces.scala | 2 +- .../expressions/windowExpressions.scala | 2 -- .../aggregate/TypedAggregateExpression.scala | 2 -- .../spark/sql/execution/aggregate/udaf.scala | 2 +- .../spark/sql/CountMinSketchAggQuerySuite.scala | 8 +++--- .../sql/TypedImperativeAggregateSuite.scala | 7 +++--- .../org/apache/spark/sql/hive/hiveUDFs.scala | 4 --- .../sql/hive/execution/TestingTypedCount.scala | 2 -- 24 files changed, 67 insertions(+), 63 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/af9789a4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala index 692cbd7..c2cd895 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala @@ -22,11 +22,11 @@ import java.nio.ByteBuffer import com.google.common.primitives.{Doubles, Ints, Longs} import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.{InternalRow} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.{PercentileDigest} +import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.PercentileDigest import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} import org.apache.spark.sql.catalyst.util.QuantileSummaries import org.apache.spark.sql.catalyst.util.QuantileSummaries.{defaultCompressThreshold, Stats} @@ -71,7 +71,8 @@ case class ApproximatePercentile( percentageExpression: Expression, accuracyExpression: Expression, override val mutableAggBufferOffset: Int, - override val inputAggBufferOffset: Int) extends TypedImperativeAggregate[PercentileDigest] { + override val inputAggBufferOffset: Int) + extends TypedImperativeAggregate[PercentileDigest] with ImplicitCastInputTypes { def this(child: Expression, percentageExpression: Expression, accuracyExpression: Expression) = { this(child, percentageExpression, accuracyExpression, 0, 0) http://git-wip-us.apache.org/repos/asf/spark/blob/af9789a4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index d523420..c423e17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.types._ @ExpressionDescription( usage = "_FUNC_(expr) - Returns the mean calculated from values of a group.") -case class Average(child: Expression) extends DeclarativeAggregate { +case class Average(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes { override def prettyName: String = "avg" http://git-wip-us.apache.org/repos/asf/spark/blob/af9789a4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala index 1a93f45..572d29c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala @@ -42,7 +42,8 @@ import org.apache.spark.sql.types._ * * @param child to compute central moments of. */ -abstract class CentralMomentAgg(child: Expression) extends DeclarativeAggregate { +abstract class CentralMomentAgg(child: Expression) + extends DeclarativeAggregate with ImplicitCastInputTypes { /** * The central moment order to be computed. http://git-wip-us.apache.org/repos/asf/spark/blob/af9789a4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala index 657f519..95a4a0d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala @@ -32,7 +32,8 @@ import org.apache.spark.sql.types._ @ExpressionDescription( usage = "_FUNC_(expr1, expr2) - Returns Pearson coefficient of correlation between a set of number pairs.") // scalastyle:on line.size.limit -case class Corr(x: Expression, y: Expression) extends DeclarativeAggregate { +case class Corr(x: Expression, y: Expression) + extends DeclarativeAggregate with ImplicitCastInputTypes { override def children: Seq[Expression] = Seq(x, y) override def nullable: Boolean = true http://git-wip-us.apache.org/repos/asf/spark/blob/af9789a4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala index bcae0dc..1990f2f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala @@ -38,9 +38,6 @@ case class Count(children: Seq[Expression]) extends DeclarativeAggregate { // Return data type. override def dataType: DataType = LongType - // Expected input data type. - override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(AnyDataType) - private lazy val count = AttributeReference("count", LongType, nullable = false)() override lazy val aggBufferAttributes = count :: Nil http://git-wip-us.apache.org/repos/asf/spark/blob/af9789a4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala index 1bfae9e..f5f185f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala @@ -22,7 +22,7 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} -import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription} +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExpressionDescription} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.sketch.CountMinSketch @@ -52,7 +52,8 @@ case class CountMinSketchAgg( confidenceExpression: Expression, seedExpression: Expression, override val mutableAggBufferOffset: Int, - override val inputAggBufferOffset: Int) extends TypedImperativeAggregate[CountMinSketch] { + override val inputAggBufferOffset: Int) + extends TypedImperativeAggregate[CountMinSketch] with ExpectsInputTypes { def this( child: Expression, http://git-wip-us.apache.org/repos/asf/spark/blob/af9789a4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala index ae5ed77..fc6c34b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala @@ -25,7 +25,8 @@ import org.apache.spark.sql.types._ * Compute the covariance between two expressions. * When applied on empty data (i.e., count is zero), it returns NULL. */ -abstract class Covariance(x: Expression, y: Expression) extends DeclarativeAggregate { +abstract class Covariance(x: Expression, y: Expression) + extends DeclarativeAggregate with ImplicitCastInputTypes { override def children: Seq[Expression] = Seq(x, y) override def nullable: Boolean = true http://git-wip-us.apache.org/repos/asf/spark/blob/af9789a4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala index 29b8947..bfc58c2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.catalyst.expressions.aggregate -import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ @@ -33,16 +34,11 @@ import org.apache.spark.sql.types._ _FUNC_(expr[, isIgnoreNull]) - Returns the first value of `expr` for a group of rows. If `isIgnoreNull` is true, returns only non-null values. """) -case class First(child: Expression, ignoreNullsExpr: Expression) extends DeclarativeAggregate { +case class First(child: Expression, ignoreNullsExpr: Expression) + extends DeclarativeAggregate with ExpectsInputTypes { def this(child: Expression) = this(child, Literal.create(false, BooleanType)) - private val ignoreNulls: Boolean = ignoreNullsExpr match { - case Literal(b: Boolean, BooleanType) => b - case _ => - throw new AnalysisException("The second argument of First should be a boolean literal.") - } - override def children: Seq[Expression] = child :: ignoreNullsExpr :: Nil override def nullable: Boolean = true @@ -56,6 +52,20 @@ case class First(child: Expression, ignoreNullsExpr: Expression) extends Declara // Expected input data type. override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, BooleanType) + override def checkInputDataTypes(): TypeCheckResult = { + val defaultCheck = super.checkInputDataTypes() + if (defaultCheck.isFailure) { + defaultCheck + } else if (!ignoreNullsExpr.foldable) { + TypeCheckFailure( + s"The second argument of First must be a boolean literal, but got: ${ignoreNullsExpr.sql}") + } else { + TypeCheckSuccess + } + } + + private def ignoreNulls: Boolean = ignoreNullsExpr.eval().asInstanceOf[Boolean] + private lazy val first = AttributeReference("first", child.dataType)() private lazy val valueSet = AttributeReference("valueSet", BooleanType)() http://git-wip-us.apache.org/repos/asf/spark/blob/af9789a4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala index 77b7eb2..d5c9166 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala @@ -140,8 +140,6 @@ case class HyperLogLogPlusPlus( override def dataType: DataType = LongType - override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) - override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) /** Allocate enough words to store all registers. */ http://git-wip-us.apache.org/repos/asf/spark/blob/af9789a4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala index b0a363e..96a6ec0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.catalyst.expressions.aggregate -import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ @@ -33,16 +34,11 @@ import org.apache.spark.sql.types._ _FUNC_(expr[, isIgnoreNull]) - Returns the last value of `expr` for a group of rows. If `isIgnoreNull` is true, returns only non-null values. """) -case class Last(child: Expression, ignoreNullsExpr: Expression) extends DeclarativeAggregate { +case class Last(child: Expression, ignoreNullsExpr: Expression) + extends DeclarativeAggregate with ExpectsInputTypes { def this(child: Expression) = this(child, Literal.create(false, BooleanType)) - private val ignoreNulls: Boolean = ignoreNullsExpr match { - case Literal(b: Boolean, BooleanType) => b - case _ => - throw new AnalysisException("The second argument of First should be a boolean literal.") - } - override def children: Seq[Expression] = child :: ignoreNullsExpr :: Nil override def nullable: Boolean = true @@ -56,6 +52,20 @@ case class Last(child: Expression, ignoreNullsExpr: Expression) extends Declarat // Expected input data type. override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, BooleanType) + override def checkInputDataTypes(): TypeCheckResult = { + val defaultCheck = super.checkInputDataTypes() + if (defaultCheck.isFailure) { + defaultCheck + } else if (!ignoreNullsExpr.foldable) { + TypeCheckFailure( + s"The second argument of Last must be a boolean literal, but got: ${ignoreNullsExpr.sql}") + } else { + TypeCheckSuccess + } + } + + private def ignoreNulls: Boolean = ignoreNullsExpr.eval().asInstanceOf[Boolean] + private lazy val last = AttributeReference("last", child.dataType)() private lazy val valueSet = AttributeReference("valueSet", BooleanType)() http://git-wip-us.apache.org/repos/asf/spark/blob/af9789a4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala index f32c9c6..58fd1d8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala @@ -33,9 +33,6 @@ case class Max(child: Expression) extends DeclarativeAggregate { // Return data type. override def dataType: DataType = child.dataType - // Expected input data type. - override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) - override def checkInputDataTypes(): TypeCheckResult = TypeUtils.checkForOrderingExpr(child.dataType, "function max") http://git-wip-us.apache.org/repos/asf/spark/blob/af9789a4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala index 9ef42b9..b2724ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala @@ -33,9 +33,6 @@ case class Min(child: Expression) extends DeclarativeAggregate { // Return data type. override def dataType: DataType = child.dataType - // Expected input data type. - override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) - override def checkInputDataTypes(): TypeCheckResult = TypeUtils.checkForOrderingExpr(child.dataType, "function min") http://git-wip-us.apache.org/repos/asf/spark/blob/af9789a4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala index 356e088..b51b553 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala @@ -54,10 +54,11 @@ import org.apache.spark.util.collection.OpenHashMap be between 0.0 and 1.0. """) case class Percentile( - child: Expression, - percentageExpression: Expression, - mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) extends TypedImperativeAggregate[OpenHashMap[Number, Long]] { + child: Expression, + percentageExpression: Expression, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) + extends TypedImperativeAggregate[OpenHashMap[Number, Long]] with ImplicitCastInputTypes { def this(child: Expression, percentageExpression: Expression) = { this(child, percentageExpression, 0, 0) http://git-wip-us.apache.org/repos/asf/spark/blob/af9789a4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala index 0876060..9ad3124 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala @@ -77,8 +77,6 @@ case class PivotFirst( override val children: Seq[Expression] = pivotColumn :: valueColumn :: Nil - override lazy val inputTypes: Seq[AbstractDataType] = children.map(_.dataType) - override val nullable: Boolean = false val valueDataType = valueColumn.dataType http://git-wip-us.apache.org/repos/asf/spark/blob/af9789a4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index f3731d4..96e8cee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.types._ @ExpressionDescription( usage = "_FUNC_(expr) - Returns the sum calculated from values of a group.") -case class Sum(child: Expression) extends DeclarativeAggregate { +case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes { override def children: Seq[Expression] = child :: Nil http://git-wip-us.apache.org/repos/asf/spark/blob/af9789a4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index d2880d5..b176e2a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -44,8 +44,6 @@ abstract class Collect extends ImperativeAggregate { override def dataType: DataType = ArrayType(child.dataType) - override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) - override def supportsPartial: Boolean = false override def aggBufferAttributes: Seq[AttributeReference] = Nil http://git-wip-us.apache.org/repos/asf/spark/blob/af9789a4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index f3fd58b..7397b60 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -155,7 +155,7 @@ case class AggregateExpression( * Code which accepts [[AggregateFunction]] instances should be prepared to handle both types of * aggregate functions. */ -sealed abstract class AggregateFunction extends Expression with ImplicitCastInputTypes { +sealed abstract class AggregateFunction extends Expression { /** An aggregate function is not foldable. */ final override def foldable: Boolean = false http://git-wip-us.apache.org/repos/asf/spark/blob/af9789a4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 3cbbcdf..c0d6a6b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -443,7 +443,6 @@ abstract class AggregateWindowFunction extends DeclarativeAggregate with WindowF abstract class RowNumberLike extends AggregateWindowFunction { override def children: Seq[Expression] = Nil - override def inputTypes: Seq[AbstractDataType] = Nil protected val zero = Literal(0) protected val one = Literal(1) protected val rowNumber = AttributeReference("rowNumber", IntegerType, nullable = false)() @@ -600,7 +599,6 @@ case class NTile(buckets: Expression) extends RowNumberLike with SizeBasedWindow * This documentation has been based upon similar documentation for the Hive and Presto projects. */ abstract class RankLike extends AggregateWindowFunction { - override def inputTypes: Seq[AbstractDataType] = children.map(_ => AnyDataType) /** Store the values of the window 'order' expressions. */ protected val orderAttrs = children.map { expr => http://git-wip-us.apache.org/repos/asf/spark/blob/af9789a4/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index 6f7f2f8..9911c0b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -81,8 +81,6 @@ case class TypedAggregateExpression( override def references: AttributeSet = AttributeSet(inputDeserializer.toSeq) - override def inputTypes: Seq[AbstractDataType] = Nil - private def aggregatorLiteral = Literal.create(aggregator, ObjectType(classOf[Aggregator[Any, Any, Any]])) http://git-wip-us.apache.org/repos/asf/spark/blob/af9789a4/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index 67760f3..ae5e2c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -324,7 +324,7 @@ case class ScalaUDAF( udaf: UserDefinedAggregateFunction, mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0) - extends ImperativeAggregate with NonSQLExpression with Logging { + extends ImperativeAggregate with NonSQLExpression with Logging with ImplicitCastInputTypes { override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = copy(mutableAggBufferOffset = newMutableAggBufferOffset) http://git-wip-us.apache.org/repos/asf/spark/blob/af9789a4/sql/core/src/test/scala/org/apache/spark/sql/CountMinSketchAggQuerySuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CountMinSketchAggQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CountMinSketchAggQuerySuite.scala index 4cc5060..3e715a3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CountMinSketchAggQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CountMinSketchAggQuerySuite.scala @@ -110,9 +110,11 @@ class CountMinSketchAggQuerySuite extends QueryTest with SharedSQLContext { withTempView(table) { val rdd: RDD[Row] = spark.sparkContext.parallelize(data) spark.createDataFrame(rdd, schema).createOrReplaceTempView(table) - val cmsSql = schema.fieldNames.map(col => s"count_min_sketch($col, $eps, $confidence, $seed)") - .mkString(", ") - val result = sql(s"SELECT $cmsSql FROM $table").head() + + val cmsSql = schema.fieldNames.map { col => + s"count_min_sketch($col, ${eps}D, ${confidence}D, $seed)" + } + val result = sql(s"SELECT ${cmsSql.mkString(", ")} FROM $table").head() schema.indices.foreach { i => val binaryData = result.getAs[Array[Byte]](i) val in = new ByteArrayInputStream(binaryData) http://git-wip-us.apache.org/repos/asf/spark/blob/af9789a4/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala index 0759915..70c3951 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala @@ -21,13 +21,13 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, Da import org.apache.spark.sql.TypedImperativeAggregateSuite.TypedMax import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, GenericInternalRow, SpecificInternalRow} +import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, GenericInternalRow, ImplicitCastInputTypes, SpecificInternalRow} import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{AbstractDataType, BinaryType, DataType, IntegerType, LongType} +import org.apache.spark.sql.types._ class TypedImperativeAggregateSuite extends QueryTest with SharedSQLContext { @@ -231,7 +231,8 @@ object TypedImperativeAggregateSuite { child: Expression, nullable: Boolean = false, mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) extends TypedImperativeAggregate[MaxValue] { + inputAggBufferOffset: Int = 0) + extends TypedImperativeAggregate[MaxValue] with ImplicitCastInputTypes { override def createAggregationBuffer(): MaxValue = { http://git-wip-us.apache.org/repos/asf/spark/blob/af9789a4/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 90e8695..349faae 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -378,10 +378,6 @@ private[hive] case class HiveUDAFFunction( @transient private lazy val aggBufferSerDe: AggregationBufferSerDe = new AggregationBufferSerDe - // We rely on Hive to check the input data types, so use `AnyDataType` here to bypass our - // catalyst type checking framework. - override def inputTypes: Seq[AbstractDataType] = children.map(_ => AnyDataType) - override def nullable: Boolean = true override def supportsPartial: Boolean = true http://git-wip-us.apache.org/repos/asf/spark/blob/af9789a4/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/TestingTypedCount.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/TestingTypedCount.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/TestingTypedCount.scala index a3d48d9..d27287b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/TestingTypedCount.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/TestingTypedCount.scala @@ -71,8 +71,6 @@ case class TestingTypedCount( TestingTypedCount.State(dataStream.readLong()) } - override def inputTypes: Seq[AbstractDataType] = AnyDataType :: Nil - override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = copy(mutableAggBufferOffset = newMutableAggBufferOffset) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org