Github user liancheng commented on a diff in the pull request:

    https://github.com/apache/spark/pull/1143#discussion_r14007203
  
    --- Diff: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
 ---
    @@ -24,72 +24,89 @@ import org.apache.spark.sql.catalyst.types._
     /** Cast the child expression to the target data type. */
     case class Cast(child: Expression, dataType: DataType) extends 
UnaryExpression {
       override def foldable = child.foldable
    -  def nullable = (child.dataType, dataType) match {
    +
    +  override def nullable = (child.dataType, dataType) match {
         case (StringType, _: NumericType) => true
         case (StringType, TimestampType)  => true
         case _                            => child.nullable
       }
    +
       override def toString = s"CAST($child, $dataType)"
     
       type EvaluatedType = Any
     
    -  def nullOrCast[T](a: Any, func: T => Any): Any = if(a == null) {
    -    null
    -  } else {
    -    func(a.asInstanceOf[T])
    -  }
    +  // [[func]] assumes the input is no longer null because eval already 
does the null check.
    +  @inline private[this] def buildCast[T](a: Any, func: T => Any): Any = 
func(a.asInstanceOf[T])
     
       // UDFToString
    -  def castToString: Any => Any = child.dataType match {
    -    case BinaryType => nullOrCast[Array[Byte]](_, new String(_, "UTF-8"))
    -    case _ => nullOrCast[Any](_, _.toString)
    +  private[this] def castToString: Any => Any = child.dataType match {
    +    case BinaryType => buildCast[Array[Byte]](_, new String(_, "UTF-8"))
    +    case _ => buildCast[Any](_, _.toString)
       }
     
       // BinaryConverter
    -  def castToBinary: Any => Any = child.dataType match {
    -    case StringType => nullOrCast[String](_, _.getBytes("UTF-8"))
    +  private[this] def castToBinary: Any => Any = child.dataType match {
    +    case StringType => buildCast[String](_, _.getBytes("UTF-8"))
       }
     
       // UDFToBoolean
    -  def castToBoolean: Any => Any = child.dataType match {
    -    case StringType => nullOrCast[String](_, _.length() != 0)
    -    case TimestampType => nullOrCast[Timestamp](_, b => {(b.getTime() != 0 
|| b.getNanos() != 0)})
    -    case LongType => nullOrCast[Long](_, _ != 0)
    -    case IntegerType => nullOrCast[Int](_, _ != 0)
    -    case ShortType => nullOrCast[Short](_, _ != 0)
    -    case ByteType => nullOrCast[Byte](_, _ != 0)
    -    case DecimalType => nullOrCast[BigDecimal](_, _ != 0)
    -    case DoubleType => nullOrCast[Double](_, _ != 0)
    -    case FloatType => nullOrCast[Float](_, _ != 0)
    +  private[this] def castToBoolean: Any => Any = child.dataType match {
    +    case StringType =>
    +      buildCast[String](_, _.length() != 0)
    +    case TimestampType =>
    +      buildCast[Timestamp](_, b => b.getTime() != 0 || b.getNanos() != 0)
    +    case LongType =>
    +      buildCast[Long](_, _ != 0)
    +    case IntegerType =>
    +      buildCast[Int](_, _ != 0)
    +    case ShortType =>
    +      buildCast[Short](_, _ != 0)
    +    case ByteType =>
    +      buildCast[Byte](_, _ != 0)
    +    case DecimalType =>
    +      buildCast[BigDecimal](_, _ != 0)
    +    case DoubleType =>
    +      buildCast[Double](_, _ != 0)
    +    case FloatType =>
    +      buildCast[Float](_, _ != 0)
       }
     
       // TimestampConverter
    -  def castToTimestamp: Any => Any = child.dataType match {
    -    case StringType => nullOrCast[String](_, s => {
    -      // Throw away extra if more than 9 decimal places
    -      val periodIdx = s.indexOf(".");
    -      var n = s
    -      if (periodIdx != -1) {
    -        if (n.length() - periodIdx > 9) {
    -          n = n.substring(0, periodIdx + 10)
    +  private[this] def castToTimestamp: Any => Any = child.dataType match {
    +    case StringType =>
    +      buildCast[String](_, s => {
    +        // Throw away extra if more than 9 decimal places
    +        val periodIdx = s.indexOf(".")
    +        var n = s
    +        if (periodIdx != -1) {
    +          if (n.length() - periodIdx > 9) {
    +            n = n.substring(0, periodIdx + 10)
    +          }
             }
    -      }
    -      try Timestamp.valueOf(n) catch { case _: 
java.lang.IllegalArgumentException => null}
    -    })
    -    case BooleanType => nullOrCast[Boolean](_, b => new Timestamp((if(b) 1 
else 0) * 1000))
    -    case LongType => nullOrCast[Long](_, l => new Timestamp(l * 1000))
    -    case IntegerType => nullOrCast[Int](_, i => new Timestamp(i * 1000))
    -    case ShortType => nullOrCast[Short](_, s => new Timestamp(s * 1000))
    -    case ByteType => nullOrCast[Byte](_, b => new Timestamp(b * 1000))
    +        try Timestamp.valueOf(n) catch { case _: 
java.lang.IllegalArgumentException => null }
    +      })
    +    case BooleanType =>
    +      buildCast[Boolean](_, b => new Timestamp((if(b) 1 else 0) * 1000))
    +    case LongType =>
    +      buildCast[Long](_, l => new Timestamp(l * 1000))
    +    case IntegerType =>
    +      buildCast[Int](_, i => new Timestamp(i * 1000))
    +    case ShortType =>
    +      buildCast[Short](_, s => new Timestamp(s * 1000))
    +    case ByteType =>
    +      buildCast[Byte](_, b => new Timestamp(b * 1000))
         // TimestampWritable.decimalToTimestamp
    -    case DecimalType => nullOrCast[BigDecimal](_, d => 
decimalToTimestamp(d))
    +    case DecimalType =>
    +      buildCast[BigDecimal](_, d => decimalToTimestamp(d))
         // TimestampWritable.doubleToTimestamp
    -    case DoubleType => nullOrCast[Double](_, d => decimalToTimestamp(d))
    +    case DoubleType =>
    +      buildCast[Double](_, d => decimalToTimestamp(d))
         // TimestampWritable.floatToTimestamp
    -    case FloatType => nullOrCast[Float](_, f => decimalToTimestamp(f))
    +    case FloatType =>
    +      buildCast[Float](_, f => decimalToTimestamp(f))
       }
     
    -  private def decimalToTimestamp(d: BigDecimal) = {
    +  private[this]  def decimalToTimestamp(d: BigDecimal) = {
         val seconds = d.longValue()
    --- End diff --
    
    According to the [API 
documentation](http://www.scala-lang.org/api/current/index.html#scala.math.BigDecimal@longValue():Long),
 `d.longValue()` and `d.intValue()` may be negative:
    
    ```scala
    scala> var d: BigDecimal = 1000000000
    d: BigDecimal = 1000000000
    
    scala> d = d * d * d * d
    d: BigDecimal = 1.000000000000000000000000000000000E+36
    
    scala> d.longValue
    res0: Long = -5527149226598858752
    ```


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---

Reply via email to