Repository: spark Updated Branches: refs/heads/master 7daa70292 -> b71d3254e
[SPARK-8075] [SQL] apply type check interface to more expressions a follow up of https://github.com/apache/spark/pull/6405. Note: It's not a big change, a lot of changing is due to I swap some code in `aggregates.scala` to make aggregate functions right below its corresponding aggregate expressions. Author: Wenchen Fan <cloud0...@outlook.com> Closes #6723 from cloud-fan/type-check and squashes the following commits: 2124301 [Wenchen Fan] fix tests 5a658bb [Wenchen Fan] add tests 287d3bb [Wenchen Fan] apply type check interface to more expressions Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b71d3254 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b71d3254 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b71d3254 Branch: refs/heads/master Commit: b71d3254e50838ccae43bdb0ff186fda25f03152 Parents: 7daa702 Author: Wenchen Fan <cloud0...@outlook.com> Authored: Wed Jun 24 16:26:00 2015 -0700 Committer: Michael Armbrust <mich...@databricks.com> Committed: Wed Jun 24 16:26:00 2015 -0700 ---------------------------------------------------------------------- .../spark/sql/catalyst/analysis/Analyzer.scala | 4 +- .../catalyst/analysis/HiveTypeCoercion.scala | 17 +- .../spark/sql/catalyst/expressions/Cast.scala | 11 +- .../sql/catalyst/expressions/Expression.scala | 4 +- .../sql/catalyst/expressions/ExtractValue.scala | 10 +- .../sql/catalyst/expressions/aggregates.scala | 420 ++++++++++--------- .../sql/catalyst/expressions/arithmetic.scala | 2 - .../expressions/complexTypeCreator.scala | 30 +- .../catalyst/expressions/decimalFunctions.scala | 17 +- .../sql/catalyst/expressions/generators.scala | 13 +- .../spark/sql/catalyst/expressions/math.scala | 4 +- .../catalyst/expressions/namedExpressions.scala | 4 +- .../catalyst/expressions/nullFunctions.scala | 27 +- .../spark/sql/catalyst/expressions/sets.scala | 10 +- .../catalyst/expressions/stringOperations.scala | 2 - .../expressions/windowExpressions.scala | 3 +- .../spark/sql/catalyst/util/TypeUtils.scala | 9 + .../sql/catalyst/analysis/AnalysisSuite.scala | 6 +- .../analysis/ExpressionTypeCheckingSuite.scala | 163 +++++++ .../ExpressionTypeCheckingSuite.scala | 141 ------- .../apache/spark/sql/execution/pythonUdfs.scala | 2 +- .../hive/execution/HiveTypeCoercionSuite.scala | 6 - 22 files changed, 476 insertions(+), 429 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/b71d3254/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index b06759f..cad2c2a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -587,8 +587,8 @@ class Analyzer( failAnalysis( s"""Expect multiple names given for ${g.getClass.getName}, |but only single name '${name}' specified""".stripMargin) - case Alias(g: Generator, name) => Some((g, name :: Nil)) - case MultiAlias(g: Generator, names) => Some(g, names) + case Alias(g: Generator, name) if g.resolved => Some((g, name :: Nil)) + case MultiAlias(g: Generator, names) if g.resolved => Some(g, names) case _ => None } } http://git-wip-us.apache.org/repos/asf/spark/blob/b71d3254/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index d4ab1fc..4ef7341 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -317,6 +317,7 @@ trait HiveTypeCoercion { i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType)))) case Sum(e @ StringType()) => Sum(Cast(e, DoubleType)) + case SumDistinct(e @ StringType()) => Sum(Cast(e, DoubleType)) case Average(e @ StringType()) => Average(Cast(e, DoubleType)) } } @@ -590,11 +591,12 @@ trait HiveTypeCoercion { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case a @ CreateArray(children) if !a.resolved => - val commonType = a.childTypes.reduce( - (a, b) => findTightestCommonTypeOfTwo(a, b).getOrElse(StringType)) - CreateArray( - children.map(c => if (c.dataType == commonType) c else Cast(c, commonType))) + case a @ CreateArray(children) if children.map(_.dataType).distinct.size > 1 => + val types = children.map(_.dataType) + findTightestCommonTypeAndPromoteToString(types) match { + case Some(finalDataType) => CreateArray(children.map(Cast(_, finalDataType))) + case None => a + } // Promote SUM, SUM DISTINCT and AVERAGE to largest types to prevent overflows. case s @ Sum(e @ DecimalType()) => s // Decimal is already the biggest. @@ -620,12 +622,11 @@ trait HiveTypeCoercion { // Coalesce should return the first non-null value, which could be any column // from the list. So we need to make sure the return type is deterministic and // compatible with every child column. - case Coalesce(es) if es.map(_.dataType).distinct.size > 1 => + case c @ Coalesce(es) if es.map(_.dataType).distinct.size > 1 => val types = es.map(_.dataType) findTightestCommonTypeAndPromoteToString(types) match { case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType))) - case None => - sys.error(s"Could not determine return type of Coalesce for ${types.mkString(",")}") + case None => c } } } http://git-wip-us.apache.org/repos/asf/spark/blob/b71d3254/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index d271434..8bd7fc1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -22,7 +22,7 @@ import java.sql.{Date, Timestamp} import java.text.{DateFormat, SimpleDateFormat} import org.apache.spark.Logging -import org.apache.spark.sql.catalyst +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ @@ -31,7 +31,14 @@ import org.apache.spark.unsafe.types.UTF8String /** Cast the child expression to the target data type. */ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with Logging { - override lazy val resolved = childrenResolved && resolve(child.dataType, dataType) + override def checkInputDataTypes(): TypeCheckResult = { + if (resolve(child.dataType, dataType)) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure( + s"cannot cast ${child.dataType} to $dataType") + } + } override def foldable: Boolean = child.foldable http://git-wip-us.apache.org/repos/asf/spark/blob/b71d3254/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index a10a959..f59db3d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -162,9 +162,7 @@ abstract class Expression extends TreeNode[Expression] { /** * Checks the input data types, returns `TypeCheckResult.success` if it's valid, * or returns a `TypeCheckResult` with an error message if invalid. - * Note: it's not valid to call this method until `childrenResolved == true` - * TODO: we should remove the default implementation and implement it for all - * expressions with proper error message. + * Note: it's not valid to call this method until `childrenResolved == true`. */ def checkInputDataTypes(): TypeCheckResult = TypeCheckResult.TypeCheckSuccess } http://git-wip-us.apache.org/repos/asf/spark/blob/b71d3254/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala index 4d6c1c2..4d7c95f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala @@ -96,6 +96,11 @@ object ExtractValue { } } +/** + * A common interface of all kinds of extract value expressions. + * Note: concrete extract value expressions are created only by `ExtractValue.apply`, + * we don't need to do type check for them. + */ trait ExtractValue extends UnaryExpression { self: Product => } @@ -179,9 +184,6 @@ case class GetArrayItem(child: Expression, ordinal: Expression) override def dataType: DataType = child.dataType.asInstanceOf[ArrayType].elementType - override lazy val resolved = childrenResolved && - child.dataType.isInstanceOf[ArrayType] && ordinal.dataType.isInstanceOf[IntegralType] - protected def evalNotNull(value: Any, ordinal: Any) = { // TODO: consider using Array[_] for ArrayType child to avoid // boxing of primitives @@ -203,8 +205,6 @@ case class GetMapValue(child: Expression, ordinal: Expression) override def dataType: DataType = child.dataType.asInstanceOf[MapType].valueType - override lazy val resolved = childrenResolved && child.dataType.isInstanceOf[MapType] - protected def evalNotNull(value: Any, ordinal: Any) = { val baseValue = value.asInstanceOf[Map[Any, _]] baseValue.get(ordinal).orNull http://git-wip-us.apache.org/repos/asf/spark/blob/b71d3254/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 00d2e49..a9fc54c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -19,9 +19,10 @@ package org.apache.spark.sql.catalyst.expressions import com.clearspring.analytics.stream.cardinality.HyperLogLog -import org.apache.spark.sql.catalyst -import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.trees +import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashSet @@ -101,6 +102,9 @@ case class Min(child: Expression) extends PartialAggregate with trees.UnaryNode[ } override def newInstance(): MinFunction = new MinFunction(child, this) + + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForOrderingExpr(child.dataType, "function min") } case class MinFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { @@ -132,6 +136,9 @@ case class Max(child: Expression) extends PartialAggregate with trees.UnaryNode[ } override def newInstance(): MaxFunction = new MaxFunction(child, this) + + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForOrderingExpr(child.dataType, "function max") } case class MaxFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { @@ -165,6 +172,21 @@ case class Count(child: Expression) extends PartialAggregate with trees.UnaryNod override def newInstance(): CountFunction = new CountFunction(child, this) } +case class CountFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { + def this() = this(null, null) // Required for serialization. + + var count: Long = _ + + override def update(input: InternalRow): Unit = { + val evaluatedExpr = expr.eval(input) + if (evaluatedExpr != null) { + count += 1L + } + } + + override def eval(input: InternalRow): Any = count +} + case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate { def this() = this(null) @@ -183,6 +205,28 @@ case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate } } +case class CountDistinctFunction( + @transient expr: Seq[Expression], + @transient base: AggregateExpression) + extends AggregateFunction { + + def this() = this(null, null) // Required for serialization. + + val seen = new OpenHashSet[Any]() + + @transient + val distinctValue = new InterpretedProjection(expr) + + override def update(input: InternalRow): Unit = { + val evaluatedExpr = distinctValue(input) + if (!evaluatedExpr.anyNull) { + seen.add(evaluatedExpr) + } + } + + override def eval(input: InternalRow): Any = seen.size.toLong +} + case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpression { def this() = this(null) @@ -278,6 +322,25 @@ case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double) } } +case class ApproxCountDistinctPartitionFunction( + expr: Expression, + base: AggregateExpression, + relativeSD: Double) + extends AggregateFunction { + def this() = this(null, null, 0) // Required for serialization. + + private val hyperLogLog = new HyperLogLog(relativeSD) + + override def update(input: InternalRow): Unit = { + val evaluatedExpr = expr.eval(input) + if (evaluatedExpr != null) { + hyperLogLog.offer(evaluatedExpr) + } + } + + override def eval(input: InternalRow): Any = hyperLogLog +} + case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double) extends AggregateExpression with trees.UnaryNode[Expression] { @@ -289,6 +352,23 @@ case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double) } } +case class ApproxCountDistinctMergeFunction( + expr: Expression, + base: AggregateExpression, + relativeSD: Double) + extends AggregateFunction { + def this() = this(null, null, 0) // Required for serialization. + + private val hyperLogLog = new HyperLogLog(relativeSD) + + override def update(input: InternalRow): Unit = { + val evaluatedExpr = expr.eval(input) + hyperLogLog.addAll(evaluatedExpr.asInstanceOf[HyperLogLog]) + } + + override def eval(input: InternalRow): Any = hyperLogLog.cardinality() +} + case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05) extends PartialAggregate with trees.UnaryNode[Expression] { @@ -349,159 +429,9 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN } override def newInstance(): AverageFunction = new AverageFunction(child, this) -} - -case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - - override def nullable: Boolean = true - - override def dataType: DataType = child.dataType match { - case DecimalType.Fixed(precision, scale) => - DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive - case DecimalType.Unlimited => - DecimalType.Unlimited - case _ => - child.dataType - } - - override def toString: String = s"SUM($child)" - - override def asPartial: SplitEvaluation = { - child.dataType match { - case DecimalType.Fixed(_, _) => - val partialSum = Alias(Sum(Cast(child, DecimalType.Unlimited)), "PartialSum")() - SplitEvaluation( - Cast(CombineSum(partialSum.toAttribute), dataType), - partialSum :: Nil) - - case _ => - val partialSum = Alias(Sum(child), "PartialSum")() - SplitEvaluation( - CombineSum(partialSum.toAttribute), - partialSum :: Nil) - } - } - - override def newInstance(): SumFunction = new SumFunction(child, this) -} - -/** - * Sum should satisfy 3 cases: - * 1) sum of all null values = zero - * 2) sum for table column with no data = null - * 3) sum of column with null and not null values = sum of not null values - * Require separate CombineSum Expression and function as it has to distinguish "No data" case - * versus "data equals null" case, while aggregating results and at each partial expression.i.e., - * Combining PartitionLevel InputData - * <-- null - * Zero <-- Zero <-- null - * - * <-- null <-- no data - * null <-- null <-- no data - */ -case class CombineSum(child: Expression) extends AggregateExpression { - def this() = this(null) - - override def children: Seq[Expression] = child :: Nil - override def nullable: Boolean = true - override def dataType: DataType = child.dataType - override def toString: String = s"CombineSum($child)" - override def newInstance(): CombineSumFunction = new CombineSumFunction(child, this) -} - -case class SumDistinct(child: Expression) - extends PartialAggregate with trees.UnaryNode[Expression] { - - def this() = this(null) - override def nullable: Boolean = true - override def dataType: DataType = child.dataType match { - case DecimalType.Fixed(precision, scale) => - DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive - case DecimalType.Unlimited => - DecimalType.Unlimited - case _ => - child.dataType - } - override def toString: String = s"SUM(DISTINCT $child)" - override def newInstance(): SumDistinctFunction = new SumDistinctFunction(child, this) - - override def asPartial: SplitEvaluation = { - val partialSet = Alias(CollectHashSet(child :: Nil), "partialSets")() - SplitEvaluation( - CombineSetsAndSum(partialSet.toAttribute, this), - partialSet :: Nil) - } -} -case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends AggregateExpression { - def this() = this(null, null) - - override def children: Seq[Expression] = inputSet :: Nil - override def nullable: Boolean = true - override def dataType: DataType = base.dataType - override def toString: String = s"CombineAndSum($inputSet)" - override def newInstance(): CombineSetsAndSumFunction = { - new CombineSetsAndSumFunction(inputSet, this) - } -} - -case class CombineSetsAndSumFunction( - @transient inputSet: Expression, - @transient base: AggregateExpression) - extends AggregateFunction { - - def this() = this(null, null) // Required for serialization. - - val seen = new OpenHashSet[Any]() - - override def update(input: InternalRow): Unit = { - val inputSetEval = inputSet.eval(input).asInstanceOf[OpenHashSet[Any]] - val inputIterator = inputSetEval.iterator - while (inputIterator.hasNext) { - seen.add(inputIterator.next) - } - } - - override def eval(input: InternalRow): Any = { - val casted = seen.asInstanceOf[OpenHashSet[InternalRow]] - if (casted.size == 0) { - null - } else { - Cast(Literal( - casted.iterator.map(f => f.apply(0)).reduceLeft( - base.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)), - base.dataType).eval(null) - } - } -} - -case class First(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def nullable: Boolean = true - override def dataType: DataType = child.dataType - override def toString: String = s"FIRST($child)" - - override def asPartial: SplitEvaluation = { - val partialFirst = Alias(First(child), "PartialFirst")() - SplitEvaluation( - First(partialFirst.toAttribute), - partialFirst :: Nil) - } - override def newInstance(): FirstFunction = new FirstFunction(child, this) -} - -case class Last(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def references: AttributeSet = child.references - override def nullable: Boolean = true - override def dataType: DataType = child.dataType - override def toString: String = s"LAST($child)" - - override def asPartial: SplitEvaluation = { - val partialLast = Alias(Last(child), "PartialLast")() - SplitEvaluation( - Last(partialLast.toAttribute), - partialLast :: Nil) - } - override def newInstance(): LastFunction = new LastFunction(child, this) + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, "function average") } case class AverageFunction(expr: Expression, base: AggregateExpression) @@ -551,55 +481,41 @@ case class AverageFunction(expr: Expression, base: AggregateExpression) } } -case class CountFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { - def this() = this(null, null) // Required for serialization. +case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - var count: Long = _ + override def nullable: Boolean = true - override def update(input: InternalRow): Unit = { - val evaluatedExpr = expr.eval(input) - if (evaluatedExpr != null) { - count += 1L - } + override def dataType: DataType = child.dataType match { + case DecimalType.Fixed(precision, scale) => + DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive + case DecimalType.Unlimited => + DecimalType.Unlimited + case _ => + child.dataType } - override def eval(input: InternalRow): Any = count -} - -case class ApproxCountDistinctPartitionFunction( - expr: Expression, - base: AggregateExpression, - relativeSD: Double) - extends AggregateFunction { - def this() = this(null, null, 0) // Required for serialization. + override def toString: String = s"SUM($child)" - private val hyperLogLog = new HyperLogLog(relativeSD) + override def asPartial: SplitEvaluation = { + child.dataType match { + case DecimalType.Fixed(_, _) => + val partialSum = Alias(Sum(Cast(child, DecimalType.Unlimited)), "PartialSum")() + SplitEvaluation( + Cast(CombineSum(partialSum.toAttribute), dataType), + partialSum :: Nil) - override def update(input: InternalRow): Unit = { - val evaluatedExpr = expr.eval(input) - if (evaluatedExpr != null) { - hyperLogLog.offer(evaluatedExpr) + case _ => + val partialSum = Alias(Sum(child), "PartialSum")() + SplitEvaluation( + CombineSum(partialSum.toAttribute), + partialSum :: Nil) } } - override def eval(input: InternalRow): Any = hyperLogLog -} - -case class ApproxCountDistinctMergeFunction( - expr: Expression, - base: AggregateExpression, - relativeSD: Double) - extends AggregateFunction { - def this() = this(null, null, 0) // Required for serialization. - - private val hyperLogLog = new HyperLogLog(relativeSD) - - override def update(input: InternalRow): Unit = { - val evaluatedExpr = expr.eval(input) - hyperLogLog.addAll(evaluatedExpr.asInstanceOf[HyperLogLog]) - } + override def newInstance(): SumFunction = new SumFunction(child, this) - override def eval(input: InternalRow): Any = hyperLogLog.cardinality() + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, "function sum") } case class SumFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { @@ -633,6 +549,30 @@ case class SumFunction(expr: Expression, base: AggregateExpression) extends Aggr } } +/** + * Sum should satisfy 3 cases: + * 1) sum of all null values = zero + * 2) sum for table column with no data = null + * 3) sum of column with null and not null values = sum of not null values + * Require separate CombineSum Expression and function as it has to distinguish "No data" case + * versus "data equals null" case, while aggregating results and at each partial expression.i.e., + * Combining PartitionLevel InputData + * <-- null + * Zero <-- Zero <-- null + * + * <-- null <-- no data + * null <-- null <-- no data + */ +case class CombineSum(child: Expression) extends AggregateExpression { + def this() = this(null) + + override def children: Seq[Expression] = child :: Nil + override def nullable: Boolean = true + override def dataType: DataType = child.dataType + override def toString: String = s"CombineSum($child)" + override def newInstance(): CombineSumFunction = new CombineSumFunction(child, this) +} + case class CombineSumFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { @@ -670,6 +610,33 @@ case class CombineSumFunction(expr: Expression, base: AggregateExpression) } } +case class SumDistinct(child: Expression) + extends PartialAggregate with trees.UnaryNode[Expression] { + + def this() = this(null) + override def nullable: Boolean = true + override def dataType: DataType = child.dataType match { + case DecimalType.Fixed(precision, scale) => + DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive + case DecimalType.Unlimited => + DecimalType.Unlimited + case _ => + child.dataType + } + override def toString: String = s"SUM(DISTINCT $child)" + override def newInstance(): SumDistinctFunction = new SumDistinctFunction(child, this) + + override def asPartial: SplitEvaluation = { + val partialSet = Alias(CollectHashSet(child :: Nil), "partialSets")() + SplitEvaluation( + CombineSetsAndSum(partialSet.toAttribute, this), + partialSet :: Nil) + } + + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, "function sumDistinct") +} + case class SumDistinctFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { @@ -696,8 +663,20 @@ case class SumDistinctFunction(expr: Expression, base: AggregateExpression) } } -case class CountDistinctFunction( - @transient expr: Seq[Expression], +case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends AggregateExpression { + def this() = this(null, null) + + override def children: Seq[Expression] = inputSet :: Nil + override def nullable: Boolean = true + override def dataType: DataType = base.dataType + override def toString: String = s"CombineAndSum($inputSet)" + override def newInstance(): CombineSetsAndSumFunction = { + new CombineSetsAndSumFunction(inputSet, this) + } +} + +case class CombineSetsAndSumFunction( + @transient inputSet: Expression, @transient base: AggregateExpression) extends AggregateFunction { @@ -705,17 +684,39 @@ case class CountDistinctFunction( val seen = new OpenHashSet[Any]() - @transient - val distinctValue = new InterpretedProjection(expr) - override def update(input: InternalRow): Unit = { - val evaluatedExpr = distinctValue(input) - if (!evaluatedExpr.anyNull) { - seen.add(evaluatedExpr) + val inputSetEval = inputSet.eval(input).asInstanceOf[OpenHashSet[Any]] + val inputIterator = inputSetEval.iterator + while (inputIterator.hasNext) { + seen.add(inputIterator.next) } } - override def eval(input: InternalRow): Any = seen.size.toLong + override def eval(input: InternalRow): Any = { + val casted = seen.asInstanceOf[OpenHashSet[InternalRow]] + if (casted.size == 0) { + null + } else { + Cast(Literal( + casted.iterator.map(f => f.apply(0)).reduceLeft( + base.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)), + base.dataType).eval(null) + } + } +} + +case class First(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { + override def nullable: Boolean = true + override def dataType: DataType = child.dataType + override def toString: String = s"FIRST($child)" + + override def asPartial: SplitEvaluation = { + val partialFirst = Alias(First(child), "PartialFirst")() + SplitEvaluation( + First(partialFirst.toAttribute), + partialFirst :: Nil) + } + override def newInstance(): FirstFunction = new FirstFunction(child, this) } case class FirstFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { @@ -732,6 +733,21 @@ case class FirstFunction(expr: Expression, base: AggregateExpression) extends Ag override def eval(input: InternalRow): Any = result } +case class Last(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { + override def references: AttributeSet = child.references + override def nullable: Boolean = true + override def dataType: DataType = child.dataType + override def toString: String = s"LAST($child)" + + override def asPartial: SplitEvaluation = { + val partialLast = Alias(Last(child), "PartialLast")() + SplitEvaluation( + Last(partialLast.toAttribute), + partialLast :: Nil) + } + override def newInstance(): LastFunction = new LastFunction(child, this) +} + case class LastFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { def this() = this(null, null) // Required for serialization. http://git-wip-us.apache.org/repos/asf/spark/blob/b71d3254/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index ace8427..3d4d9e2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -25,8 +25,6 @@ import org.apache.spark.sql.types._ abstract class UnaryArithmetic extends UnaryExpression { self: Product => - override def foldable: Boolean = child.foldable - override def nullable: Boolean = child.nullable override def dataType: DataType = child.dataType override def eval(input: InternalRow): Any = { http://git-wip-us.apache.org/repos/asf/spark/blob/b71d3254/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index e0bf07e..5def57b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -17,9 +17,10 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ - /** * Returns an Array containing the evaluation of all children expressions. */ @@ -27,15 +28,12 @@ case class CreateArray(children: Seq[Expression]) extends Expression { override def foldable: Boolean = children.forall(_.foldable) - lazy val childTypes = children.map(_.dataType).distinct - - override lazy val resolved = - childrenResolved && childTypes.size <= 1 + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), "function array") override def dataType: DataType = { - assert(resolved, s"Invalid dataType of mixed ArrayType ${childTypes.mkString(",")}") ArrayType( - childTypes.headOption.getOrElse(NullType), + children.headOption.map(_.dataType).getOrElse(NullType), containsNull = children.exists(_.nullable)) } @@ -56,19 +54,15 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { override def foldable: Boolean = children.forall(_.foldable) - override lazy val resolved: Boolean = childrenResolved - override lazy val dataType: StructType = { - assert(resolved, - s"CreateStruct contains unresolvable children: ${children.filterNot(_.resolved)}.") - val fields = children.zipWithIndex.map { case (child, idx) => - child match { - case ne: NamedExpression => - StructField(ne.name, ne.dataType, ne.nullable, ne.metadata) - case _ => - StructField(s"col${idx + 1}", child.dataType, child.nullable, Metadata.empty) - } + val fields = children.zipWithIndex.map { case (child, idx) => + child match { + case ne: NamedExpression => + StructField(ne.name, ne.dataType, ne.nullable, ne.metadata) + case _ => + StructField(s"col${idx + 1}", child.dataType, child.nullable, Metadata.empty) } + } StructType(fields) } http://git-wip-us.apache.org/repos/asf/spark/blob/b71d3254/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala index 2bc893a..f5c2dde 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala @@ -17,16 +17,17 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.types._ -/** Return the unscaled Long value of a Decimal, assuming it fits in a Long */ +/** + * Return the unscaled Long value of a Decimal, assuming it fits in a Long. + * Note: this expression is internal and created only by the optimizer, + * we don't need to do type check for it. + */ case class UnscaledValue(child: Expression) extends UnaryExpression { override def dataType: DataType = LongType - override def foldable: Boolean = child.foldable - override def nullable: Boolean = child.nullable override def toString: String = s"UnscaledValue($child)" override def eval(input: InternalRow): Any = { @@ -43,12 +44,14 @@ case class UnscaledValue(child: Expression) extends UnaryExpression { } } -/** Create a Decimal from an unscaled Long value */ +/** + * Create a Decimal from an unscaled Long value. + * Note: this expression is internal and created only by the optimizer, + * we don't need to do type check for it. + */ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends UnaryExpression { override def dataType: DataType = DecimalType(precision, scale) - override def foldable: Boolean = child.foldable - override def nullable: Boolean = child.nullable override def toString: String = s"MakeDecimal($child,$precision,$scale)" override def eval(input: InternalRow): Decimal = { http://git-wip-us.apache.org/repos/asf/spark/blob/b71d3254/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index f30cb42..356560e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.Map -import org.apache.spark.sql.catalyst +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.{CatalystTypeConverters, trees} import org.apache.spark.sql.types._ @@ -100,9 +100,14 @@ case class UserDefinedGenerator( case class Explode(child: Expression) extends Generator with trees.UnaryNode[Expression] { - override lazy val resolved = - child.resolved && - (child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType]) + override def checkInputDataTypes(): TypeCheckResult = { + if (child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType]) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure( + s"input to function explode should be array or map type, not ${child.dataType}") + } + } override def elementTypes: Seq[(DataType, Boolean)] = child.dataType match { case ArrayType(et, containsNull) => (et, containsNull) :: Nil http://git-wip-us.apache.org/repos/asf/spark/blob/b71d3254/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 250564d..5694afc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions import java.lang.{Long => JLong} -import org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StringType} import org.apache.spark.unsafe.types.UTF8String @@ -60,7 +59,6 @@ abstract class UnaryMathExpression(f: Double => Double, name: String) override def expectedChildTypes: Seq[DataType] = Seq(DoubleType) override def dataType: DataType = DoubleType - override def foldable: Boolean = child.foldable override def nullable: Boolean = true override def toString: String = s"$name($child)" @@ -224,7 +222,7 @@ case class Bin(child: Expression) def funcName: String = name.toLowerCase - override def eval(input: catalyst.InternalRow): Any = { + override def eval(input: InternalRow): Any = { val evalE = child.eval(input) if (evalE == null) { null http://git-wip-us.apache.org/repos/asf/spark/blob/b71d3254/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 9cacdce..6f56a9e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} @@ -113,7 +112,8 @@ case class Alias(child: Expression, name: String)( extends NamedExpression with trees.UnaryNode[Expression] { // Alias(Generator, xx) need to be transformed into Generate(generator, ...) - override lazy val resolved = childrenResolved && !child.isInstanceOf[Generator] + override lazy val resolved = + childrenResolved && checkInputDataTypes().isSuccess && !child.isInstanceOf[Generator] override def eval(input: InternalRow): Any = child.eval(input) http://git-wip-us.apache.org/repos/asf/spark/blob/b71d3254/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index 98acaf2..5d59114 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -17,33 +17,32 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.analysis.UnresolvedException +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types.DataType case class Coalesce(children: Seq[Expression]) extends Expression { /** Coalesce is nullable if all of its children are nullable, or if it has no children. */ - override def nullable: Boolean = !children.exists(!_.nullable) + override def nullable: Boolean = children.forall(_.nullable) // Coalesce is foldable if all children are foldable. - override def foldable: Boolean = !children.exists(!_.foldable) + override def foldable: Boolean = children.forall(_.foldable) - // Only resolved if all the children are of the same type. - override lazy val resolved = childrenResolved && (children.map(_.dataType).distinct.size == 1) + override def checkInputDataTypes(): TypeCheckResult = { + if (children == Nil) { + TypeCheckResult.TypeCheckFailure("input to function coalesce cannot be empty") + } else { + TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), "function coalesce") + } + } override def toString: String = s"Coalesce(${children.mkString(",")})" - override def dataType: DataType = if (resolved) { - children.head.dataType - } else { - val childTypes = children.map(c => s"$c: ${c.dataType}").mkString(", ") - throw new UnresolvedException( - this, s"Coalesce cannot have children of different types. $childTypes") - } + override def dataType: DataType = children.head.dataType override def eval(input: InternalRow): Any = { - var i = 0 var result: Any = null val childIterator = children.iterator while (childIterator.hasNext && result == null) { @@ -75,7 +74,6 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } case class IsNull(child: Expression) extends UnaryExpression with Predicate { - override def foldable: Boolean = child.foldable override def nullable: Boolean = false override def eval(input: InternalRow): Any = { @@ -93,7 +91,6 @@ case class IsNull(child: Expression) extends UnaryExpression with Predicate { } case class IsNotNull(child: Expression) extends UnaryExpression with Predicate { - override def foldable: Boolean = child.foldable override def nullable: Boolean = false override def toString: String = s"IS NOT NULL $child" http://git-wip-us.apache.org/repos/asf/spark/blob/b71d3254/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala index 30e4167..efc6f50 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala @@ -78,6 +78,8 @@ case class NewSet(elementType: DataType) extends LeafExpression { /** * Adds an item to a set. * For performance, this expression mutates its input during evaluation. + * Note: this expression is internal and created only by the GeneratedAggregate, + * we don't need to do type check for it. */ case class AddItemToSet(item: Expression, set: Expression) extends Expression { @@ -85,7 +87,7 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression { override def nullable: Boolean = set.nullable - override def dataType: OpenHashSetUDT = set.dataType.asInstanceOf[OpenHashSetUDT] + override def dataType: DataType = set.dataType override def eval(input: InternalRow): Any = { val itemEval = item.eval(input) @@ -128,12 +130,14 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression { /** * Combines the elements of two sets. * For performance, this expression mutates its left input set during evaluation. + * Note: this expression is internal and created only by the GeneratedAggregate, + * we don't need to do type check for it. */ case class CombineSets(left: Expression, right: Expression) extends BinaryExpression { override def nullable: Boolean = left.nullable || right.nullable - override def dataType: OpenHashSetUDT = left.dataType.asInstanceOf[OpenHashSetUDT] + override def dataType: DataType = left.dataType override def symbol: String = "++=" @@ -176,6 +180,8 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres /** * Returns the number of elements in the input set. + * Note: this expression is internal and created only by the GeneratedAggregate, + * we don't need to do type check for it. */ case class CountSet(child: Expression) extends UnaryExpression { http://git-wip-us.apache.org/repos/asf/spark/blob/b71d3254/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 315c63e..44416e7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -117,8 +117,6 @@ trait CaseConversionExpression extends ExpectsInputTypes { def convert(v: UTF8String): UTF8String - override def foldable: Boolean = child.foldable - override def nullable: Boolean = child.nullable override def dataType: DataType = StringType override def expectedChildTypes: Seq[DataType] = Seq(StringType) http://git-wip-us.apache.org/repos/asf/spark/blob/b71d3254/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 896e383..12023ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -68,7 +68,8 @@ case class WindowSpecDefinition( override def children: Seq[Expression] = partitionSpec ++ orderSpec override lazy val resolved: Boolean = - childrenResolved && frameSpecification.isInstanceOf[SpecifiedWindowFrame] + childrenResolved && checkInputDataTypes().isSuccess && + frameSpecification.isInstanceOf[SpecifiedWindowFrame] override def toString: String = simpleString http://git-wip-us.apache.org/repos/asf/spark/blob/b71d3254/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index 04857a2..8656cc3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -48,6 +48,15 @@ object TypeUtils { } } + def checkForSameTypeInputExpr(types: Seq[DataType], caller: String): TypeCheckResult = { + if (types.distinct.size > 1) { + TypeCheckResult.TypeCheckFailure( + s"input to $caller should all be the same type, but it's ${types.mkString("[", ", ", "]")}") + } else { + TypeCheckResult.TypeCheckSuccess + } + } + def getNumeric(t: DataType): Numeric[Any] = t.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]] http://git-wip-us.apache.org/repos/asf/spark/blob/b71d3254/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index e09cd79..77ca080 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -193,7 +193,7 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter { errorTest( "bad casts", testRelation.select(Literal(1).cast(BinaryType).as('badCast)), - "invalid cast" :: Literal(1).dataType.simpleString :: BinaryType.simpleString :: Nil) + "cannot cast" :: Literal(1).dataType.simpleString :: BinaryType.simpleString :: Nil) errorTest( "non-boolean filters", @@ -264,9 +264,9 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter { val plan = Aggregate( Nil, - Alias(Sum(AttributeReference("a", StringType)(exprId = ExprId(1))), "b")() :: Nil, + Alias(Sum(AttributeReference("a", IntegerType)(exprId = ExprId(1))), "b")() :: Nil, LocalRelation( - AttributeReference("a", StringType)(exprId = ExprId(2)))) + AttributeReference("a", IntegerType)(exprId = ExprId(2)))) assert(plan.resolved) http://git-wip-us.apache.org/repos/asf/spark/blob/b71d3254/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala new file mode 100644 index 0000000..bc1537b --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.types.StringType + +class ExpressionTypeCheckingSuite extends SparkFunSuite { + + val testRelation = LocalRelation( + 'intField.int, + 'stringField.string, + 'booleanField.boolean, + 'complexField.array(StringType)) + + def assertError(expr: Expression, errorMessage: String): Unit = { + val e = intercept[AnalysisException] { + assertSuccess(expr) + } + assert(e.getMessage.contains( + s"cannot resolve '${expr.prettyString}' due to data type mismatch:")) + assert(e.getMessage.contains(errorMessage)) + } + + def assertSuccess(expr: Expression): Unit = { + val analyzed = testRelation.select(expr.as("c")).analyze + SimpleAnalyzer.checkAnalysis(analyzed) + } + + def assertErrorForDifferingTypes(expr: Expression): Unit = { + assertError(expr, + s"differing types in ${expr.getClass.getSimpleName} (IntegerType and BooleanType).") + } + + test("check types for unary arithmetic") { + assertError(UnaryMinus('stringField), "operator - accepts numeric type") + assertError(Abs('stringField), "function abs accepts numeric type") + assertError(BitwiseNot('stringField), "operator ~ accepts integral type") + } + + test("check types for binary arithmetic") { + // We will cast String to Double for binary arithmetic + assertSuccess(Add('intField, 'stringField)) + assertSuccess(Subtract('intField, 'stringField)) + assertSuccess(Multiply('intField, 'stringField)) + assertSuccess(Divide('intField, 'stringField)) + assertSuccess(Remainder('intField, 'stringField)) + // checkAnalysis(BitwiseAnd('intField, 'stringField)) + + assertErrorForDifferingTypes(Add('intField, 'booleanField)) + assertErrorForDifferingTypes(Subtract('intField, 'booleanField)) + assertErrorForDifferingTypes(Multiply('intField, 'booleanField)) + assertErrorForDifferingTypes(Divide('intField, 'booleanField)) + assertErrorForDifferingTypes(Remainder('intField, 'booleanField)) + assertErrorForDifferingTypes(BitwiseAnd('intField, 'booleanField)) + assertErrorForDifferingTypes(BitwiseOr('intField, 'booleanField)) + assertErrorForDifferingTypes(BitwiseXor('intField, 'booleanField)) + assertErrorForDifferingTypes(MaxOf('intField, 'booleanField)) + assertErrorForDifferingTypes(MinOf('intField, 'booleanField)) + + assertError(Add('booleanField, 'booleanField), "operator + accepts numeric type") + assertError(Subtract('booleanField, 'booleanField), "operator - accepts numeric type") + assertError(Multiply('booleanField, 'booleanField), "operator * accepts numeric type") + assertError(Divide('booleanField, 'booleanField), "operator / accepts numeric type") + assertError(Remainder('booleanField, 'booleanField), "operator % accepts numeric type") + + assertError(BitwiseAnd('booleanField, 'booleanField), "operator & accepts integral type") + assertError(BitwiseOr('booleanField, 'booleanField), "operator | accepts integral type") + assertError(BitwiseXor('booleanField, 'booleanField), "operator ^ accepts integral type") + + assertError(MaxOf('complexField, 'complexField), "function maxOf accepts non-complex type") + assertError(MinOf('complexField, 'complexField), "function minOf accepts non-complex type") + } + + test("check types for predicates") { + // We will cast String to Double for binary comparison + assertSuccess(EqualTo('intField, 'stringField)) + assertSuccess(EqualNullSafe('intField, 'stringField)) + assertSuccess(LessThan('intField, 'stringField)) + assertSuccess(LessThanOrEqual('intField, 'stringField)) + assertSuccess(GreaterThan('intField, 'stringField)) + assertSuccess(GreaterThanOrEqual('intField, 'stringField)) + + // We will transform EqualTo with numeric and boolean types to CaseKeyWhen + assertSuccess(EqualTo('intField, 'booleanField)) + assertSuccess(EqualNullSafe('intField, 'booleanField)) + + assertError(EqualTo('intField, 'complexField), "differing types") + assertError(EqualNullSafe('intField, 'complexField), "differing types") + + assertErrorForDifferingTypes(LessThan('intField, 'booleanField)) + assertErrorForDifferingTypes(LessThanOrEqual('intField, 'booleanField)) + assertErrorForDifferingTypes(GreaterThan('intField, 'booleanField)) + assertErrorForDifferingTypes(GreaterThanOrEqual('intField, 'booleanField)) + + assertError( + LessThan('complexField, 'complexField), "operator < accepts non-complex type") + assertError( + LessThanOrEqual('complexField, 'complexField), "operator <= accepts non-complex type") + assertError( + GreaterThan('complexField, 'complexField), "operator > accepts non-complex type") + assertError( + GreaterThanOrEqual('complexField, 'complexField), "operator >= accepts non-complex type") + + assertError( + If('intField, 'stringField, 'stringField), + "type of predicate expression in If should be boolean") + assertErrorForDifferingTypes(If('booleanField, 'intField, 'booleanField)) + + assertError( + CaseWhen(Seq('booleanField, 'intField, 'booleanField, 'complexField)), + "THEN and ELSE expressions should all be same type or coercible to a common type") + assertError( + CaseKeyWhen('intField, Seq('intField, 'stringField, 'intField, 'complexField)), + "THEN and ELSE expressions should all be same type or coercible to a common type") + assertError( + CaseWhen(Seq('booleanField, 'intField, 'intField, 'intField)), + "WHEN expressions in CaseWhen should all be boolean type") + } + + test("check types for aggregates") { + // We will cast String to Double for sum and average + assertSuccess(Sum('stringField)) + assertSuccess(SumDistinct('stringField)) + assertSuccess(Average('stringField)) + + assertError(Min('complexField), "function min accepts non-complex type") + assertError(Max('complexField), "function max accepts non-complex type") + assertError(Sum('booleanField), "function sum accepts numeric type") + assertError(SumDistinct('booleanField), "function sumDistinct accepts numeric type") + assertError(Average('booleanField), "function average accepts numeric type") + } + + test("check types for others") { + assertError(CreateArray(Seq('intField, 'booleanField)), + "input to function array should all be the same type") + assertError(Coalesce(Seq('intField, 'booleanField)), + "input to function coalesce should all be the same type") + assertError(Coalesce(Nil), "input to function coalesce cannot be empty") + assertError(Explode('intField), + "input to function explode should be array or map type") + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/b71d3254/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala deleted file mode 100644 index 49b1119..0000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala +++ /dev/null @@ -1,141 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.types.StringType - -class ExpressionTypeCheckingSuite extends SparkFunSuite { - - val testRelation = LocalRelation( - 'intField.int, - 'stringField.string, - 'booleanField.boolean, - 'complexField.array(StringType)) - - def assertError(expr: Expression, errorMessage: String): Unit = { - val e = intercept[AnalysisException] { - assertSuccess(expr) - } - assert(e.getMessage.contains( - s"cannot resolve '${expr.prettyString}' due to data type mismatch:")) - assert(e.getMessage.contains(errorMessage)) - } - - def assertSuccess(expr: Expression): Unit = { - val analyzed = testRelation.select(expr.as("c")).analyze - SimpleAnalyzer.checkAnalysis(analyzed) - } - - def assertErrorForDifferingTypes(expr: Expression): Unit = { - assertError(expr, - s"differing types in ${expr.getClass.getSimpleName} (IntegerType and BooleanType).") - } - - test("check types for unary arithmetic") { - assertError(UnaryMinus('stringField), "operator - accepts numeric type") - assertError(Abs('stringField), "function abs accepts numeric type") - assertError(BitwiseNot('stringField), "operator ~ accepts integral type") - } - - test("check types for binary arithmetic") { - // We will cast String to Double for binary arithmetic - assertSuccess(Add('intField, 'stringField)) - assertSuccess(Subtract('intField, 'stringField)) - assertSuccess(Multiply('intField, 'stringField)) - assertSuccess(Divide('intField, 'stringField)) - assertSuccess(Remainder('intField, 'stringField)) - // checkAnalysis(BitwiseAnd('intField, 'stringField)) - - assertErrorForDifferingTypes(Add('intField, 'booleanField)) - assertErrorForDifferingTypes(Subtract('intField, 'booleanField)) - assertErrorForDifferingTypes(Multiply('intField, 'booleanField)) - assertErrorForDifferingTypes(Divide('intField, 'booleanField)) - assertErrorForDifferingTypes(Remainder('intField, 'booleanField)) - assertErrorForDifferingTypes(BitwiseAnd('intField, 'booleanField)) - assertErrorForDifferingTypes(BitwiseOr('intField, 'booleanField)) - assertErrorForDifferingTypes(BitwiseXor('intField, 'booleanField)) - assertErrorForDifferingTypes(MaxOf('intField, 'booleanField)) - assertErrorForDifferingTypes(MinOf('intField, 'booleanField)) - - assertError(Add('booleanField, 'booleanField), "operator + accepts numeric type") - assertError(Subtract('booleanField, 'booleanField), "operator - accepts numeric type") - assertError(Multiply('booleanField, 'booleanField), "operator * accepts numeric type") - assertError(Divide('booleanField, 'booleanField), "operator / accepts numeric type") - assertError(Remainder('booleanField, 'booleanField), "operator % accepts numeric type") - - assertError(BitwiseAnd('booleanField, 'booleanField), "operator & accepts integral type") - assertError(BitwiseOr('booleanField, 'booleanField), "operator | accepts integral type") - assertError(BitwiseXor('booleanField, 'booleanField), "operator ^ accepts integral type") - - assertError(MaxOf('complexField, 'complexField), "function maxOf accepts non-complex type") - assertError(MinOf('complexField, 'complexField), "function minOf accepts non-complex type") - } - - test("check types for predicates") { - // We will cast String to Double for binary comparison - assertSuccess(EqualTo('intField, 'stringField)) - assertSuccess(EqualNullSafe('intField, 'stringField)) - assertSuccess(LessThan('intField, 'stringField)) - assertSuccess(LessThanOrEqual('intField, 'stringField)) - assertSuccess(GreaterThan('intField, 'stringField)) - assertSuccess(GreaterThanOrEqual('intField, 'stringField)) - - // We will transform EqualTo with numeric and boolean types to CaseKeyWhen - assertSuccess(EqualTo('intField, 'booleanField)) - assertSuccess(EqualNullSafe('intField, 'booleanField)) - - assertError(EqualTo('intField, 'complexField), "differing types") - assertError(EqualNullSafe('intField, 'complexField), "differing types") - - assertErrorForDifferingTypes(LessThan('intField, 'booleanField)) - assertErrorForDifferingTypes(LessThanOrEqual('intField, 'booleanField)) - assertErrorForDifferingTypes(GreaterThan('intField, 'booleanField)) - assertErrorForDifferingTypes(GreaterThanOrEqual('intField, 'booleanField)) - - assertError( - LessThan('complexField, 'complexField), "operator < accepts non-complex type") - assertError( - LessThanOrEqual('complexField, 'complexField), "operator <= accepts non-complex type") - assertError( - GreaterThan('complexField, 'complexField), "operator > accepts non-complex type") - assertError( - GreaterThanOrEqual('complexField, 'complexField), "operator >= accepts non-complex type") - - assertError( - If('intField, 'stringField, 'stringField), - "type of predicate expression in If should be boolean") - assertErrorForDifferingTypes(If('booleanField, 'intField, 'booleanField)) - - assertError( - CaseWhen(Seq('booleanField, 'intField, 'booleanField, 'complexField)), - "THEN and ELSE expressions should all be same type or coercible to a common type") - assertError( - CaseKeyWhen('intField, Seq('intField, 'stringField, 'intField, 'complexField)), - "THEN and ELSE expressions should all be same type or coercible to a common type") - assertError( - CaseWhen(Seq('booleanField, 'intField, 'intField, 'intField)), - "WHEN expressions in CaseWhen should all be boolean type") - - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/b71d3254/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala index 6db551c..f9c3fe9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala @@ -55,7 +55,7 @@ private[spark] case class PythonUDF( override def toString: String = s"PythonUDF#$name(${children.mkString(",")})" - def nullable: Boolean = true + override def nullable: Boolean = true override def eval(input: InternalRow): Any = { throw new UnsupportedOperationException("PythonUDFs can not be directly evaluated.") http://git-wip-us.apache.org/repos/asf/spark/blob/b71d3254/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala index f0f04f8..197e9bf 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala @@ -59,10 +59,4 @@ class HiveTypeCoercionSuite extends HiveComparisonTest { } assert(numEquals === 1) } - - test("COALESCE with different types") { - intercept[RuntimeException] { - TestHive.sql("""SELECT COALESCE(1, true, "abc") FROM src limit 1""").collect() - } - } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org