maropu commented on a change in pull request #20965: [SPARK-21870][SQL] Split
aggregation code into small functions
URL: https://github.com/apache/spark/pull/20965#discussion_r319323818
##########
File path:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
##########
@@ -115,7 +115,8 @@ package object dsl {
def getField(fieldName: String): UnresolvedExtractValue =
UnresolvedExtractValue(expr, Literal(fieldName))
- def cast(to: DataType): Expression = Cast(expr, to)
+ def cast(to: DataType): Expression =
+ if (!expr.dataType.sameType(to)) Cast(expr, to) else expr
Review comment:
Without this fix, the query below failed to compile [generated agg
code](https://gist.github.com/maropu/05a11c3b1c9d7b7673772a7a36683502);
```
scala> sql("""CREATE TABLE dest1(c1 STRING) STORED AS TEXTFILE""")
scala> df = sql("""SELECT avg(c1), sum(c1), count(c1) FROM dest1""")
scala> df.explain
== Physical Plan ==
*(2) HashAggregate(keys=[], functions=[avg(cast(c1#0 as double)),
sum(cast(c1#0 as double)), count(c1#0)])
+- Exchange SinglePartition, true
+- *(1) HashAggregate(keys=[], functions=[partial_avg(cast(c1#0 as
double)), partial_sum(cast(c1#0 as double)), partial_count(c1#0)])
+- Scan hive default.dest1 [c1#0], HiveTableRelation
`default`.`dest1`, org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe, [c1#0]
scala> df.show
19/08/30 08:48:52 ERROR CodeGenerator: failed to compile:
org.codehaus.commons.compiler.CompileException:
File 'generated.java', Line 67, Column 6: Expression "agg_isNull_5" is not
an rvalue
```
This is because `getLocalInputVariableValues` [couldn't get a correct input
variable
set](https://gist.github.com/maropu/05a11c3b1c9d7b7673772a7a36683502#file-spark-20965-L41)
from a common subexprs state `subExprs`.
== Root cause ==
[The current logic for common subexpression
elimination](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala#L272-L276)
does top-down replacements and IIUC it assumes an evaluation code flow of
expressions doesn't change between L272 and L275.
The query above has an expression `cast(c1#0 as double)` for sum. Since [sum
internally uses a cast for aggregation buffer
evals](https://github.com/apache/spark/blob/f8f7c52f1272a11e6549fc2e4f5a308e295d1de9/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L76),
so it actually has `cast(cast(c1#0 as double) as double)` for the buffer
evals. In the query above, the code in L272 collects two common
subexpressions below;
```
- cast(cast(c1#0 as double) as double) => `agg_value_5/agg_isNull_5`
- cast(c1#0 as double) => `agg_value_7/agg_isNull_7`
```
Then, the code `boundUpdateExpr.map(_.genCode(ctx))` in L275 replaces exprs
by using the `subExprs` in a top-down manner. In this case,
`agg_value_5/agg_isNull_5` should be used for the split sum function. But, in
this nested cast case,
[Cast.genCode](https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala#L793)
skips the outer cast eval, then `agg_value_7/agg_isNull_7` is wrongly selected
in L275. As a result of that, the query generates the illegal code.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]