vinodkc commented on code in PR #38419:
URL: https://github.com/apache/spark/pull/38419#discussion_r1096439062


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala:
##########
@@ -331,6 +332,275 @@ case class RoundCeil(child: Expression, scale: Expression)
     copy(child = newLeft, scale = newRight)
 }
 
+
+/**
+ * Truncates a number to the specified number of digits.
+ * @param child
+ *   expression to get the number to be truncated.
+ * @param scale
+ *   expression to get the number of decimal places to truncate to.
+ */
+case class TruncNumber(child: Expression, scale: Expression)
+    extends BaseBinaryExpression
+    with NullIntolerant {
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression,
+      newRight: Expression): TruncNumber = copy(child = newLeft, scale = 
newRight)
+
+  /**
+   * Returns the [[DataType]] of the result of evaluating this expression. It 
is invalid to query
+   * the dataType of an unresolved expression (i.e., when `resolved` == false).
+   */
+  override lazy val dataType: DataType = child.dataType
+
+  /**
+   * This overridden implementation delegates the overloaded TruncNumber.trunc 
methods based on
+   * data type of input values
+   */
+  override protected def nullSafeEval(input1: Any, input2: Any): Any = {
+    (dataType, input1) match {
+      // Trunc function accepts a second parameter to truncate the input 
number.
+      // If 0, it removes all the decimal values and returns only the integer.
+      // If negative, the number is truncated to the left side of the decimal 
point.
+      // Value  of decimal places to truncate can range from -ve to +ve
+      // 1) In the case of integral numbers, as there is no decimal part if 
the value of decimal
+      // places to truncate is +ve, then we can return that input value 
without any
+      // modification as there is no +ve decimal place to be truncated from an 
integral number
+      // Truncate the input only if the value of decimal places to truncate is 
< 0
+      case (ByteType, input: Byte) if (scaleValue < 0) =>
+        TruncNumber.trunc(input.toLong, scaleValue).toByte
+      case (ShortType, input: Short) if (scaleValue < 0) =>
+        TruncNumber.trunc(input.toLong, scaleValue).shortValue
+      case (IntegerType, input: Int) if (scaleValue < 0) =>
+        TruncNumber.trunc(input.toLong, scaleValue).intValue
+      case (LongType, input: Long) if (scaleValue < 0) =>
+        TruncNumber.trunc(input, scaleValue).longValue
+      // 2) In the case of Float, Double, and Decimal , TruncNumber.trunc
+      // will accept both -ve and +ve values
+      case (FloatType, input: Float) =>
+        TruncNumber.trunc(input, scaleValue).floatValue
+      case (DoubleType, input: Double) =>
+        TruncNumber.trunc(input, scaleValue).doubleValue
+      case (DecimalType.Fixed(p, s), input: Decimal) =>
+        Decimal(TruncNumber.trunc(input.toJavaBigDecimal, scaleValue), p, s)
+      case _ => input1
+    }
+  }
+
+  /**
+   * Returns Java source code that can be compiled to evaluate this expression.
+   * This overridden implementation delegates the overloaded TruncNumber.trunc 
methods based on
+   * data type of input values
+   * @param ctx
+   *   a [[CodegenContext]]
+   * @param ev
+   *   an [[ExprCode]] with unique terms.
+   * @return
+   *   an [[ExprCode]] containing the Java source code to generate the given 
expression
+   */
+  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): 
ExprCode =
+    defineCodeGen(
+      ctx,
+      ev,
+      (input, _) => {
+        val methodName = 
"org.apache.spark.sql.catalyst.expressions.TruncNumber.trunc"
+        // Trunc function accepts a second parameter to truncate the input 
number.
+        // If 0, it removes all the decimal values and returns only the 
integer.
+        // If negative, the number is truncated to the left side of the 
decimal point.
+        // Value  of decimal places to truncate can range from -ve to +ve
+        // 1) In the case of integral numbers, as there is no decimal part if 
the value of decimal
+        // places to truncate is +ve, then we can return that input value 
without any
+        // modification as there is no +ve decimal place to be truncated from 
an integral number
+        // Truncate the input only if the value of decimal places to truncate 
is < 0
+        dataType match {
+          case ByteType if (scaleValue < 0) =>
+            s"""(byte)($methodName(
+               |(long)$input, $scaleValue))""".stripMargin
+          case ShortType if (scaleValue < 0) =>
+            s"""(short)($methodName(
+               |(long)$input, $scaleValue))""".stripMargin
+          case IntegerType if (scaleValue < 0) =>
+            s"""(int)($methodName(
+               |(long)$input, $scaleValue))""".stripMargin
+          case LongType if (scaleValue < 0) =>
+            s"""($methodName(
+               |$input, $scaleValue))""".stripMargin
+          // 2) In the case of Float, Double, and Decimal , TruncNumber.trunc
+          // will accept both -ve and +ve values
+          case FloatType =>
+            s"""$methodName(
+               |$input, $scaleValue).floatValue()""".stripMargin
+          case DoubleType =>
+            s"""$methodName(
+               |$input, $scaleValue).doubleValue()""".stripMargin
+          case DecimalType.Fixed(p, s) =>
+            s"""Decimal.apply(
+               |$methodName(
+               |${input}.toJavaBigDecimal(), $scaleValue), $p, 
$s)""".stripMargin
+          case _ => s"$input"
+        }
+      })
+}
+
+object TruncNumber {
+
+  /**
+   * To truncate whole numbers; byte, short, int, and long types.
+   */
+  def trunc(input: Long, position: Int): Long = {
+    if (position >= 0) {
+      input
+    } else {
+      // Here we truncate the number by the absolute value of the position.
+      // For example, if the input is 123 and the scale is -2, then the result 
is 100.
+      val pow = Math.pow(10, Math.abs(position)).toLong
+      (input / pow) * pow
+    }
+  }
+
+  /**
+   * To truncate double and float type.
+   */
+  def trunc(input: Double, position: Int): BigDecimal = {
+    trunc(jm.BigDecimal.valueOf(input), position)
+  }
+
+  /**
+   * To truncate decimal type.
+   */
+  def trunc(input: jm.BigDecimal, position: Int): jm.BigDecimal = {
+    if (input.scale < position) input
+    else {
+      val wholePart = input.toBigInteger
+      position match {
+        case pos if pos >= 0 =>
+          // Here we truncate only the decimal part by the value of the 
position.
+          val decimalPart = input.remainder(java.math.BigDecimal.ONE)
+          // If the position is zero OR Decimal part is zero,
+          // we extract the whole part and return it.
+          // For example,
+          // if the input is 123.456 and the scale is 0, the result will be 
123.
+          // if the input is 123.000 and the scale is > 0, the result will be 
123.
+          val wholePartBD = new jm.BigDecimal(wholePart)
+          if (pos == 0 || jm.BigDecimal.ZERO.compareTo(decimalPart) == 0) {
+            wholePartBD
+          } else {
+            // To avoid overflow during multiplication, we extract the decimal 
part from the input,
+            // truncate it and then append it to the whole part.
+            // For example, if the input is 123.456 and the scale is 2, the 
result will be 123.45.
+            val pow = jm.BigDecimal.valueOf(Math.pow(10, pos).toLong)

Review Comment:
   Refactored



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala:
##########
@@ -331,6 +332,275 @@ case class RoundCeil(child: Expression, scale: Expression)
     copy(child = newLeft, scale = newRight)
 }
 
+
+/**
+ * Truncates a number to the specified number of digits.
+ * @param child
+ *   expression to get the number to be truncated.
+ * @param scale
+ *   expression to get the number of decimal places to truncate to.
+ */
+case class TruncNumber(child: Expression, scale: Expression)
+    extends BaseBinaryExpression
+    with NullIntolerant {
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression,
+      newRight: Expression): TruncNumber = copy(child = newLeft, scale = 
newRight)
+
+  /**
+   * Returns the [[DataType]] of the result of evaluating this expression. It 
is invalid to query
+   * the dataType of an unresolved expression (i.e., when `resolved` == false).
+   */
+  override lazy val dataType: DataType = child.dataType
+
+  /**
+   * This overridden implementation delegates the overloaded TruncNumber.trunc 
methods based on
+   * data type of input values
+   */
+  override protected def nullSafeEval(input1: Any, input2: Any): Any = {
+    (dataType, input1) match {
+      // Trunc function accepts a second parameter to truncate the input 
number.
+      // If 0, it removes all the decimal values and returns only the integer.
+      // If negative, the number is truncated to the left side of the decimal 
point.
+      // Value  of decimal places to truncate can range from -ve to +ve
+      // 1) In the case of integral numbers, as there is no decimal part if 
the value of decimal
+      // places to truncate is +ve, then we can return that input value 
without any
+      // modification as there is no +ve decimal place to be truncated from an 
integral number
+      // Truncate the input only if the value of decimal places to truncate is 
< 0
+      case (ByteType, input: Byte) if (scaleValue < 0) =>
+        TruncNumber.trunc(input.toLong, scaleValue).toByte
+      case (ShortType, input: Short) if (scaleValue < 0) =>
+        TruncNumber.trunc(input.toLong, scaleValue).shortValue
+      case (IntegerType, input: Int) if (scaleValue < 0) =>
+        TruncNumber.trunc(input.toLong, scaleValue).intValue
+      case (LongType, input: Long) if (scaleValue < 0) =>
+        TruncNumber.trunc(input, scaleValue).longValue
+      // 2) In the case of Float, Double, and Decimal , TruncNumber.trunc
+      // will accept both -ve and +ve values
+      case (FloatType, input: Float) =>
+        TruncNumber.trunc(input, scaleValue).floatValue
+      case (DoubleType, input: Double) =>
+        TruncNumber.trunc(input, scaleValue).doubleValue
+      case (DecimalType.Fixed(p, s), input: Decimal) =>
+        Decimal(TruncNumber.trunc(input.toJavaBigDecimal, scaleValue), p, s)
+      case _ => input1
+    }
+  }
+
+  /**
+   * Returns Java source code that can be compiled to evaluate this expression.
+   * This overridden implementation delegates the overloaded TruncNumber.trunc 
methods based on
+   * data type of input values
+   * @param ctx
+   *   a [[CodegenContext]]
+   * @param ev
+   *   an [[ExprCode]] with unique terms.
+   * @return
+   *   an [[ExprCode]] containing the Java source code to generate the given 
expression
+   */
+  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): 
ExprCode =
+    defineCodeGen(
+      ctx,
+      ev,
+      (input, _) => {
+        val methodName = 
"org.apache.spark.sql.catalyst.expressions.TruncNumber.trunc"
+        // Trunc function accepts a second parameter to truncate the input 
number.
+        // If 0, it removes all the decimal values and returns only the 
integer.
+        // If negative, the number is truncated to the left side of the 
decimal point.
+        // Value  of decimal places to truncate can range from -ve to +ve
+        // 1) In the case of integral numbers, as there is no decimal part if 
the value of decimal
+        // places to truncate is +ve, then we can return that input value 
without any
+        // modification as there is no +ve decimal place to be truncated from 
an integral number
+        // Truncate the input only if the value of decimal places to truncate 
is < 0
+        dataType match {
+          case ByteType if (scaleValue < 0) =>
+            s"""(byte)($methodName(
+               |(long)$input, $scaleValue))""".stripMargin
+          case ShortType if (scaleValue < 0) =>
+            s"""(short)($methodName(
+               |(long)$input, $scaleValue))""".stripMargin
+          case IntegerType if (scaleValue < 0) =>
+            s"""(int)($methodName(
+               |(long)$input, $scaleValue))""".stripMargin
+          case LongType if (scaleValue < 0) =>
+            s"""($methodName(
+               |$input, $scaleValue))""".stripMargin
+          // 2) In the case of Float, Double, and Decimal , TruncNumber.trunc
+          // will accept both -ve and +ve values
+          case FloatType =>
+            s"""$methodName(
+               |$input, $scaleValue).floatValue()""".stripMargin
+          case DoubleType =>
+            s"""$methodName(
+               |$input, $scaleValue).doubleValue()""".stripMargin
+          case DecimalType.Fixed(p, s) =>
+            s"""Decimal.apply(
+               |$methodName(
+               |${input}.toJavaBigDecimal(), $scaleValue), $p, 
$s)""".stripMargin
+          case _ => s"$input"
+        }
+      })
+}
+
+object TruncNumber {
+
+  /**
+   * To truncate whole numbers; byte, short, int, and long types.
+   */
+  def trunc(input: Long, position: Int): Long = {
+    if (position >= 0) {
+      input
+    } else {
+      // Here we truncate the number by the absolute value of the position.
+      // For example, if the input is 123 and the scale is -2, then the result 
is 100.
+      val pow = Math.pow(10, Math.abs(position)).toLong
+      (input / pow) * pow
+    }
+  }
+
+  /**
+   * To truncate double and float type.
+   */
+  def trunc(input: Double, position: Int): BigDecimal = {
+    trunc(jm.BigDecimal.valueOf(input), position)
+  }
+
+  /**
+   * To truncate decimal type.
+   */
+  def trunc(input: jm.BigDecimal, position: Int): jm.BigDecimal = {
+    if (input.scale < position) input
+    else {
+      val wholePart = input.toBigInteger
+      position match {
+        case pos if pos >= 0 =>
+          // Here we truncate only the decimal part by the value of the 
position.
+          val decimalPart = input.remainder(java.math.BigDecimal.ONE)
+          // If the position is zero OR Decimal part is zero,
+          // we extract the whole part and return it.
+          // For example,
+          // if the input is 123.456 and the scale is 0, the result will be 
123.
+          // if the input is 123.000 and the scale is > 0, the result will be 
123.
+          val wholePartBD = new jm.BigDecimal(wholePart)
+          if (pos == 0 || jm.BigDecimal.ZERO.compareTo(decimalPart) == 0) {
+            wholePartBD
+          } else {
+            // To avoid overflow during multiplication, we extract the decimal 
part from the input,
+            // truncate it and then append it to the whole part.
+            // For example, if the input is 123.456 and the scale is 2, the 
result will be 123.45.
+            val pow = jm.BigDecimal.valueOf(Math.pow(10, pos).toLong)
+            val truncated = new 
jm.BigDecimal(decimalPart.multiply(pow).toBigInteger).divide(pow)
+            wholePartBD.add(truncated)
+          }
+        case pos if pos < 0 =>
+          // Here we truncate the whole part by the absolute value of the 
position.
+          // For example, if the input is 123.456 and the scale is -2, the 
result will be 100.
+          val pow = jm.BigInteger.valueOf(Math.pow(10, Math.abs(pos)).toLong)

Review Comment:
   Done



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to