c21 opened a new pull request #29277:
URL: https://github.com/apache/spark/pull/29277


   <!--
   Thanks for sending a pull request!  Here are some tips for you:
     1. If this is your first time, please read our contributor guidelines: 
https://spark.apache.org/contributing.html
     2. Ensure you have added or run the appropriate tests for your PR: 
https://spark.apache.org/developer-tools.html
     3. If the PR is unfinished, add '[WIP]' in your PR title, e.g., 
'[WIP][SPARK-XXXX] Your PR title ...'.
     4. Be sure to keep the PR description updated to reflect all changes.
     5. Please write your PR title to summarize what this PR proposes.
     6. If possible, provide a concise example to reproduce the issue for a 
faster review.
     7. If you want to add a new configuration, please read the guideline first 
for naming configurations in
        
'core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala'.
   -->
   
   ### What changes were proposed in this pull request?
   <!--
   Please clarify what changes you are proposing. The purpose of this section 
is to outline the changes and how this PR fixes the issue. 
   If possible, please consider writing useful notes for better and faster 
reviews in your PR. See the examples below.
     1. If you refactor some codes with changing classes, showing the class 
hierarchy will help reviewers.
     2. If you fix some SQL features, you can provide some references of other 
DBMSes.
     3. If there is design documentation, please add the link.
     4. If there is a discussion in the mailing list, please add the link.
   -->
   
   Adding codegen for shuffled hash join. Shuffled hash join codegen is very 
similar to broadcast hash join codegen. So most of code change is to refactor 
existing codegen in `BroadcastHashJoinExec` to `HashJoin`.
   
   Example codegen for query in 
[`JoinBenchmark`](https://github.com/apache/spark/blob/master/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/JoinBenchmark.scala#L153):
   
   ```
     def shuffleHashJoin(): Unit = {
       val N: Long = 4 << 20
       withSQLConf(
         SQLConf.SHUFFLE_PARTITIONS.key -> "2",
         SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "10000000",
         SQLConf.PREFER_SORTMERGEJOIN.key -> "false") {
         codegenBenchmark("shuffle hash join", N) {
           val df1 = spark.range(N).selectExpr(s"id as k1")
           val df2 = spark.range(N / 3).selectExpr(s"id * 3 as k2")
           val df = df1.join(df2, col("k1") === col("k2"))
           
assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[ShuffledHashJoinExec]).isDefined)
           df.noop()
         }
       }
     }
   ```
   
   Shuffled hash join codegen:
   
   ```
   == Subtree 3 / 3 (maxMethodCodeSize:206; maxConstantPoolSize:135(0.21% 
used); numInnerClasses:0) ==
   *(3) ShuffledHashJoin [k1#2L], [k2#6L], Inner, BuildRight
   :- *(1) Project [id#0L AS k1#2L]
   :  +- *(1) Range (0, 4194304, step=1, splits=1)
   +- *(2) Project [(id#4L * 3) AS k2#6L]
      +- *(2) Range (0, 1398101, step=1, splits=1)
   
   Generated code:
   /* 001 */ public Object generate(Object[] references) {
   /* 002 */   return new GeneratedIteratorForCodegenStage3(references);
   /* 003 */ }
   /* 004 */
   /* 005 */ // codegenStageId=3
   /* 006 */ final class GeneratedIteratorForCodegenStage3 extends 
org.apache.spark.sql.execution.BufferedRowIterator {
   /* 007 */   private Object[] references;
   /* 008 */   private scala.collection.Iterator[] inputs;
   /* 009 */   private scala.collection.Iterator shj_streamedInput_0;
   /* 010 */   private scala.collection.Iterator shj_buildInput_0;
   /* 011 */   private boolean shj_initRelation_0;
   /* 012 */   private InternalRow shj_streamedRow_0;
   /* 013 */   private org.apache.spark.sql.execution.joins.HashedRelation 
shj_relation_0;
   /* 014 */   private 
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] 
shj_mutableStateArray_0 = new 
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[1];
   /* 015 */
   /* 016 */   public GeneratedIteratorForCodegenStage3(Object[] references) {
   /* 017 */     this.references = references;
   /* 018 */   }
   /* 019 */
   /* 020 */   public void init(int index, scala.collection.Iterator[] inputs) {
   /* 021 */     partitionIndex = index;
   /* 022 */     this.inputs = inputs;
   /* 023 */     shj_streamedInput_0 = inputs[0];
   /* 024 */     shj_buildInput_0 = inputs[1];
   /* 025 */     shj_initRelation_0 = false;
   /* 026 */
   /* 027 */     shj_mutableStateArray_0[0] = new 
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(2, 0);
   /* 028 */
   /* 029 */   }
   /* 030 */
   /* 031 */   protected void processNext() throws java.io.IOException {
   /* 032 */     // construct hash map for shuffled hash join build side
   /* 033 */     if (!shj_initRelation_0) {
   /* 034 */       shj_relation_0 = 
((org.apache.spark.sql.execution.joins.ShuffledHashJoinExec) references[0] /* 
plan */).buildHashedRelation(shj_buildInput_0);
   /* 035 */       shj_initRelation_0 = true;
   /* 036 */     }
   /* 037 */
   /* 038 */     while (shj_streamedInput_0.hasNext()) {
   /* 039 */       shj_streamedRow_0 = (InternalRow) shj_streamedInput_0.next();
   /* 040 */       long shj_value_0 = -1L;
   /* 041 */
   /* 042 */       // generate join key for stream side
   /* 043 */       shj_value_0 = shj_streamedRow_0.getLong(0);
   /* 044 */       // find matches from HashRelation
   /* 045 */       scala.collection.Iterator shj_matches_0 = false ?
   /* 046 */       null : 
(scala.collection.Iterator)shj_relation_0.get(shj_value_0);
   /* 047 */       if (shj_matches_0 != null) {
   /* 048 */         while (shj_matches_0.hasNext()) {
   /* 049 */           UnsafeRow shj_matched_0 = (UnsafeRow) 
shj_matches_0.next();
   /* 050 */           {
   /* 051 */             ((org.apache.spark.sql.execution.metric.SQLMetric) 
references[1] /* numOutputRows */).add(1);
   /* 052 */
   /* 053 */             shj_value_0 = shj_streamedRow_0.getLong(0);
   /* 054 */             long shj_value_2 = shj_matched_0.getLong(0);
   /* 055 */             shj_mutableStateArray_0[0].reset();
   /* 056 */
   /* 057 */             shj_mutableStateArray_0[0].write(0, shj_value_0);
   /* 058 */
   /* 059 */             shj_mutableStateArray_0[0].write(1, shj_value_2);
   /* 060 */             append((shj_mutableStateArray_0[0].getRow()).copy());
   /* 061 */
   /* 062 */           }
   /* 063 */         }
   /* 064 */       }
   /* 065 */
   /* 066 */       if (shouldStop()) return;
   /* 067 */     }
   /* 068 */   }
   /* 069 */
   /* 070 */ }
   ```
   
   Broadcast hash join codegen for the same query (for reference here):
   
   ```
   == Subtree 2 / 2 (maxMethodCodeSize:280; maxConstantPoolSize:218(0.33% 
used); numInnerClasses:0) ==
   *(2) BroadcastHashJoin [k1#2L], [k2#6L], Inner, BuildRight, false
   :- *(2) Project [id#0L AS k1#2L]
   :  +- *(2) Range (0, 4194304, step=1, splits=1)
   +- BroadcastExchange HashedRelationBroadcastMode(List(input[0, bigint, 
false]),false), [id=#22]
      +- *(1) Project [(id#4L * 3) AS k2#6L]
         +- *(1) Range (0, 1398101, step=1, splits=1)
   
   Generated code:
   /* 001 */ public Object generate(Object[] references) {
   /* 002 */   return new GeneratedIteratorForCodegenStage2(references);
   /* 003 */ }
   /* 004 */
   /* 005 */ // codegenStageId=2
   /* 006 */ final class GeneratedIteratorForCodegenStage2 extends 
org.apache.spark.sql.execution.BufferedRowIterator {
   /* 007 */   private Object[] references;
   /* 008 */   private scala.collection.Iterator[] inputs;
   /* 009 */   private boolean range_initRange_0;
   /* 010 */   private long range_nextIndex_0;
   /* 011 */   private TaskContext range_taskContext_0;
   /* 012 */   private InputMetrics range_inputMetrics_0;
   /* 013 */   private long range_batchEnd_0;
   /* 014 */   private long range_numElementsTodo_0;
   /* 015 */   private org.apache.spark.sql.execution.joins.LongHashedRelation 
bhj_relation_0;
   /* 016 */   private 
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] 
range_mutableStateArray_0 = new 
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[4];
   /* 017 */
   /* 018 */   public GeneratedIteratorForCodegenStage2(Object[] references) {
   /* 019 */     this.references = references;
   /* 020 */   }
   /* 021 */
   /* 022 */   public void init(int index, scala.collection.Iterator[] inputs) {
   /* 023 */     partitionIndex = index;
   /* 024 */     this.inputs = inputs;
   /* 025 */
   /* 026 */     range_taskContext_0 = TaskContext.get();
   /* 027 */     range_inputMetrics_0 = 
range_taskContext_0.taskMetrics().inputMetrics();
   /* 028 */     range_mutableStateArray_0[0] = new 
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
   /* 029 */     range_mutableStateArray_0[1] = new 
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
   /* 030 */     range_mutableStateArray_0[2] = new 
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
   /* 031 */
   /* 032 */     bhj_relation_0 = 
((org.apache.spark.sql.execution.joins.LongHashedRelation) 
((org.apache.spark.broadcast.TorrentBroadcast) references[1] /* broadcast 
*/).value()).asReadOnlyCopy();
   /* 033 */     incPeakExecutionMemory(bhj_relation_0.estimatedSize());
   /* 034 */
   /* 035 */     range_mutableStateArray_0[3] = new 
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(2, 0);
   /* 036 */
   /* 037 */   }
   /* 038 */
   /* 039 */   private void initRange(int idx) {
   /* 040 */     java.math.BigInteger index = java.math.BigInteger.valueOf(idx);
   /* 041 */     java.math.BigInteger numSlice = 
java.math.BigInteger.valueOf(1L);
   /* 042 */     java.math.BigInteger numElement = 
java.math.BigInteger.valueOf(4194304L);
   /* 043 */     java.math.BigInteger step = java.math.BigInteger.valueOf(1L);
   /* 044 */     java.math.BigInteger start = java.math.BigInteger.valueOf(0L);
   /* 045 */     long partitionEnd;
   /* 046 */
   /* 047 */     java.math.BigInteger st = 
index.multiply(numElement).divide(numSlice).multiply(step).add(start);
   /* 048 */     if (st.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) 
> 0) {
   /* 049 */       range_nextIndex_0 = Long.MAX_VALUE;
   /* 050 */     } else if 
(st.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) {
   /* 051 */       range_nextIndex_0 = Long.MIN_VALUE;
   /* 052 */     } else {
   /* 053 */       range_nextIndex_0 = st.longValue();
   /* 054 */     }
   /* 055 */     range_batchEnd_0 = range_nextIndex_0;
   /* 056 */
   /* 057 */     java.math.BigInteger end = 
index.add(java.math.BigInteger.ONE).multiply(numElement).divide(numSlice)
   /* 058 */     .multiply(step).add(start);
   /* 059 */     if 
(end.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) {
   /* 060 */       partitionEnd = Long.MAX_VALUE;
   /* 061 */     } else if 
(end.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) {
   /* 062 */       partitionEnd = Long.MIN_VALUE;
   /* 063 */     } else {
   /* 064 */       partitionEnd = end.longValue();
   /* 065 */     }
   /* 066 */
   /* 067 */     java.math.BigInteger startToEnd = 
java.math.BigInteger.valueOf(partitionEnd).subtract(
   /* 068 */       java.math.BigInteger.valueOf(range_nextIndex_0));
   /* 069 */     range_numElementsTodo_0  = startToEnd.divide(step).longValue();
   /* 070 */     if (range_numElementsTodo_0 < 0) {
   /* 071 */       range_numElementsTodo_0 = 0;
   /* 072 */     } else if 
(startToEnd.remainder(step).compareTo(java.math.BigInteger.valueOf(0L)) != 0) {
   /* 073 */       range_numElementsTodo_0++;
   /* 074 */     }
   /* 075 */   }
   /* 076 */
   /* 077 */   private void bhj_doConsume_0(long bhj_expr_0_0) throws 
java.io.IOException {
   /* 078 */     // generate join key for stream side
   /* 079 */
   /* 080 */     // find matches from HashedRelation
   /* 081 */     UnsafeRow bhj_matched_0 = false ? null: 
(UnsafeRow)bhj_relation_0.getValue(bhj_expr_0_0);
   /* 082 */     if (bhj_matched_0 != null) {
   /* 083 */       {
   /* 084 */         ((org.apache.spark.sql.execution.metric.SQLMetric) 
references[2] /* numOutputRows */).add(1);
   /* 085 */
   /* 086 */         long bhj_value_2 = bhj_matched_0.getLong(0);
   /* 087 */         range_mutableStateArray_0[3].reset();
   /* 088 */
   /* 089 */         range_mutableStateArray_0[3].write(0, bhj_expr_0_0);
   /* 090 */
   /* 091 */         range_mutableStateArray_0[3].write(1, bhj_value_2);
   /* 092 */         append((range_mutableStateArray_0[3].getRow()));
   /* 093 */
   /* 094 */       }
   /* 095 */     }
   /* 096 */
   /* 097 */   }
   /* 098 */
   /* 099 */   protected void processNext() throws java.io.IOException {
   /* 100 */     // initialize Range
   /* 101 */     if (!range_initRange_0) {
   /* 102 */       range_initRange_0 = true;
   /* 103 */       initRange(partitionIndex);
   /* 104 */     }
   /* 105 */
   /* 106 */     while (true) {
   /* 107 */       if (range_nextIndex_0 == range_batchEnd_0) {
   /* 108 */         long range_nextBatchTodo_0;
   /* 109 */         if (range_numElementsTodo_0 > 1000L) {
   /* 110 */           range_nextBatchTodo_0 = 1000L;
   /* 111 */           range_numElementsTodo_0 -= 1000L;
   /* 112 */         } else {
   /* 113 */           range_nextBatchTodo_0 = range_numElementsTodo_0;
   /* 114 */           range_numElementsTodo_0 = 0;
   /* 115 */           if (range_nextBatchTodo_0 == 0) break;
   /* 116 */         }
   /* 117 */         range_batchEnd_0 += range_nextBatchTodo_0 * 1L;
   /* 118 */       }
   /* 119 */
   /* 120 */       int range_localEnd_0 = (int)((range_batchEnd_0 - 
range_nextIndex_0) / 1L);
   /* 121 */       for (int range_localIdx_0 = 0; range_localIdx_0 < 
range_localEnd_0; range_localIdx_0++) {
   /* 122 */         long range_value_0 = ((long)range_localIdx_0 * 1L) + 
range_nextIndex_0;
   /* 123 */
   /* 124 */         bhj_doConsume_0(range_value_0);
   /* 125 */
   /* 126 */         if (shouldStop()) {
   /* 127 */           range_nextIndex_0 = range_value_0 + 1L;
   /* 128 */           ((org.apache.spark.sql.execution.metric.SQLMetric) 
references[0] /* numOutputRows */).add(range_localIdx_0 + 1);
   /* 129 */           range_inputMetrics_0.incRecordsRead(range_localIdx_0 + 
1);
   /* 130 */           return;
   /* 131 */         }
   /* 132 */
   /* 133 */       }
   /* 134 */       range_nextIndex_0 = range_batchEnd_0;
   /* 135 */       ((org.apache.spark.sql.execution.metric.SQLMetric) 
references[0] /* numOutputRows */).add(range_localEnd_0);
   /* 136 */       range_inputMetrics_0.incRecordsRead(range_localEnd_0);
   /* 137 */       range_taskContext_0.killTaskIfInterrupted();
   /* 138 */     }
   /* 139 */   }
   /* 140 */
   /* 141 */ }
   ```
   
   ### Why are the changes needed?
   <!--
   Please clarify why the changes are needed. For instance,
     1. If you propose a new API, clarify the use case for a new API.
     2. If you fix a bug, you can clarify why it is a bug.
   -->
   
   Codegen shuffled hash join can help save CPU cost. We added shuffled hash 
join codegen internally in our fork, and seeing obvious improvement in 
benchmark compared to current non-codegen code path.
   
   Test example query in 
[`JoinBenchmark`](https://github.com/apache/spark/blob/master/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/JoinBenchmark.scala#L153),
 seeing 30% wall clock time improvement compared to existing non-codegen code 
path:
   
   Enable shuffled hash join code-gen:
   
   ```
   Running benchmark: shuffle hash join
     Running case: shuffle hash join wholestage off
     Stopped after 2 iterations, 1358 ms
     Running case: shuffle hash join wholestage on
     Stopped after 5 iterations, 2323 ms
   
   Java HotSpot(TM) 64-Bit Server VM 1.8.0_181-b13 on Mac OS X 10.15.4
   Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz
   shuffle hash join:                        Best Time(ms)   Avg Time(ms)   
Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
   
------------------------------------------------------------------------------------------------------------------------
   shuffle hash join wholestage off                    649            679       
   43          6.5         154.7       1.0X
   shuffle hash join wholestage on                     436            465       
   45          9.6         103.9       1.5X
   ```
   
   Disable shuffled hash join codegen:
   
   ```
   Running benchmark: shuffle hash join
     Running case: shuffle hash join wholestage off
     Stopped after 2 iterations, 1345 ms
     Running case: shuffle hash join wholestage on
     Stopped after 5 iterations, 2967 ms
   
   Java HotSpot(TM) 64-Bit Server VM 1.8.0_181-b13 on Mac OS X 10.15.4
   Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz
   shuffle hash join:                        Best Time(ms)   Avg Time(ms)   
Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
   
------------------------------------------------------------------------------------------------------------------------
   shuffle hash join wholestage off                    646            673       
   37          6.5         154.1       1.0X
   shuffle hash join wholestage on                     549            594       
   47          7.6         130.9       1.2X
   ```
   
   ### Does this PR introduce _any_ user-facing change?
   <!--
   Note that it means *any* user-facing change including all aspects such as 
the documentation fix.
   If yes, please clarify the previous behavior and the change this PR proposes 
- provide the console output, description and/or an example to show the 
behavior difference if possible.
   If possible, please also clarify if this is a user-facing change compared to 
the released Spark versions or within the unreleased branches such as master.
   If no, write 'No'.
   -->
   
   No. 
   
   ### How was this patch tested?
   <!--
   If tests were added, say they were added here. Please make sure to add some 
test cases that check the changes thoroughly including negative and positive 
cases if possible.
   If it was tested in a way different from regular unit tests, please clarify 
how you tested step by step, ideally copy and paste-able, so that other 
reviewers can test and check, and descendants can verify in the future.
   If tests were not added, please describe why they were not added and/or why 
it was difficult to add.
   -->
   
   Added unit test in `WholeStageCodegenSuite`.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to