maropu commented on a change in pull request #20965: [SPARK-21870][SQL] Split 
aggregation code into small functions
URL: https://github.com/apache/spark/pull/20965#discussion_r319323818
 
 

 ##########
 File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
 ##########
 @@ -115,7 +115,8 @@ package object dsl {
     def getField(fieldName: String): UnresolvedExtractValue =
       UnresolvedExtractValue(expr, Literal(fieldName))
 
-    def cast(to: DataType): Expression = Cast(expr, to)
+    def cast(to: DataType): Expression =
+      if (!expr.dataType.sameType(to)) Cast(expr, to) else expr
 
 Review comment:
   Without this fix, the query below failed to compile [generated agg 
code](https://gist.github.com/maropu/05a11c3b1c9d7b7673772a7a36683502);
   ```
   scala> sql("""CREATE TABLE dest1(c1 STRING) STORED AS TEXTFILE""")
   scala> df = sql("""SELECT avg(c1), sum(c1), count(c1) FROM dest1""")
   scala> df.explain
   == Physical Plan ==
   *(2) HashAggregate(keys=[], functions=[avg(cast(c1#0 as double)), 
sum(cast(c1#0 as double)), count(c1#0)])
   +- Exchange SinglePartition, true
      +- *(1) HashAggregate(keys=[], functions=[partial_avg(cast(c1#0 as 
double)), partial_sum(cast(c1#0 as double)), partial_count(c1#0)])
         +- Scan hive default.dest1 [c1#0], HiveTableRelation 
`default`.`dest1`, org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe, [c1#0]
   
   scala> df.show
   19/08/30 08:48:52 ERROR CodeGenerator: failed to compile: 
org.codehaus.commons.compiler.CompileException:
    File 'generated.java', Line 67, Column 6: Expression "agg_isNull_5" is not 
an rvalue
   ```
   
   This is because `getLocalInputVariableValues` [couldn't get a correct input 
variable 
set](https://gist.github.com/maropu/05a11c3b1c9d7b7673772a7a36683502#file-spark-20965-L41)
 from a common subexprs state `subExprs`.
   
   == Root cause ==
   [The current logic for common subexpression 
elimination](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala#L272-L276)
 does top-down replacements and IIUC it assumes an evaluation code flow of 
expressions doesn't change between L272 and L275.
   
   The query above has an expression `cast(c1#0 as double)` for sum. Since [sum 
internally uses a cast for aggregation buffer 
evals](https://github.com/apache/spark/blob/f8f7c52f1272a11e6549fc2e4f5a308e295d1de9/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L76),
 so it actually has `cast(cast(c1#0 as double) as double)` for the buffer 
evals. In the query above,  the code in L272  collects two common 
subexpressions below;
   ```
    - cast(cast(c1#0 as double) as double) =>  `agg_value_5/agg_isNull_5`
    - cast(c1#0 as double) => `agg_value_7/agg_isNull_7`
   ```
   Then, the code `boundUpdateExpr.map(_.genCode(ctx))` in L275 replaces exprs 
by using the `subExprs` in a top-down manner. In this case, 
`agg_value_5/agg_isNull_5` should be used for the split sum function. But, in 
this nested cast case, 
[Cast.genCode](https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala#L793)
 skips the outer cast eval, then `agg_value_7/agg_isNull_7` is wrongly selected 
in L275. As a result of that, the query generates the illegal code.
   
   To simply fix this issue, I added this change in this pr.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to