Repository: spark
Updated Branches:
  refs/heads/branch-1.0 a803e009c -> 86c4f5af9


[SPARK-2209][SQL] Cast shouldn't do null check twice.

Also took the chance to clean up cast a little bit. Too many arrows on each 
line before!

Author: Reynold Xin <[email protected]>

Closes #1143 from rxin/cast and squashes the following commits:

dd006cb [Reynold Xin] Code review feedback.
c2b88ae [Reynold Xin] [SPARK-2209][SQL] Cast shouldn't do null check twice.

(cherry picked from commit c55bbb49f7ec653f0ff635015d3bc789ca26c4eb)
Signed-off-by: Reynold Xin <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/86c4f5af
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/86c4f5af
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/86c4f5af

Branch: refs/heads/branch-1.0
Commit: 86c4f5af9de76727549f2b145d964722407cb927
Parents: a803e00
Author: Reynold Xin <[email protected]>
Authored: Fri Jun 20 00:01:19 2014 -0700
Committer: Reynold Xin <[email protected]>
Committed: Fri Jun 20 00:01:25 2014 -0700

----------------------------------------------------------------------
 .../spark/sql/catalyst/expressions/Cast.scala   | 274 +++++++++++--------
 1 file changed, 159 insertions(+), 115 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/86c4f5af/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 0b3a4e7..1f9716e 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -24,72 +24,87 @@ import org.apache.spark.sql.catalyst.types._
 /** Cast the child expression to the target data type. */
 case class Cast(child: Expression, dataType: DataType) extends UnaryExpression 
{
   override def foldable = child.foldable
-  def nullable = (child.dataType, dataType) match {
+
+  override def nullable = (child.dataType, dataType) match {
     case (StringType, _: NumericType) => true
     case (StringType, TimestampType)  => true
     case _                            => child.nullable
   }
+
   override def toString = s"CAST($child, $dataType)"
 
   type EvaluatedType = Any
 
-  def nullOrCast[T](a: Any, func: T => Any): Any = if(a == null) {
-    null
-  } else {
-    func(a.asInstanceOf[T])
-  }
+  // [[func]] assumes the input is no longer null because eval already does 
the null check.
+  @inline private[this] def buildCast[T](a: Any, func: T => Any): Any = 
func(a.asInstanceOf[T])
 
   // UDFToString
-  def castToString: Any => Any = child.dataType match {
-    case BinaryType => nullOrCast[Array[Byte]](_, new String(_, "UTF-8"))
-    case _ => nullOrCast[Any](_, _.toString)
+  private[this] def castToString: Any => Any = child.dataType match {
+    case BinaryType => buildCast[Array[Byte]](_, new String(_, "UTF-8"))
+    case _ => buildCast[Any](_, _.toString)
   }
 
   // BinaryConverter
-  def castToBinary: Any => Any = child.dataType match {
-    case StringType => nullOrCast[String](_, _.getBytes("UTF-8"))
+  private[this] def castToBinary: Any => Any = child.dataType match {
+    case StringType => buildCast[String](_, _.getBytes("UTF-8"))
   }
 
   // UDFToBoolean
-  def castToBoolean: Any => Any = child.dataType match {
-    case StringType => nullOrCast[String](_, _.length() != 0)
-    case TimestampType => nullOrCast[Timestamp](_, b => {(b.getTime() != 0 || 
b.getNanos() != 0)})
-    case LongType => nullOrCast[Long](_, _ != 0)
-    case IntegerType => nullOrCast[Int](_, _ != 0)
-    case ShortType => nullOrCast[Short](_, _ != 0)
-    case ByteType => nullOrCast[Byte](_, _ != 0)
-    case DecimalType => nullOrCast[BigDecimal](_, _ != 0)
-    case DoubleType => nullOrCast[Double](_, _ != 0)
-    case FloatType => nullOrCast[Float](_, _ != 0)
+  private[this] def castToBoolean: Any => Any = child.dataType match {
+    case StringType =>
+      buildCast[String](_, _.length() != 0)
+    case TimestampType =>
+      buildCast[Timestamp](_, b => b.getTime() != 0 || b.getNanos() != 0)
+    case LongType =>
+      buildCast[Long](_, _ != 0)
+    case IntegerType =>
+      buildCast[Int](_, _ != 0)
+    case ShortType =>
+      buildCast[Short](_, _ != 0)
+    case ByteType =>
+      buildCast[Byte](_, _ != 0)
+    case DecimalType =>
+      buildCast[BigDecimal](_, _ != 0)
+    case DoubleType =>
+      buildCast[Double](_, _ != 0)
+    case FloatType =>
+      buildCast[Float](_, _ != 0)
   }
 
   // TimestampConverter
-  def castToTimestamp: Any => Any = child.dataType match {
-    case StringType => nullOrCast[String](_, s => {
-      // Throw away extra if more than 9 decimal places
-      val periodIdx = s.indexOf(".");
-      var n = s
-      if (periodIdx != -1) {
-        if (n.length() - periodIdx > 9) {
+  private[this] def castToTimestamp: Any => Any = child.dataType match {
+    case StringType =>
+      buildCast[String](_, s => {
+        // Throw away extra if more than 9 decimal places
+        val periodIdx = s.indexOf(".")
+        var n = s
+        if (periodIdx != -1 && n.length() - periodIdx > 9) {
           n = n.substring(0, periodIdx + 10)
         }
-      }
-      try Timestamp.valueOf(n) catch { case _: 
java.lang.IllegalArgumentException => null}
-    })
-    case BooleanType => nullOrCast[Boolean](_, b => new Timestamp((if(b) 1 
else 0) * 1000))
-    case LongType => nullOrCast[Long](_, l => new Timestamp(l * 1000))
-    case IntegerType => nullOrCast[Int](_, i => new Timestamp(i * 1000))
-    case ShortType => nullOrCast[Short](_, s => new Timestamp(s * 1000))
-    case ByteType => nullOrCast[Byte](_, b => new Timestamp(b * 1000))
+        try Timestamp.valueOf(n) catch { case _: 
java.lang.IllegalArgumentException => null }
+      })
+    case BooleanType =>
+      buildCast[Boolean](_, b => new Timestamp((if (b) 1 else 0) * 1000))
+    case LongType =>
+      buildCast[Long](_, l => new Timestamp(l * 1000))
+    case IntegerType =>
+      buildCast[Int](_, i => new Timestamp(i * 1000))
+    case ShortType =>
+      buildCast[Short](_, s => new Timestamp(s * 1000))
+    case ByteType =>
+      buildCast[Byte](_, b => new Timestamp(b * 1000))
     // TimestampWritable.decimalToTimestamp
-    case DecimalType => nullOrCast[BigDecimal](_, d => decimalToTimestamp(d))
+    case DecimalType =>
+      buildCast[BigDecimal](_, d => decimalToTimestamp(d))
     // TimestampWritable.doubleToTimestamp
-    case DoubleType => nullOrCast[Double](_, d => decimalToTimestamp(d))
+    case DoubleType =>
+      buildCast[Double](_, d => decimalToTimestamp(d))
     // TimestampWritable.floatToTimestamp
-    case FloatType => nullOrCast[Float](_, f => decimalToTimestamp(f))
+    case FloatType =>
+      buildCast[Float](_, f => decimalToTimestamp(f))
   }
 
-  private def decimalToTimestamp(d: BigDecimal) = {
+  private[this]  def decimalToTimestamp(d: BigDecimal) = {
     val seconds = d.longValue()
     val bd = (d - seconds) * 1000000000
     val nanos = bd.intValue()
@@ -104,85 +119,118 @@ case class Cast(child: Expression, dataType: DataType) 
extends UnaryExpression {
   }
 
   // Timestamp to long, converting milliseconds to seconds
-  private def timestampToLong(ts: Timestamp) = ts.getTime / 1000
+  private[this] def timestampToLong(ts: Timestamp) = ts.getTime / 1000
 
-  private def timestampToDouble(ts: Timestamp) = {
+  private[this] def timestampToDouble(ts: Timestamp) = {
     // First part is the seconds since the beginning of time, followed by 
nanosecs.
     ts.getTime / 1000 + ts.getNanos.toDouble / 1000000000
   }
 
-  def castToLong: Any => Any = child.dataType match {
-    case StringType => nullOrCast[String](_, s => try s.toLong catch {
-      case _: NumberFormatException => null
-    })
-    case BooleanType => nullOrCast[Boolean](_, b => if(b) 1L else 0L)
-    case TimestampType => nullOrCast[Timestamp](_, t => timestampToLong(t))
-    case DecimalType => nullOrCast[BigDecimal](_, _.toLong)
-    case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toLong(b)
-  }
-
-  def castToInt: Any => Any = child.dataType match {
-    case StringType => nullOrCast[String](_, s => try s.toInt catch {
-      case _: NumberFormatException => null
-    })
-    case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0)
-    case TimestampType => nullOrCast[Timestamp](_, t => 
timestampToLong(t).toInt)
-    case DecimalType => nullOrCast[BigDecimal](_, _.toInt)
-    case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b)
-  }
-
-  def castToShort: Any => Any = child.dataType match {
-    case StringType => nullOrCast[String](_, s => try s.toShort catch {
-      case _: NumberFormatException => null
-    })
-    case BooleanType => nullOrCast[Boolean](_, b => if(b) 1.toShort else 
0.toShort)
-    case TimestampType => nullOrCast[Timestamp](_, t => 
timestampToLong(t).toShort)
-    case DecimalType => nullOrCast[BigDecimal](_, _.toShort)
-    case x: NumericType => b => 
x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toShort
-  }
-
-  def castToByte: Any => Any = child.dataType match {
-    case StringType => nullOrCast[String](_, s => try s.toByte catch {
-      case _: NumberFormatException => null
-    })
-    case BooleanType => nullOrCast[Boolean](_, b => if(b) 1.toByte else 
0.toByte)
-    case TimestampType => nullOrCast[Timestamp](_, t => 
timestampToLong(t).toByte)
-    case DecimalType => nullOrCast[BigDecimal](_, _.toByte)
-    case x: NumericType => b => 
x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte
-  }
-
-  def castToDecimal: Any => Any = child.dataType match {
-    case StringType => nullOrCast[String](_, s => try BigDecimal(s.toDouble) 
catch {
-      case _: NumberFormatException => null
-    })
-    case BooleanType => nullOrCast[Boolean](_, b => if(b) BigDecimal(1) else 
BigDecimal(0))
+  private[this] def castToLong: Any => Any = child.dataType match {
+    case StringType =>
+      buildCast[String](_, s => try s.toLong catch {
+        case _: NumberFormatException => null
+      })
+    case BooleanType =>
+      buildCast[Boolean](_, b => if (b) 1L else 0L)
+    case TimestampType =>
+      buildCast[Timestamp](_, t => timestampToLong(t))
+    case DecimalType =>
+      buildCast[BigDecimal](_, _.toLong)
+    case x: NumericType =>
+      b => x.numeric.asInstanceOf[Numeric[Any]].toLong(b)
+  }
+
+  private[this] def castToInt: Any => Any = child.dataType match {
+    case StringType =>
+      buildCast[String](_, s => try s.toInt catch {
+        case _: NumberFormatException => null
+      })
+    case BooleanType =>
+      buildCast[Boolean](_, b => if (b) 1 else 0)
+    case TimestampType =>
+      buildCast[Timestamp](_, t => timestampToLong(t).toInt)
+    case DecimalType =>
+      buildCast[BigDecimal](_, _.toInt)
+    case x: NumericType =>
+      b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b)
+  }
+
+  private[this] def castToShort: Any => Any = child.dataType match {
+    case StringType =>
+      buildCast[String](_, s => try s.toShort catch {
+        case _: NumberFormatException => null
+      })
+    case BooleanType =>
+      buildCast[Boolean](_, b => if (b) 1.toShort else 0.toShort)
+    case TimestampType =>
+      buildCast[Timestamp](_, t => timestampToLong(t).toShort)
+    case DecimalType =>
+      buildCast[BigDecimal](_, _.toShort)
+    case x: NumericType =>
+      b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toShort
+  }
+
+  private[this] def castToByte: Any => Any = child.dataType match {
+    case StringType =>
+      buildCast[String](_, s => try s.toByte catch {
+        case _: NumberFormatException => null
+      })
+    case BooleanType =>
+      buildCast[Boolean](_, b => if (b) 1.toByte else 0.toByte)
+    case TimestampType =>
+      buildCast[Timestamp](_, t => timestampToLong(t).toByte)
+    case DecimalType =>
+      buildCast[BigDecimal](_, _.toByte)
+    case x: NumericType =>
+      b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte
+  }
+
+  private[this] def castToDecimal: Any => Any = child.dataType match {
+    case StringType =>
+      buildCast[String](_, s => try BigDecimal(s.toDouble) catch {
+        case _: NumberFormatException => null
+      })
+    case BooleanType =>
+      buildCast[Boolean](_, b => if (b) BigDecimal(1) else BigDecimal(0))
     case TimestampType =>
       // Note that we lose precision here.
-      nullOrCast[Timestamp](_, t => BigDecimal(timestampToDouble(t)))
-    case x: NumericType => b => 
BigDecimal(x.numeric.asInstanceOf[Numeric[Any]].toDouble(b))
-  }
-
-  def castToDouble: Any => Any = child.dataType match {
-    case StringType => nullOrCast[String](_, s => try s.toDouble catch {
-      case _: NumberFormatException => null
-    })
-    case BooleanType => nullOrCast[Boolean](_, b => if(b) 1d else 0d)
-    case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t))
-    case DecimalType => nullOrCast[BigDecimal](_, _.toDouble)
-    case x: NumericType => b => 
x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)
-  }
-
-  def castToFloat: Any => Any = child.dataType match {
-    case StringType => nullOrCast[String](_, s => try s.toFloat catch {
-      case _: NumberFormatException => null
-    })
-    case BooleanType => nullOrCast[Boolean](_, b => if(b) 1f else 0f)
-    case TimestampType => nullOrCast[Timestamp](_, t => 
timestampToDouble(t).toFloat)
-    case DecimalType => nullOrCast[BigDecimal](_, _.toFloat)
-    case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toFloat(b)
+      buildCast[Timestamp](_, t => BigDecimal(timestampToDouble(t)))
+    case x: NumericType =>
+      b => BigDecimal(x.numeric.asInstanceOf[Numeric[Any]].toDouble(b))
+  }
+
+  private[this] def castToDouble: Any => Any = child.dataType match {
+    case StringType =>
+      buildCast[String](_, s => try s.toDouble catch {
+        case _: NumberFormatException => null
+      })
+    case BooleanType =>
+      buildCast[Boolean](_, b => if (b) 1d else 0d)
+    case TimestampType =>
+      buildCast[Timestamp](_, t => timestampToDouble(t))
+    case DecimalType =>
+      buildCast[BigDecimal](_, _.toDouble)
+    case x: NumericType =>
+      b => x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)
+  }
+
+  private[this] def castToFloat: Any => Any = child.dataType match {
+    case StringType =>
+      buildCast[String](_, s => try s.toFloat catch {
+        case _: NumberFormatException => null
+      })
+    case BooleanType =>
+      buildCast[Boolean](_, b => if (b) 1f else 0f)
+    case TimestampType =>
+      buildCast[Timestamp](_, t => timestampToDouble(t).toFloat)
+    case DecimalType =>
+      buildCast[BigDecimal](_, _.toFloat)
+    case x: NumericType =>
+      b => x.numeric.asInstanceOf[Numeric[Any]].toFloat(b)
   }
 
-  private lazy val cast: Any => Any = dataType match {
+  private[this] lazy val cast: Any => Any = dataType match {
     case StringType => castToString
     case BinaryType => castToBinary
     case DecimalType => castToDecimal
@@ -198,10 +246,6 @@ case class Cast(child: Expression, dataType: DataType) 
extends UnaryExpression {
 
   override def eval(input: Row): Any = {
     val evaluated = child.eval(input)
-    if (evaluated == null) {
-      null
-    } else {
-      cast(evaluated)
-    }
+    if (evaluated == null) null else cast(evaluated)
   }
 }

Reply via email to