ulysses-you commented on code in PR #37207:
URL: https://github.com/apache/spark/pull/37207#discussion_r924039193


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala:
##########
@@ -247,3 +247,77 @@ case class DecimalAddNoOverflowCheck(
       newLeft: Expression, newRight: Expression): DecimalAddNoOverflowCheck =
     copy(left = newLeft, right = newRight)
 }
+
+/**
+ * A divide expression for decimal values which is only used internally by Avg.
+ *
+ * It will fail when nullOnOverflow is false follows:
+ *   - left (sum in avg) is null due to over the max precision 38,
+ *     the right (count in avg) should never be null
+ *   - the result of divide is overflow
+ */
+case class DecimalDivideWithOverflowCheck(
+    left: Expression,
+    right: Expression,
+    override val dataType: DecimalType,
+    avgQueryContext: String,
+    nullOnOverflow: Boolean) extends BinaryExpression with SupportQueryContext 
{
+  override def nullable: Boolean = nullOnOverflow
+  override def initQueryContext(): String = avgQueryContext
+  def decimalMethod: String = "$div"
+
+  override def eval(input: InternalRow): Any = {
+    val value1 = left.eval(input)
+    if (value1 == null) {
+      if (nullOnOverflow)  {
+        null
+      } else {
+        throw QueryExecutionErrors.overflowInSumOfDecimalError(queryContext)
+      }
+    } else {
+      val value2 = right.eval(input)
+      dataType.fractional.asInstanceOf[Fractional[Any]].div(value1, 
value2).asInstanceOf[Decimal]
+        .toPrecision(dataType.precision, dataType.scale, 
Decimal.ROUND_HALF_UP, nullOnOverflow,
+          queryContext)
+    }
+  }
+
+  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+    val errorContextCode = if (nullOnOverflow) {
+      "\"\""
+    } else {
+      ctx.addReferenceObj("errCtx", queryContext)
+    }
+    val nullHandling = if (nullOnOverflow) {
+      ""
+    } else {
+      s"throw 
QueryExecutionErrors.overflowInSumOfDecimalError($errorContextCode);"
+    }
+
+    val eval1 = left.genCode(ctx)
+    val eval2 = right.genCode(ctx)
+
+    // scalastyle:off line.size.limit
+    val code =
+      code"""
+         |${eval1.code}
+         |${eval2.code}
+         |boolean ${ev.isNull} = ${eval1.isNull};
+         |${CodeGenerator.javaType(dataType)} ${ev.value} = 
${CodeGenerator.defaultValue(dataType)};
+         |if (${eval1.isNull}) {
+         |  $nullHandling
+         |} else {
+         |  ${ev.value} = 
${eval1.value}.$decimalMethod(${eval2.value}).toPrecision(

Review Comment:
   sorry, not quite follow the suggestion. This expression is internally used 
only by average so 
   - no caller side will invoke this checkInputDataTypes method
   - the only chance it can fail is we write a bug code at average
   
   but if you think it's necessary, I can add some code like:
   ```
     override def checkInputDataTypes(): TypeCheckResult = (left.dataType, 
right.dataType) match {
       case (l: DecimalType, r: DecimalType) if inputType.acceptsType(l) && 
inputType.acceptsType(r) =>
         TypeCheckResult.TypeCheckSuccess
       case _ => throw new IllegalStateException("Both left and right child 
must be decimal type.")
     }
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to