skambha commented on issue #27627: [WIP][SPARK-28067][SQL] Fix incorrect results for decimal aggregate sum by returning null on decimal overflow URL: https://github.com/apache/spark/pull/27627#issuecomment-597404017 >For aggregate with GROUP BY, we write out the sum value to unsafe row(the aggregate hash map) at each step. Do you what happens there? Details are here: https://github.com/skambha/notes/blob/master/case3_exception.txt Scenario has the 2 range and union and then your query from previous comment. Below are some brief notes: 1. In this case, we write to an aggregate HashMap using UnsafeRow. The writing of the UnsafeRow to the hash map does not use the UnsafeRowWriter that has the issue with writing null for an overflowed decimal value. (Side note, in the whole stage execution, we will still have write using UnsafeRowWriter but that processing happens later. (ie after the hashmap is built etc). The error from the scenario mentioned is actually coming prior to that) 2. When building the hash map, it will use the _UnsafeRow.setDecimal_ call. `agg_unsafeRowAggBuffer_0.setDecimal(0, agg_value_2, 38);` This setDecimal method does not check if the decimal value overflowed or not. 3. As it is building the hash map, if the key already exists in the hash map, it will retrieve it and then do the + with the new input row's value. In this case, the method it uses to retrieve the decimal value from the hash map is `UnsafeRow.getDecimal` `agg_unsafeRowAggBuffer_0.getDecimal(0, 38, 18)` 4. The UnsafeRow.getDecimal calls the `Decimal.set(decimal: BigDecimal, precision: Int, scale: Int)` via this`Decimal.apply(javaDecimal, precision, scale);`which actually does a check to see if the value is containable and if not it will throw an exception that you see. ``` if (decimalVal.precision > precision) { throw new ArithmeticException( s"Decimal precision ${decimalVal.precision} exceeds max precision $precision") } ``` 5. In the example scenario that has the 11 rows in one and 1 row in the other partition, and the group by key is same. ``` val decimalStr = "1" + "0" * 19 val df = spark.range(0, 1, 1, 1).union(spark.range(0, 11, 1, 1)) df.select(expr(s"cast('$decimalStr' as decimal (38, 18)) as d"), lit(1).as("key")) .groupBy("key").agg(sum($"d")).show ``` 6. When it is building the hash map, it would have successfully added up to the 10 rows, and the sum at that point is `100000000000000000000` which is a overflowed value. And when it gets to the next row (11th row), when it fetches the value from the agg hash map, using the UnsafeRow.getDecimal(..) will throw an error as its value is not containable. `agg_unsafeRowAggBuffer_0.getDecimal(0, 38, 18)`,
---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected] With regards, Apache Git Services --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
