c21 commented on a change in pull request #29277:
URL: https://github.com/apache/spark/pull/29277#discussion_r463403584



##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
##########
@@ -903,6 +904,10 @@ case class CollapseCodegenStages(
         // The children of SortMergeJoin should do codegen separately.
         j.withNewChildren(j.children.map(
           child => InputAdapter(insertWholeStageCodegen(child))))
+      case j: ShuffledHashJoinExec =>
+        // The children of ShuffledHashJoin should do codegen separately.
+        j.withNewChildren(j.children.map(

Review comment:
       So here is the problematic query (same as above)
   
   ```
     test("ShuffledHashJoin should be included in WholeStageCodegen") {
       withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "30",
           SQLConf.SHUFFLE_PARTITIONS.key -> "2",
           SQLConf.PREFER_SORTMERGEJOIN.key -> "false") {
         val df1 = spark.range(5).select($"id".as("k1"))
         val df2 = spark.range(15).select($"id".as("k2"))
         val df3 = spark.range(6).select($"id".as("k3"))
         val twoJoinsDF = df1.join(df2, $"k1" === $"k2").join(df3, $"k1" === 
$"k3")
       }
     }
   ```
   
   if we only codegen separately build side:
   
   ```
   case j: ShuffledHashJoinExec =>
           // The children of ShuffledHashJoin should do codegen separately.
           val newChildren = j.buildSide match {
             case BuildLeft =>
               val buildChild = InputAdapter(insertWholeStageCodegen(j.left))
               val streamChild = insertInputAdapter(j.right)
               Seq(buildChild, streamChild)
             case BuildRight =>
               val buildChild = InputAdapter(insertWholeStageCodegen(j.right))
               val streamChild = insertInputAdapter(j.left)
               Seq(streamChild, buildChild)
           }
           j.withNewChildren(newChildren)
   ```
   
   generated code where `shj_relation_0` and  `shj_relation_1` both are 
building relation on `input[1]`, but shouldn't be:
   
   ```
   == Subtree 4 / 4 (maxMethodCodeSize:190; maxConstantPoolSize:129(0.20% 
used); numInnerClasses:0) ==
   *(4) ShuffledHashJoin [k1#2L], [k3#10L], Inner, BuildRight
   :- *(4) ShuffledHashJoin [k1#2L], [k2#6L], Inner, BuildLeft
   :  :- Exchange hashpartitioning(k1#2L, 2), true, [id=#111]
   :  :  +- *(2) Project [id#0L AS k1#2L]
   :  :     +- *(2) Range (0, 5, step=1, splits=2)
   :  +- Exchange hashpartitioning(k2#6L, 2), true, [id=#114]
   :     +- *(3) Project [id#4L AS k2#6L]
   :        +- *(3) Range (0, 15, step=1, splits=2)
   +- Exchange hashpartitioning(k3#10L, 2), true, [id=#108]
      +- *(1) Project [id#8L AS k3#10L]
         +- *(1) Range (0, 6, step=1, splits=2)
   
   Generated code:
   /* 001 */ public Object generate(Object[] references) {
   /* 002 */   return new GeneratedIteratorForCodegenStage4(references);
   /* 003 */ }
   /* 004 */
   /* 005 */ // codegenStageId=4
   /* 006 */ final class GeneratedIteratorForCodegenStage4 extends 
org.apache.spark.sql.execution.BufferedRowIterator {
   /* 007 */   private Object[] references;
   /* 008 */   private scala.collection.Iterator[] inputs;
   /* 009 */   private scala.collection.Iterator inputadapter_input_0;
   /* 010 */   private org.apache.spark.sql.execution.joins.HashedRelation 
shj_relation_0;
   /* 011 */   private org.apache.spark.sql.execution.joins.HashedRelation 
shj_relation_1;
   /* 012 */   private 
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] 
shj_mutableStateArray_0 = new 
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[2];
   /* 013 */
   /* 014 */   public GeneratedIteratorForCodegenStage4(Object[] references) {
   /* 015 */     this.references = references;
   /* 016 */   }
   /* 017 */
   /* 018 */   public void init(int index, scala.collection.Iterator[] inputs) {
   /* 019 */     partitionIndex = index;
   /* 020 */     this.inputs = inputs;
   /* 021 */     inputadapter_input_0 = inputs[0];
   /* 022 */     shj_relation_0 = 
((org.apache.spark.sql.execution.joins.ShuffledHashJoinExec) references[0] /* 
plan */).buildHashedRelation(inputs[1]);
   /* 023 */     shj_mutableStateArray_0[0] = new 
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(2, 0);
   /* 024 */     shj_relation_1 = 
((org.apache.spark.sql.execution.joins.ShuffledHashJoinExec) references[2] /* 
plan */).buildHashedRelation(inputs[1]);
   /* 025 */     shj_mutableStateArray_0[1] = new 
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(3, 0);
   /* 026 */
   /* 027 */   }
   /* 028 */
   /* 029 */   private void shj_doConsume_0(InternalRow inputadapter_row_0, 
long shj_expr_0_0) throws java.io.IOException {
   /* 030 */     // generate join key for stream side
   /* 031 */
   /* 032 */     // find matches from HashRelation
   /* 033 */     scala.collection.Iterator shj_matches_0 = false ?
   /* 034 */     null : 
(scala.collection.Iterator)shj_relation_0.get(shj_expr_0_0);
   /* 035 */     if (shj_matches_0 != null) {
   /* 036 */       while (shj_matches_0.hasNext()) {
   /* 037 */         UnsafeRow shj_matched_0 = (UnsafeRow) shj_matches_0.next();
   /* 038 */         {
   /* 039 */           ((org.apache.spark.sql.execution.metric.SQLMetric) 
references[1] /* numOutputRows */).add(1);
   /* 040 */
   /* 041 */           long shj_value_1 = shj_matched_0.getLong(0);
   /* 042 */
   /* 043 */           // generate join key for stream side
   /* 044 */
   /* 045 */           // find matches from HashRelation
   /* 046 */           scala.collection.Iterator shj_matches_1 = false ?
   /* 047 */           null : 
(scala.collection.Iterator)shj_relation_1.get(shj_value_1);
   /* 048 */           if (shj_matches_1 != null) {
   /* 049 */             while (shj_matches_1.hasNext()) {
   /* 050 */               UnsafeRow shj_matched_1 = (UnsafeRow) 
shj_matches_1.next();
   /* 051 */               {
   /* 052 */                 ((org.apache.spark.sql.execution.metric.SQLMetric) 
references[3] /* numOutputRows */).add(1);
   /* 053 */
   /* 054 */                 long shj_value_5 = shj_matched_1.getLong(0);
   /* 055 */                 shj_mutableStateArray_0[1].reset();
   /* 056 */
   /* 057 */                 shj_mutableStateArray_0[1].write(0, shj_value_1);
   /* 058 */
   /* 059 */                 shj_mutableStateArray_0[1].write(1, shj_expr_0_0);
   /* 060 */
   /* 061 */                 shj_mutableStateArray_0[1].write(2, shj_value_5);
   /* 062 */                 
append((shj_mutableStateArray_0[1].getRow()).copy());
   /* 063 */
   /* 064 */               }
   /* 065 */             }
   /* 066 */           }
   /* 067 */
   /* 068 */         }
   /* 069 */       }
   /* 070 */     }
   /* 071 */
   /* 072 */   }
   /* 073 */
   /* 074 */   protected void processNext() throws java.io.IOException {
   /* 075 */     while ( inputadapter_input_0.hasNext()) {
   /* 076 */       InternalRow inputadapter_row_0 = (InternalRow) 
inputadapter_input_0.next();
   /* 077 */
   /* 078 */       long inputadapter_value_0 = inputadapter_row_0.getLong(0);
   /* 079 */
   /* 080 */       shj_doConsume_0(inputadapter_row_0, inputadapter_value_0);
   /* 081 */       if (shouldStop()) return;
   /* 082 */     }
   /* 083 */   }
   /* 084 */
   /* 085 */ }
   ```




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to