skambha commented on issue #27627: [WIP][SPARK-28067][SQL] Fix incorrect 
results for decimal aggregate sum by returning null on decimal overflow
URL: https://github.com/apache/spark/pull/27627#issuecomment-597404017
 
 
   >For aggregate with GROUP BY, we write out the sum value to unsafe row(the 
aggregate hash map) at each step. Do you what happens there? 
   
   Details are here: 
https://github.com/skambha/notes/blob/master/case3_exception.txt  Scenario has 
the 2 range and union and then your query from previous comment. 
   
   Below are some brief notes:  
   1. In this case, we write to an aggregate HashMap using UnsafeRow. The 
writing of the UnsafeRow to the hash map does not use the UnsafeRowWriter that 
has the issue with writing null for an overflowed decimal value.  
   (Side note, in the whole stage execution, we will still have write using 
UnsafeRowWriter but that processing happens later. (ie after the hashmap is 
built etc).  The error from the scenario mentioned is actually coming prior to 
that) 
   
   2. When building the hash map, it will use the _UnsafeRow.setDecimal_ call. 
   `agg_unsafeRowAggBuffer_0.setDecimal(0, agg_value_2, 38);`
   This setDecimal method does not check if the decimal value overflowed or 
not. 
   
   3. As it is building the hash map, if the key already exists in the hash 
map, it will retrieve it and then do the + with the new input row's value.   In 
this case, the method it uses to retrieve the decimal value from the hash map 
is `UnsafeRow.getDecimal`
   `agg_unsafeRowAggBuffer_0.getDecimal(0, 38, 18)`
   
   4. The UnsafeRow.getDecimal calls the `Decimal.set(decimal: BigDecimal, 
precision: Int, scale: Int)` via this`Decimal.apply(javaDecimal, precision, 
scale);`which actually does a check to see if the value is containable and if 
not it will throw an exception that you see.
   ```
       if (decimalVal.precision > precision) {
         throw new ArithmeticException(
           s"Decimal precision ${decimalVal.precision} exceeds max precision 
$precision")
       }
   ```
   5. In the example scenario that has the 11 rows in one and 1 row in the 
other partition, and the group by key is same.  
   ```
    val decimalStr = "1" + "0" * 19
   val df = spark.range(0, 1, 1, 1).union(spark.range(0, 11, 1, 1))
   df.select(expr(s"cast('$decimalStr' as decimal (38, 18)) as d"), 
lit(1).as("key"))
   .groupBy("key").agg(sum($"d")).show
   ```
   6. When it is building the hash map, it would have successfully added up to 
the 10 rows, and the sum at that point is  `100000000000000000000` which is a 
overflowed value.  And when it gets to the next row (11th row), when it fetches 
the value from the agg hash map, using the UnsafeRow.getDecimal(..) will throw 
an error as its value is not containable.
   
   `agg_unsafeRowAggBuffer_0.getDecimal(0, 38, 18)`, 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to