Repository: spark
Updated Branches:
refs/heads/master 049e639fc -> 9bd80ad6b
[SPARK-15776][SQL] Divide Expression inside Aggregation function is casted to
wrong type
## What changes were proposed in this pull request?
This PR fixes the problem that Divide Expression inside Aggregation function is
casted to wrong type, which cause `select 1/2` and `select sum(1/2)`returning
different result.
**Before the change:**
```
scala> sql("select 1/2 as a").show()
+---+
| a|
+---+
|0.5|
+---+
scala> sql("select sum(1/2) as a").show()
+---+
| a|
+---+
|0 |
+---+
scala> sql("select sum(1 / 2) as a").schema
res4: org.apache.spark.sql.types.StructType =
StructType(StructField(a,LongType,true))
```
**After the change:**
```
scala> sql("select 1/2 as a").show()
+---+
| a|
+---+
|0.5|
+---+
scala> sql("select sum(1/2) as a").show()
+---+
| a|
+---+
|0.5|
+---+
scala> sql("select sum(1/2) as a").schema
res4: org.apache.spark.sql.types.StructType =
StructType(StructField(a,DoubleType,true))
```
## How was this patch tested?
Unit test.
This PR is based on https://github.com/apache/spark/pull/13524 by Sephiroth-Lin
Author: Sean Zhong <[email protected]>
Closes #13651 from clockfly/SPARK-15776.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9bd80ad6
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9bd80ad6
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9bd80ad6
Branch: refs/heads/master
Commit: 9bd80ad6bd43462d16ce24cda77cdfaa336c4e02
Parents: 049e639
Author: Sean Zhong <[email protected]>
Authored: Wed Jun 15 14:34:15 2016 -0700
Committer: Wenchen Fan <[email protected]>
Committed: Wed Jun 15 14:34:15 2016 -0700
----------------------------------------------------------------------
.../sql/catalyst/analysis/TypeCoercion.scala | 8 +++--
.../sql/catalyst/expressions/arithmetic.scala | 3 +-
.../sql/catalyst/analysis/AnalysisSuite.scala | 32 +++++++++++++++++
.../analysis/ExpressionTypeCheckingSuite.scala | 2 +-
.../catalyst/analysis/TypeCoercionSuite.scala | 37 ++++++++++++++++++--
.../expressions/ArithmeticExpressionSuite.scala | 19 +++++-----
.../plans/ConstraintPropagationSuite.scala | 4 +--
7 files changed, 86 insertions(+), 19 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/9bd80ad6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
----------------------------------------------------------------------
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index a5b5b91..16df628 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -525,14 +525,16 @@ object TypeCoercion {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
// Skip nodes who has not been resolved yet,
// as this is an extra rule which should be applied at last.
- case e if !e.resolved => e
+ case e if !e.childrenResolved => e
// Decimal and Double remain the same
case d: Divide if d.dataType == DoubleType => d
case d: Divide if d.dataType.isInstanceOf[DecimalType] => d
-
- case Divide(left, right) => Divide(Cast(left, DoubleType), Cast(right,
DoubleType))
+ case Divide(left, right) if isNumeric(left) && isNumeric(right) =>
+ Divide(Cast(left, DoubleType), Cast(right, DoubleType))
}
+
+ private def isNumeric(ex: Expression): Boolean =
ex.dataType.isInstanceOf[NumericType]
}
/**
http://git-wip-us.apache.org/repos/asf/spark/blob/9bd80ad6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
----------------------------------------------------------------------
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index b2df79a..4db1352 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -213,7 +213,7 @@ case class Multiply(left: Expression, right: Expression)
case class Divide(left: Expression, right: Expression)
extends BinaryArithmetic with NullIntolerant {
- override def inputType: AbstractDataType = NumericType
+ override def inputType: AbstractDataType = TypeCollection(DoubleType,
DecimalType)
override def symbol: String = "/"
override def decimalMethod: String = "$div"
@@ -221,7 +221,6 @@ case class Divide(left: Expression, right: Expression)
private lazy val div: (Any, Any) => Any = dataType match {
case ft: FractionalType => ft.fractional.asInstanceOf[Fractional[Any]].div
- case it: IntegralType => it.integral.asInstanceOf[Integral[Any]].quot
}
override def eval(input: InternalRow): Any = {
http://git-wip-us.apache.org/repos/asf/spark/blob/9bd80ad6/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
----------------------------------------------------------------------
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 77ea29e..102c78b 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -345,4 +345,36 @@ class AnalysisSuite extends AnalysisTest {
assertAnalysisSuccess(query)
}
+
+ private def assertExpressionType(
+ expression: Expression,
+ expectedDataType: DataType): Unit = {
+ val afterAnalyze =
+ Project(Seq(Alias(expression, "a")()),
OneRowRelation).analyze.expressions.head
+ if (!afterAnalyze.dataType.equals(expectedDataType)) {
+ fail(
+ s"""
+ |data type of expression $expression doesn't match expected:
+ |Actual data type:
+ |${afterAnalyze.dataType}
+ |
+ |Expected data type:
+ |${expectedDataType}
+ """.stripMargin)
+ }
+ }
+
+ test("SPARK-15776: test whether Divide expression's data type can be deduced
correctly by " +
+ "analyzer") {
+ assertExpressionType(sum(Divide(1, 2)), DoubleType)
+ assertExpressionType(sum(Divide(1.0, 2)), DoubleType)
+ assertExpressionType(sum(Divide(1, 2.0)), DoubleType)
+ assertExpressionType(sum(Divide(1.0, 2.0)), DoubleType)
+ assertExpressionType(sum(Divide(1, 2.0f)), DoubleType)
+ assertExpressionType(sum(Divide(1.0f, 2)), DoubleType)
+ assertExpressionType(sum(Divide(1, Decimal(2))), DecimalType(31, 11))
+ assertExpressionType(sum(Divide(Decimal(1), 2)), DecimalType(31, 11))
+ assertExpressionType(sum(Divide(Decimal(1), 2.0)), DoubleType)
+ assertExpressionType(sum(Divide(1.0, Decimal(2.0))), DoubleType)
+ }
}
http://git-wip-us.apache.org/repos/asf/spark/blob/9bd80ad6/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
----------------------------------------------------------------------
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
index 660dc86..54436ea 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
@@ -85,7 +85,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
assertError(Subtract('booleanField, 'booleanField),
"requires (numeric or calendarinterval) type")
assertError(Multiply('booleanField, 'booleanField), "requires numeric
type")
- assertError(Divide('booleanField, 'booleanField), "requires numeric type")
+ assertError(Divide('booleanField, 'booleanField), "requires (double or
decimal) type")
assertError(Remainder('booleanField, 'booleanField), "requires numeric
type")
assertError(BitwiseAnd('booleanField, 'booleanField), "requires integral
type")
http://git-wip-us.apache.org/repos/asf/spark/blob/9bd80ad6/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
----------------------------------------------------------------------
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
index 7435399..971c99b 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
@@ -19,10 +19,12 @@ package org.apache.spark.sql.catalyst.analysis
import java.sql.Timestamp
+import org.apache.spark.sql.catalyst.analysis.TypeCoercion.{Division,
FunctionArgumentConversion}
+import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval
@@ -199,9 +201,20 @@ class TypeCoercionSuite extends PlanTest {
}
private def ruleTest(rule: Rule[LogicalPlan], initial: Expression,
transformed: Expression) {
+ ruleTest(Seq(rule), initial, transformed)
+ }
+
+ private def ruleTest(
+ rules: Seq[Rule[LogicalPlan]],
+ initial: Expression,
+ transformed: Expression): Unit = {
val testRelation = LocalRelation(AttributeReference("a", IntegerType)())
+ val analyzer = new RuleExecutor[LogicalPlan] {
+ override val batches = Seq(Batch("Resolution", FixedPoint(3), rules: _*))
+ }
+
comparePlans(
- rule(Project(Seq(Alias(initial, "a")()), testRelation)),
+ analyzer.execute(Project(Seq(Alias(initial, "a")()), testRelation)),
Project(Seq(Alias(transformed, "a")()), testRelation))
}
@@ -630,6 +643,26 @@ class TypeCoercionSuite extends PlanTest {
Seq(Cast(Literal(1), StringType), Cast(Literal("b"), StringType)))
)
}
+
+ test("SPARK-15776 Divide expression's dataType should be casted to Double or
Decimal " +
+ "in aggregation function like sum") {
+ val rules = Seq(FunctionArgumentConversion, Division)
+ // Casts Integer to Double
+ ruleTest(rules, sum(Divide(4, 3)), sum(Divide(Cast(4, DoubleType), Cast(3,
DoubleType))))
+ // Left expression is Double, right expression is Int. Another rule
ImplicitTypeCasts will
+ // cast the right expression to Double.
+ ruleTest(rules, sum(Divide(4.0, 3)), sum(Divide(4.0, 3)))
+ // Left expression is Int, right expression is Double
+ ruleTest(rules, sum(Divide(4, 3.0)), sum(Divide(Cast(4, DoubleType),
Cast(3.0, DoubleType))))
+ // Casts Float to Double
+ ruleTest(
+ rules,
+ sum(Divide(4.0f, 3)),
+ sum(Divide(Cast(4.0f, DoubleType), Cast(3, DoubleType))))
+ // Left expression is Decimal, right expression is Int. Another rule
DecimalPrecision will cast
+ // the right expression to Decimal.
+ ruleTest(rules, sum(Divide(Decimal(4.0), 3)), sum(Divide(Decimal(4.0), 3)))
+ }
}
http://git-wip-us.apache.org/repos/asf/spark/blob/9bd80ad6/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
----------------------------------------------------------------------
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
index 72285c6..2e37887 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
@@ -117,8 +117,13 @@ class ArithmeticExpressionSuite extends SparkFunSuite with
ExpressionEvalHelper
}
}
+ private def testDecimalAndDoubleType(testFunc: (Int => Any) => Unit): Unit =
{
+ testFunc(_.toDouble)
+ testFunc(Decimal(_))
+ }
+
test("/ (Divide) basic") {
- testNumericDataTypes { convert =>
+ testDecimalAndDoubleType { convert =>
val left = Literal(convert(2))
val right = Literal(convert(1))
val dataType = left.dataType
@@ -128,12 +133,14 @@ class ArithmeticExpressionSuite extends SparkFunSuite
with ExpressionEvalHelper
checkEvaluation(Divide(left, Literal(convert(0))), null) // divide by
zero
}
- DataTypeTestUtils.numericTypeWithoutDecimal.foreach { tpe =>
+ Seq(DoubleType, DecimalType.SYSTEM_DEFAULT).foreach { tpe =>
checkConsistencyBetweenInterpretedAndCodegen(Divide, tpe, tpe)
}
}
- test("/ (Divide) for integral type") {
+ // By fixing SPARK-15776, Divide's inputType is required to be DoubleType of
DecimalType.
+ // TODO: in future release, we should add a IntegerDivide to support
integral types.
+ ignore("/ (Divide) for integral type") {
checkEvaluation(Divide(Literal(1.toByte), Literal(2.toByte)), 0.toByte)
checkEvaluation(Divide(Literal(1.toShort), Literal(2.toShort)), 0.toShort)
checkEvaluation(Divide(Literal(1), Literal(2)), 0)
@@ -143,12 +150,6 @@ class ArithmeticExpressionSuite extends SparkFunSuite with
ExpressionEvalHelper
checkEvaluation(Divide(positiveLongLit, negativeLongLit), 0L)
}
- test("/ (Divide) for floating point") {
- checkEvaluation(Divide(Literal(1.0f), Literal(2.0f)), 0.5f)
- checkEvaluation(Divide(Literal(1.0), Literal(2.0)), 0.5)
- checkEvaluation(Divide(Literal(Decimal(1.0)), Literal(Decimal(2.0))),
Decimal(0.5))
- }
-
test("% (Remainder)") {
testNumericDataTypes { convert =>
val left = Literal(convert(1))
http://git-wip-us.apache.org/repos/asf/spark/blob/9bd80ad6/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala
----------------------------------------------------------------------
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala
index 81cc6b1..0b73b5e 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala
@@ -298,7 +298,7 @@ class ConstraintPropagationSuite extends SparkFunSuite {
Cast(resolveColumn(tr, "a"), LongType) * resolveColumn(tr, "b") +
Cast(100, LongType) ===
Cast(resolveColumn(tr, "c"), LongType),
Cast(resolveColumn(tr, "d"), DoubleType) /
- Cast(Cast(10, LongType), DoubleType) ===
+ Cast(10, DoubleType) ===
Cast(resolveColumn(tr, "e"), DoubleType),
IsNotNull(resolveColumn(tr, "a")),
IsNotNull(resolveColumn(tr, "b")),
@@ -312,7 +312,7 @@ class ConstraintPropagationSuite extends SparkFunSuite {
Cast(resolveColumn(tr, "a"), LongType) * resolveColumn(tr, "b") -
Cast(10, LongType) >=
Cast(resolveColumn(tr, "c"), LongType),
Cast(resolveColumn(tr, "d"), DoubleType) /
- Cast(Cast(10, LongType), DoubleType) <
+ Cast(10, DoubleType) <
Cast(resolveColumn(tr, "e"), DoubleType),
IsNotNull(resolveColumn(tr, "a")),
IsNotNull(resolveColumn(tr, "b")),
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]