Github user liancheng commented on a diff in the pull request:
https://github.com/apache/spark/pull/1143#discussion_r14006911
--- Diff:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
---
@@ -24,72 +24,89 @@ import org.apache.spark.sql.catalyst.types._
/** Cast the child expression to the target data type. */
case class Cast(child: Expression, dataType: DataType) extends
UnaryExpression {
override def foldable = child.foldable
- def nullable = (child.dataType, dataType) match {
+
+ override def nullable = (child.dataType, dataType) match {
case (StringType, _: NumericType) => true
case (StringType, TimestampType) => true
case _ => child.nullable
}
+
override def toString = s"CAST($child, $dataType)"
type EvaluatedType = Any
- def nullOrCast[T](a: Any, func: T => Any): Any = if(a == null) {
- null
- } else {
- func(a.asInstanceOf[T])
- }
+ // [[func]] assumes the input is no longer null because eval already
does the null check.
+ @inline private[this] def buildCast[T](a: Any, func: T => Any): Any =
func(a.asInstanceOf[T])
// UDFToString
- def castToString: Any => Any = child.dataType match {
- case BinaryType => nullOrCast[Array[Byte]](_, new String(_, "UTF-8"))
- case _ => nullOrCast[Any](_, _.toString)
+ private[this] def castToString: Any => Any = child.dataType match {
+ case BinaryType => buildCast[Array[Byte]](_, new String(_, "UTF-8"))
+ case _ => buildCast[Any](_, _.toString)
}
// BinaryConverter
- def castToBinary: Any => Any = child.dataType match {
- case StringType => nullOrCast[String](_, _.getBytes("UTF-8"))
+ private[this] def castToBinary: Any => Any = child.dataType match {
+ case StringType => buildCast[String](_, _.getBytes("UTF-8"))
}
// UDFToBoolean
- def castToBoolean: Any => Any = child.dataType match {
- case StringType => nullOrCast[String](_, _.length() != 0)
- case TimestampType => nullOrCast[Timestamp](_, b => {(b.getTime() != 0
|| b.getNanos() != 0)})
- case LongType => nullOrCast[Long](_, _ != 0)
- case IntegerType => nullOrCast[Int](_, _ != 0)
- case ShortType => nullOrCast[Short](_, _ != 0)
- case ByteType => nullOrCast[Byte](_, _ != 0)
- case DecimalType => nullOrCast[BigDecimal](_, _ != 0)
- case DoubleType => nullOrCast[Double](_, _ != 0)
- case FloatType => nullOrCast[Float](_, _ != 0)
+ private[this] def castToBoolean: Any => Any = child.dataType match {
+ case StringType =>
+ buildCast[String](_, _.length() != 0)
+ case TimestampType =>
+ buildCast[Timestamp](_, b => b.getTime() != 0 || b.getNanos() != 0)
+ case LongType =>
+ buildCast[Long](_, _ != 0)
+ case IntegerType =>
+ buildCast[Int](_, _ != 0)
+ case ShortType =>
+ buildCast[Short](_, _ != 0)
+ case ByteType =>
+ buildCast[Byte](_, _ != 0)
+ case DecimalType =>
+ buildCast[BigDecimal](_, _ != 0)
+ case DoubleType =>
+ buildCast[Double](_, _ != 0)
+ case FloatType =>
+ buildCast[Float](_, _ != 0)
}
// TimestampConverter
- def castToTimestamp: Any => Any = child.dataType match {
- case StringType => nullOrCast[String](_, s => {
- // Throw away extra if more than 9 decimal places
- val periodIdx = s.indexOf(".");
- var n = s
- if (periodIdx != -1) {
- if (n.length() - periodIdx > 9) {
- n = n.substring(0, periodIdx + 10)
+ private[this] def castToTimestamp: Any => Any = child.dataType match {
+ case StringType =>
+ buildCast[String](_, s => {
+ // Throw away extra if more than 9 decimal places
+ val periodIdx = s.indexOf(".")
+ var n = s
+ if (periodIdx != -1) {
+ if (n.length() - periodIdx > 9) {
--- End diff --
How about merging these two `if` statements into 1 with `&&`?
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---