Github user rxin commented on a diff in the pull request:

    https://github.com/apache/spark/pull/6938#discussion_r34649492
  
    --- Diff: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
 ---
    @@ -520,3 +522,202 @@ case class Logarithm(left: Expression, right: 
Expression)
         """
       }
     }
    +
    +/**
    + * Round the `child`'s result to `scale` decimal place when `scale` >= 0
    + * or round at integral part when `scale` < 0.
    + * For example, round(31.415, 2) would eval to 31.42 and round(31.415, -1) 
would eval to 30.
    + *
    + * Child of IntegralType would eval to itself when `scale` >= 0.
    + * Child of FractionalType whose value is NaN or Infinite would always 
eval to itself.
    + *
    + * Round's dataType would always equal to `child`'s dataType except for 
[[DecimalType.Fixed]],
    + * which leads to scale update in DecimalType's [[PrecisionInfo]]
    + *
    + * @param child expr to be round, all [[NumericType]] is allowed as Input
    + * @param scale new scale to be round to, this should be a constant int at 
runtime
    + */
    +case class Round(child: Expression, scale: Expression)
    +  extends BinaryExpression with ExpectsInputTypes {
    +
    +  import BigDecimal.RoundingMode.HALF_UP
    +
    +  def this(child: Expression) = this(child, Literal(0))
    +
    +  override def left: Expression = child
    +  override def right: Expression = scale
    +
    +  // round of Decimal would eval to null if it fails to `changePrecision`
    +  override def nullable: Boolean = true
    +
    +  override def foldable: Boolean = child.foldable
    +
    +  override lazy val dataType: DataType = child.dataType match {
    +    // if the new scale is bigger which means we are scaling up,
    +    // keep the original scale as `Decimal` does
    +    case DecimalType.Fixed(p, s) => DecimalType(p, if (_scale > s) s else 
_scale)
    +    case t => t
    +  }
    +
    +  override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, 
IntegerType)
    +
    +  override def checkInputDataTypes(): TypeCheckResult = {
    +    super.checkInputDataTypes() match {
    +      case TypeCheckSuccess =>
    +        if (scale.foldable) {
    +          TypeCheckSuccess
    +        } else {
    +          TypeCheckFailure("Only foldable Expression is allowed for scale 
arguments")
    +        }
    +      case f => f
    +    }
    +  }
    +
    +  // Avoid repeated evaluation since `scale` is a constant int,
    +  // avoid unnecessary `child` evaluation in both codegen and non-codegen 
eval
    +  // by checking if scaleV == null as well.
    +  private lazy val scaleV: Any = scale.eval(EmptyRow)
    +  private lazy val _scale: Int = scaleV.asInstanceOf[Int]
    +
    +  override def eval(input: InternalRow): Any = {
    +    if (scaleV == null) { // if scale is null, no need to eval its child 
at all
    +      null
    +    } else {
    +      val evalE = child.eval(input)
    +      if (evalE == null) {
    +        null
    +      } else {
    +        nullSafeEval(evalE)
    +      }
    +    }
    +  }
    +
    +  // not overriding since _scale is a constant int at runtime
    +  def nullSafeEval(input1: Any): Any = {
    +    child.dataType match {
    +      case _: DecimalType =>
    +        val decimal = input1.asInstanceOf[Decimal]
    +        if (decimal.changePrecision(decimal.precision, _scale)) decimal 
else null
    +      case ByteType =>
    +        BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, 
HALF_UP).toByte
    +      case ShortType =>
    +        BigDecimal(input1.asInstanceOf[Short]).setScale(_scale, 
HALF_UP).toShort
    +      case IntegerType =>
    +        BigDecimal(input1.asInstanceOf[Int]).setScale(_scale, 
HALF_UP).toInt
    +      case LongType =>
    +        BigDecimal(input1.asInstanceOf[Long]).setScale(_scale, 
HALF_UP).toLong
    +      case FloatType =>
    +        val f = input1.asInstanceOf[Float]
    +        if (f.isNaN || f.isInfinite) {
    +          f
    +        } else {
    +          BigDecimal(f).setScale(_scale, HALF_UP).toFloat
    +        }
    +      case DoubleType =>
    +        val d = input1.asInstanceOf[Double]
    +        if (d.isNaN || d.isInfinite) {
    +          d
    +        } else {
    +          BigDecimal(d).setScale(_scale, HALF_UP).toDouble
    +        }
    +    }
    +  }
    +
    +  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): 
String = {
    +    val ce = child.gen(ctx)
    +
    +    val evaluationCode = child.dataType match {
    +      case _: DecimalType =>
    +        s"""
    +        if (${ce.primitive}.changePrecision(${ce.primitive}.precision(), 
${_scale})) {
    +          ${ev.primitive} = ${ce.primitive};
    +        } else {
    +          ${ev.isNull} = true;
    +        }"""
    +      case ByteType =>
    +        if (_scale < 0) {
    +          s"""
    +          ${ev.primitive} = new java.math.BigDecimal(${ce.primitive}).
    +            setScale(${_scale}, 
java.math.BigDecimal.ROUND_HALF_UP).byteValue();"""
    +        } else {
    +          s"${ev.primitive} = ${ce.primitive};"
    +        }
    +      case ShortType =>
    +        if (_scale < 0) {
    +          s"""
    +          ${ev.primitive} = new java.math.BigDecimal(${ce.primitive}).
    +            setScale(${_scale}, 
java.math.BigDecimal.ROUND_HALF_UP).shortValue();"""
    +        } else {
    +          s"${ev.primitive} = ${ce.primitive};"
    +        }
    +      case IntegerType =>
    +        if (_scale < 0) {
    +          s"""
    +          ${ev.primitive} = new java.math.BigDecimal(${ce.primitive}).
    +            setScale(${_scale}, 
java.math.BigDecimal.ROUND_HALF_UP).intValue();"""
    +        } else {
    +          s"${ev.primitive} = ${ce.primitive};"
    +        }
    +      case LongType =>
    +        if (_scale < 0) {
    +          s"""
    +          ${ev.primitive} = new java.math.BigDecimal(${ce.primitive}).
    +            setScale(${_scale}, 
java.math.BigDecimal.ROUND_HALF_UP).longValue();"""
    +        } else {
    +          s"${ev.primitive} = ${ce.primitive};"
    +        }
    +      case FloatType => // if child eval to NaN or Infinity, just return 
it.
    +        if (_scale == 0) {
    +          s"""
    +            if (Float.isNaN(${ce.primitive}) || 
Float.isInfinite(${ce.primitive})){
    +              ${ev.primitive} = ${ce.primitive};
    +            } else {
    +              ${ev.primitive} = Math.round(${ce.primitive});
    +            }"""
    +        } else {
    +          s"""
    +            if (Float.isNaN(${ce.primitive}) || 
Float.isInfinite(${ce.primitive})){
    +              ${ev.primitive} = ${ce.primitive};
    +            } else {
    +              ${ev.primitive} = 
java.math.BigDecimal.valueOf(${ce.primitive}).
    +                setScale(${_scale}, 
java.math.BigDecimal.ROUND_HALF_UP).floatValue();
    +            }"""
    +        }
    +      case DoubleType => // if child eval to NaN or Infinity, just return 
it.
    +        if (_scale == 0) {
    +          s"""
    +            if (Double.isNaN(${ce.primitive}) || 
Double.isInfinite(${ce.primitive})){
    +              ${ev.primitive} = ${ce.primitive};
    +            } else {
    +              ${ev.primitive} = Math.round(${ce.primitive});
    +            }"""
    +        } else {
    +          s"""
    +            if (Double.isNaN(${ce.primitive}) || 
Double.isInfinite(${ce.primitive})){
    +              ${ev.primitive} = ${ce.primitive};
    +            } else {
    +              ${ev.primitive} = 
java.math.BigDecimal.valueOf(${ce.primitive}).
    +                setScale(${_scale}, 
java.math.BigDecimal.ROUND_HALF_UP).doubleValue();
    +            }"""
    +        }
    +    }
    +
    +    if (scaleV == null) { // if scale is null, no need to eval its child 
at all
    +      s"""
    +        boolean ${ev.isNull} = true;
    +        ${ctx.javaType(dataType)} ${ev.primitive} = 
${ctx.defaultValue(dataType)};
    +      """
    +    } else {
    +      s"""
    +        ${ce.code}
    +        boolean ${ev.isNull} = ${ce.isNull};
    +        ${ctx.javaType(dataType)} ${ev.primitive} = 
${ctx.defaultValue(dataType)};
    +        if (!${ev.isNull}) {
    +          $evaluationCode
    +        }
    +      """
    +    }
    +  }
    +
    +  override def prettyName: String = "round"
    --- End diff --
    
    you can remove this, since the expression is already named Round


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to