Repository: spark
Updated Branches:
  refs/heads/master 1fdf659d2 -> 27b2821cf


[SPARK-1610] [SQL] Fix Cast to use exact type value when cast from BooleanType 
to NumericTy...

...pe.

`Cast` from `BooleanType` to `NumericType` are all using `Int` value.
But it causes `ClassCastException` when the casted value is used by the 
following evaluation like the code below:

``` scala
scala> import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst._

scala> import types._
import types._

scala> import expressions._
import expressions._

scala> Add(Cast(Literal(true), ShortType), Literal(1.toShort)).eval()
java.lang.ClassCastException: java.lang.Integer cannot be cast to 
java.lang.Short
        at scala.runtime.BoxesRunTime.unboxToShort(BoxesRunTime.java:102)
        at scala.math.Numeric$ShortIsIntegral$.plus(Numeric.scala:72)
        at 
org.apache.spark.sql.catalyst.expressions.Add$$anonfun$eval$2.apply(arithmetic.scala:58)
        at 
org.apache.spark.sql.catalyst.expressions.Add$$anonfun$eval$2.apply(arithmetic.scala:58)
        at 
org.apache.spark.sql.catalyst.expressions.Expression.n2(Expression.scala:114)
        at 
org.apache.spark.sql.catalyst.expressions.Add.eval(arithmetic.scala:58)
        at .<init>(<console>:17)
        at .<clinit>(<console>)
        at .<init>(<console>:7)
        at .<clinit>(<console>)
        at $print(<console>)
        at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
        at 
sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
        at 
sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
        at java.lang.reflect.Method.invoke(Method.java:483)
        at scala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:734)
        at scala.tools.nsc.interpreter.IMain$Request.loadAndRun(IMain.scala:983)
        at scala.tools.nsc.interpreter.IMain.loadAndRunReq$1(IMain.scala:573)
        at scala.tools.nsc.interpreter.IMain.interpret(IMain.scala:604)
        at scala.tools.nsc.interpreter.IMain.interpret(IMain.scala:568)
        at scala.tools.nsc.interpreter.ILoop.reallyInterpret$1(ILoop.scala:760)
        at 
scala.tools.nsc.interpreter.ILoop.interpretStartingWith(ILoop.scala:805)
        at scala.tools.nsc.interpreter.ILoop.command(ILoop.scala:717)
        at scala.tools.nsc.interpreter.ILoop.processLine$1(ILoop.scala:581)
        at scala.tools.nsc.interpreter.ILoop.innerLoop$1(ILoop.scala:588)
        at scala.tools.nsc.interpreter.ILoop.loop(ILoop.scala:591)
        at 
scala.tools.nsc.interpreter.ILoop$$anonfun$process$1.apply$mcZ$sp(ILoop.scala:882)
        at 
scala.tools.nsc.interpreter.ILoop$$anonfun$process$1.apply(ILoop.scala:837)
        at 
scala.tools.nsc.interpreter.ILoop$$anonfun$process$1.apply(ILoop.scala:837)
        at 
scala.tools.nsc.util.ScalaClassLoader$.savingContextLoader(ScalaClassLoader.scala:135)
        at scala.tools.nsc.interpreter.ILoop.process(ILoop.scala:837)
        at 
scala.tools.nsc.MainGenericRunner.runTarget$1(MainGenericRunner.scala:83)
        at scala.tools.nsc.MainGenericRunner.process(MainGenericRunner.scala:96)
        at scala.tools.nsc.MainGenericRunner$.main(MainGenericRunner.scala:105)
        at scala.tools.nsc.MainGenericRunner.main(MainGenericRunner.scala)
```

Author: Takuya UESHIN <[email protected]>

Closes #533 from ueshin/issues/SPARK-1610 and squashes the following commits:

70f36e8 [Takuya UESHIN] Fix Cast to use exact type value when cast from 
BooleanType to NumericType.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/27b2821c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/27b2821c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/27b2821c

Branch: refs/heads/master
Commit: 27b2821cf16948962c7a6f513621a1eba60b8cf3
Parents: 1fdf659
Author: Takuya UESHIN <[email protected]>
Authored: Thu Apr 24 09:57:28 2014 -0700
Committer: Reynold Xin <[email protected]>
Committed: Thu Apr 24 09:57:28 2014 -0700

----------------------------------------------------------------------
 .../org/apache/spark/sql/catalyst/expressions/Cast.scala  | 10 +++++-----
 .../catalyst/expressions/ExpressionEvaluationSuite.scala  |  7 +++++++
 2 files changed, 12 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/27b2821c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 1f3fab0..8b79b0c 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -111,7 +111,7 @@ case class Cast(child: Expression, dataType: DataType) 
extends UnaryExpression {
     case StringType => nullOrCast[String](_, s => try s.toLong catch {
       case _: NumberFormatException => null
     })
-    case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0)
+    case BooleanType => nullOrCast[Boolean](_, b => if(b) 1L else 0L)
     case TimestampType => nullOrCast[Timestamp](_, t => timestampToLong(t))
     case DecimalType => nullOrCast[BigDecimal](_, _.toLong)
     case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toLong(b)
@@ -131,7 +131,7 @@ case class Cast(child: Expression, dataType: DataType) 
extends UnaryExpression {
     case StringType => nullOrCast[String](_, s => try s.toShort catch {
       case _: NumberFormatException => null
     })
-    case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0)
+    case BooleanType => nullOrCast[Boolean](_, b => if(b) 1.toShort else 
0.toShort)
     case TimestampType => nullOrCast[Timestamp](_, t => 
timestampToLong(t).toShort)
     case DecimalType => nullOrCast[BigDecimal](_, _.toShort)
     case x: NumericType => b => 
x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toShort
@@ -141,7 +141,7 @@ case class Cast(child: Expression, dataType: DataType) 
extends UnaryExpression {
     case StringType => nullOrCast[String](_, s => try s.toByte catch {
       case _: NumberFormatException => null
     })
-    case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0)
+    case BooleanType => nullOrCast[Boolean](_, b => if(b) 1.toByte else 
0.toByte)
     case TimestampType => nullOrCast[Timestamp](_, t => 
timestampToLong(t).toByte)
     case DecimalType => nullOrCast[BigDecimal](_, _.toByte)
     case x: NumericType => b => 
x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte
@@ -162,7 +162,7 @@ case class Cast(child: Expression, dataType: DataType) 
extends UnaryExpression {
     case StringType => nullOrCast[String](_, s => try s.toDouble catch {
       case _: NumberFormatException => null
     })
-    case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0)
+    case BooleanType => nullOrCast[Boolean](_, b => if(b) 1d else 0d)
     case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t))
     case DecimalType => nullOrCast[BigDecimal](_, _.toDouble)
     case x: NumericType => b => 
x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)
@@ -172,7 +172,7 @@ case class Cast(child: Expression, dataType: DataType) 
extends UnaryExpression {
     case StringType => nullOrCast[String](_, s => try s.toFloat catch {
       case _: NumberFormatException => null
     })
-    case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0)
+    case BooleanType => nullOrCast[Boolean](_, b => if(b) 1f else 0f)
     case TimestampType => nullOrCast[Timestamp](_, t => 
timestampToDouble(t).toFloat)
     case DecimalType => nullOrCast[BigDecimal](_, _.toFloat)
     case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toFloat(b)

http://git-wip-us.apache.org/repos/asf/spark/blob/27b2821c/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
index 2cd0d2b..4ce0dff 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
@@ -237,6 +237,13 @@ class ExpressionEvaluationSuite extends FunSuite {
     checkEvaluation("2012-12-11" cast DoubleType, null)
     checkEvaluation(Literal(123) cast IntegerType, 123)
 
+    checkEvaluation(Literal(23d) + Cast(true, DoubleType), 24)
+    checkEvaluation(Literal(23) + Cast(true, IntegerType), 24)
+    checkEvaluation(Literal(23f) + Cast(true, FloatType), 24)
+    checkEvaluation(Literal(BigDecimal(23)) + Cast(true, DecimalType), 24)
+    checkEvaluation(Literal(23.toByte) + Cast(true, ByteType), 24)
+    checkEvaluation(Literal(23.toShort) + Cast(true, ShortType), 24)
+
     intercept[Exception] {evaluate(Literal(1) cast BinaryType, null)}
   }
 

Reply via email to