vinodkc commented on code in PR #38419:
URL: https://github.com/apache/spark/pull/38419#discussion_r1089829294
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala:
##########
@@ -331,6 +332,247 @@ case class RoundCeil(child: Expression, scale: Expression)
copy(child = newLeft, scale = newRight)
}
+case class TruncNumber(child: Expression, scale: Expression)
+ extends BaseBinaryExpression with NullIntolerant {
+
+ override protected def withNewChildrenInternal(
+ newLeft: Expression,
+ newRight: Expression): TruncNumber = copy(child = newLeft, scale =
newRight)
+
+ /**
+ * Returns Java source code that can be compiled to evaluate this
expression. The default
+ * behavior is to call the eval method of the expression. Concrete
expression implementations
+ * should override this to do actual code generation.
+ *
+ * @param ctx
+ * a [[CodegenContext]]
+ * @param ev
+ * an [[ExprCode]] with unique terms.
+ * @return
+ * an [[ExprCode]] containing the Java source code to generate the given
expression
+ */
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode):
ExprCode =
+ defineCodeGen(
+ ctx,
+ ev,
+ (input, _) => {
+ dataType match {
+ case ByteType if (scaleValue <= 0) =>
+
s"""(byte)(org.apache.spark.sql.catalyst.expressions.TruncNumber.trunc(
+ |(long)$input, $scaleValue))""".stripMargin
+ case ShortType if (scaleValue <= 0) =>
+
s"""(short)(org.apache.spark.sql.catalyst.expressions.TruncNumber.trunc(
+ |(long)$input, $scaleValue))""".stripMargin
+ case IntegerType if (scaleValue <= 0) =>
+
s"""(int)(org.apache.spark.sql.catalyst.expressions.TruncNumber.trunc(
+ |(long)$input, $scaleValue))""".stripMargin
+ case LongType if (scaleValue <= 0) =>
+ s"""(org.apache.spark.sql.catalyst.expressions.TruncNumber.trunc(
+ |$input, $scaleValue))""".stripMargin
+ case FloatType if (scaleValue <= 0) =>
+ s"""org.apache.spark.sql.catalyst.expressions.TruncNumber.trunc(
+ |$input, $scaleValue).floatValue()""".stripMargin
+ case DoubleType if (scaleValue <= 0) =>
+ s"""org.apache.spark.sql.catalyst.expressions.TruncNumber.trunc(
+ |$input, $scaleValue).doubleValue()""".stripMargin
+ case DecimalType.Fixed(p, s) =>
+ s"""Decimal.apply(
+ |org.apache.spark.sql.catalyst.expressions.TruncNumber.trunc(
+ |${input}.toJavaBigDecimal(), $scaleValue), $p, $s)""".stripMargin
+ case _ => s"$input"
+ }
+ })
+
+ /**
+ * Returns the [[DataType]] of the result of evaluating this expression. It
is invalid to query
+ * the dataType of an unresolved expression (i.e., when `resolved` == false).
+ */
+ override lazy val dataType: DataType = child.dataType
+
+ /**
+ * Called by default [[eval]] implementation. If subclass of
BinaryExpression keep the default
+ * nullability, they can override this method to save null-check code. If we
need full control
+ * of evaluation process, we should override [[eval]].
+ */
+ override protected def nullSafeEval(input1: Any, input2: Any): Any = {
+ (dataType, input1) match {
+ case (ByteType, input: Byte) if (scaleValue <= 0) =>
+ TruncNumber.trunc(input.toLong, scaleValue).toByte
+ case (ShortType, input: Short) if (scaleValue <= 0) =>
+ TruncNumber.trunc(input.toLong, scaleValue).shortValue
+ case (IntegerType, input: Int) if (scaleValue <= 0) =>
+ TruncNumber.trunc(input.toLong, scaleValue).intValue
+ case (LongType, input: Long) if (scaleValue <= 0) =>
+ TruncNumber.trunc(input, scaleValue).longValue
+ case (FloatType, input: Float) =>
+ TruncNumber.trunc(input, scaleValue).floatValue
+ case (DoubleType, input: Double) =>
+ TruncNumber.trunc(input, scaleValue).doubleValue
+ case (DecimalType.Fixed(p, s), input: Decimal) =>
+ Decimal(TruncNumber.trunc(input.toJavaBigDecimal, scaleValue), p, s)
+ case _ => input1
+ }
+ }
+}
+
+object TruncNumber {
+ /**
+ * To truncate whole numbers; byte, short, int, and long types.
+ */
+ def trunc(input: Long, position: Int): Long = {
+ if (position >= 0) {
+ input
+ } else {
+ // Here we truncate the number by the absolute value of the position.
+ // For example, if the input is 123 and the scale is -2, then the result
is 100.
+ val pow = Math.pow(10, Math.abs(position)).toLong
+ (input / pow) * pow
+ }
+ }
+
+ /**
+ * To truncate double and float type
+ */
+ def trunc(input: Double, position: Int): BigDecimal = {
+ trunc(jm.BigDecimal.valueOf(input), position)
+ }
+
+ /**
+ * To truncate decimal type
+ */
+ def trunc(input: jm.BigDecimal, position: Int): jm.BigDecimal = {
+ if (input.scale < position) {
+ input
+ } else {
+ val wholePart = input.toBigInteger
+ if (position > 0) {
+ // Here we truncate only the decimal part by the value of the position.
+ val decimalPart = input.remainder(java.math.BigDecimal.ONE)
+ // To avoid overflow during multiplication, we extract the decimal
part first,
+ // truncate it and then add the whole part.
+ // For example, if the input is 123.456 and the scale is 2, the result
should be 123.45.
+ if (jm.BigDecimal.ZERO.compareTo(decimalPart) == 0) {
+ new jm.BigDecimal(wholePart)
+ } else {
+ val pow = jm.BigDecimal.valueOf(Math.pow(10, position).toLong)
+ val newRemainder = new
jm.BigDecimal(decimalPart.multiply(pow).toBigInteger).divide(pow)
+ new jm.BigDecimal(wholePart).add(newRemainder)
+ }
+ } else if (position == 0) {
+ // The position is zero, so we extract the whole part.
+ // For example, if the input is 123.456 and the scale is 0, the result
is 123.
+ new jm.BigDecimal(wholePart)
+ } else {
+ // Here we truncate the whole part by the absolute value of the
position.
+ // For example, if the input is 123.456 and the scale is -2, the
result is 100.
+ if (jm.BigInteger.ZERO.compareTo(wholePart) == 0) {
Review Comment:
refactored to avoid code duplication
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala:
##########
@@ -331,6 +332,247 @@ case class RoundCeil(child: Expression, scale: Expression)
copy(child = newLeft, scale = newRight)
}
+case class TruncNumber(child: Expression, scale: Expression)
+ extends BaseBinaryExpression with NullIntolerant {
+
+ override protected def withNewChildrenInternal(
+ newLeft: Expression,
+ newRight: Expression): TruncNumber = copy(child = newLeft, scale =
newRight)
+
+ /**
+ * Returns Java source code that can be compiled to evaluate this
expression. The default
+ * behavior is to call the eval method of the expression. Concrete
expression implementations
+ * should override this to do actual code generation.
+ *
+ * @param ctx
+ * a [[CodegenContext]]
+ * @param ev
+ * an [[ExprCode]] with unique terms.
+ * @return
+ * an [[ExprCode]] containing the Java source code to generate the given
expression
+ */
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode):
ExprCode =
+ defineCodeGen(
+ ctx,
+ ev,
+ (input, _) => {
+ dataType match {
+ case ByteType if (scaleValue <= 0) =>
+
s"""(byte)(org.apache.spark.sql.catalyst.expressions.TruncNumber.trunc(
+ |(long)$input, $scaleValue))""".stripMargin
+ case ShortType if (scaleValue <= 0) =>
+
s"""(short)(org.apache.spark.sql.catalyst.expressions.TruncNumber.trunc(
+ |(long)$input, $scaleValue))""".stripMargin
+ case IntegerType if (scaleValue <= 0) =>
+
s"""(int)(org.apache.spark.sql.catalyst.expressions.TruncNumber.trunc(
+ |(long)$input, $scaleValue))""".stripMargin
+ case LongType if (scaleValue <= 0) =>
+ s"""(org.apache.spark.sql.catalyst.expressions.TruncNumber.trunc(
+ |$input, $scaleValue))""".stripMargin
+ case FloatType if (scaleValue <= 0) =>
+ s"""org.apache.spark.sql.catalyst.expressions.TruncNumber.trunc(
+ |$input, $scaleValue).floatValue()""".stripMargin
+ case DoubleType if (scaleValue <= 0) =>
+ s"""org.apache.spark.sql.catalyst.expressions.TruncNumber.trunc(
+ |$input, $scaleValue).doubleValue()""".stripMargin
+ case DecimalType.Fixed(p, s) =>
+ s"""Decimal.apply(
+ |org.apache.spark.sql.catalyst.expressions.TruncNumber.trunc(
+ |${input}.toJavaBigDecimal(), $scaleValue), $p, $s)""".stripMargin
+ case _ => s"$input"
+ }
+ })
+
+ /**
+ * Returns the [[DataType]] of the result of evaluating this expression. It
is invalid to query
+ * the dataType of an unresolved expression (i.e., when `resolved` == false).
+ */
+ override lazy val dataType: DataType = child.dataType
+
+ /**
+ * Called by default [[eval]] implementation. If subclass of
BinaryExpression keep the default
+ * nullability, they can override this method to save null-check code. If we
need full control
+ * of evaluation process, we should override [[eval]].
+ */
+ override protected def nullSafeEval(input1: Any, input2: Any): Any = {
+ (dataType, input1) match {
+ case (ByteType, input: Byte) if (scaleValue <= 0) =>
+ TruncNumber.trunc(input.toLong, scaleValue).toByte
+ case (ShortType, input: Short) if (scaleValue <= 0) =>
+ TruncNumber.trunc(input.toLong, scaleValue).shortValue
+ case (IntegerType, input: Int) if (scaleValue <= 0) =>
+ TruncNumber.trunc(input.toLong, scaleValue).intValue
+ case (LongType, input: Long) if (scaleValue <= 0) =>
+ TruncNumber.trunc(input, scaleValue).longValue
+ case (FloatType, input: Float) =>
+ TruncNumber.trunc(input, scaleValue).floatValue
+ case (DoubleType, input: Double) =>
+ TruncNumber.trunc(input, scaleValue).doubleValue
+ case (DecimalType.Fixed(p, s), input: Decimal) =>
+ Decimal(TruncNumber.trunc(input.toJavaBigDecimal, scaleValue), p, s)
+ case _ => input1
+ }
+ }
+}
+
+object TruncNumber {
+ /**
+ * To truncate whole numbers; byte, short, int, and long types.
+ */
+ def trunc(input: Long, position: Int): Long = {
+ if (position >= 0) {
+ input
+ } else {
+ // Here we truncate the number by the absolute value of the position.
+ // For example, if the input is 123 and the scale is -2, then the result
is 100.
+ val pow = Math.pow(10, Math.abs(position)).toLong
+ (input / pow) * pow
+ }
+ }
+
+ /**
+ * To truncate double and float type
+ */
+ def trunc(input: Double, position: Int): BigDecimal = {
+ trunc(jm.BigDecimal.valueOf(input), position)
+ }
+
+ /**
+ * To truncate decimal type
+ */
+ def trunc(input: jm.BigDecimal, position: Int): jm.BigDecimal = {
+ if (input.scale < position) {
+ input
+ } else {
+ val wholePart = input.toBigInteger
+ if (position > 0) {
+ // Here we truncate only the decimal part by the value of the position.
+ val decimalPart = input.remainder(java.math.BigDecimal.ONE)
+ // To avoid overflow during multiplication, we extract the decimal
part first,
+ // truncate it and then add the whole part.
+ // For example, if the input is 123.456 and the scale is 2, the result
should be 123.45.
+ if (jm.BigDecimal.ZERO.compareTo(decimalPart) == 0) {
+ new jm.BigDecimal(wholePart)
+ } else {
+ val pow = jm.BigDecimal.valueOf(Math.pow(10, position).toLong)
+ val newRemainder = new
jm.BigDecimal(decimalPart.multiply(pow).toBigInteger).divide(pow)
+ new jm.BigDecimal(wholePart).add(newRemainder)
+ }
+ } else if (position == 0) {
+ // The position is zero, so we extract the whole part.
+ // For example, if the input is 123.456 and the scale is 0, the result
is 123.
+ new jm.BigDecimal(wholePart)
+ } else {
+ // Here we truncate the whole part by the absolute value of the
position.
+ // For example, if the input is 123.456 and the scale is -2, the
result is 100.
+ if (jm.BigInteger.ZERO.compareTo(wholePart) == 0) {
Review Comment:
Refactored to avoid code duplication
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]