ulysses-you commented on code in PR #37207:
URL: https://github.com/apache/spark/pull/37207#discussion_r924043508


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala:
##########
@@ -247,3 +247,77 @@ case class DecimalAddNoOverflowCheck(
       newLeft: Expression, newRight: Expression): DecimalAddNoOverflowCheck =
     copy(left = newLeft, right = newRight)
 }
+
+/**
+ * A divide expression for decimal values which is only used internally by Avg.
+ *
+ * It will fail when nullOnOverflow is false follows:
+ *   - left (sum in avg) is null due to over the max precision 38,
+ *     the right (count in avg) should never be null
+ *   - the result of divide is overflow
+ */
+case class DecimalDivideWithOverflowCheck(
+    left: Expression,
+    right: Expression,
+    override val dataType: DecimalType,
+    avgQueryContext: String,
+    nullOnOverflow: Boolean) extends BinaryExpression with SupportQueryContext 
{
+  override def nullable: Boolean = nullOnOverflow
+  override def initQueryContext(): String = avgQueryContext
+  def decimalMethod: String = "$div"
+
+  override def eval(input: InternalRow): Any = {
+    val value1 = left.eval(input)
+    if (value1 == null) {
+      if (nullOnOverflow)  {
+        null
+      } else {
+        throw QueryExecutionErrors.overflowInSumOfDecimalError(queryContext)
+      }
+    } else {
+      val value2 = right.eval(input)
+      dataType.fractional.asInstanceOf[Fractional[Any]].div(value1, 
value2).asInstanceOf[Decimal]
+        .toPrecision(dataType.precision, dataType.scale, 
Decimal.ROUND_HALF_UP, nullOnOverflow,
+          queryContext)
+    }
+  }
+
+  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+    val errorContextCode = if (nullOnOverflow) {
+      "\"\""
+    } else {
+      ctx.addReferenceObj("errCtx", queryContext)
+    }
+    val nullHandling = if (nullOnOverflow) {
+      ""
+    } else {
+      s"throw 
QueryExecutionErrors.overflowInSumOfDecimalError($errorContextCode);"
+    }
+
+    val eval1 = left.genCode(ctx)
+    val eval2 = right.genCode(ctx)
+
+    // scalastyle:off line.size.limit
+    val code =
+      code"""
+         |${eval1.code}
+         |${eval2.code}
+         |boolean ${ev.isNull} = ${eval1.isNull};
+         |${CodeGenerator.javaType(dataType)} ${ev.value} = 
${CodeGenerator.defaultValue(dataType)};
+         |if (${eval1.isNull}) {
+         |  $nullHandling
+         |} else {
+         |  ${ev.value} = 
${eval1.value}.$decimalMethod(${eval2.value}).toPrecision(

Review Comment:
   got it ! addressed 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to