c21 commented on a change in pull request #29277:
URL: https://github.com/apache/spark/pull/29277#discussion_r462677816



##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
##########
@@ -316,6 +318,387 @@ trait HashJoin extends BaseJoinExec {
       resultProj(r)
     }
   }
+
+  /**
+   * Returns the code for generating join key for stream side, and expression 
of whether the key
+   * has any null in it or not.
+   */
+  protected def genStreamSideJoinKey(
+      ctx: CodegenContext,
+      input: Seq[ExprCode]): (ExprCode, String) = {
+    ctx.currentVars = input
+    if (streamedBoundKeys.length == 1 && streamedBoundKeys.head.dataType == 
LongType) {
+      // generate the join key as Long
+      val ev = streamedBoundKeys.head.genCode(ctx)
+      (ev, ev.isNull)
+    } else {
+      // generate the join key as UnsafeRow
+      val ev = GenerateUnsafeProjection.createCode(ctx, streamedBoundKeys)
+      (ev, s"${ev.value}.anyNull()")
+    }
+  }
+
+  /**
+   * Generates the code for variable of build side.
+   */
+  private def genBuildSideVars(ctx: CodegenContext, matched: String): 
Seq[ExprCode] = {
+    ctx.currentVars = null
+    ctx.INPUT_ROW = matched
+    buildPlan.output.zipWithIndex.map { case (a, i) =>
+      val ev = BoundReference(i, a.dataType, a.nullable).genCode(ctx)
+      if (joinType.isInstanceOf[InnerLike]) {
+        ev
+      } else {
+        // the variables are needed even there is no matched rows
+        val isNull = ctx.freshName("isNull")
+        val value = ctx.freshName("value")
+        val javaType = CodeGenerator.javaType(a.dataType)
+        val code = code"""
+          |boolean $isNull = true;
+          |$javaType $value = ${CodeGenerator.defaultValue(a.dataType)};
+          |if ($matched != null) {
+          |  ${ev.code}
+          |  $isNull = ${ev.isNull};
+          |  $value = ${ev.value};
+          |}
+         """.stripMargin
+        ExprCode(code, JavaCode.isNullVariable(isNull), 
JavaCode.variable(value, a.dataType))
+      }
+    }
+  }
+
+  /**
+   * Generate the (non-equi) condition used to filter joined rows. This is 
used in Inner, Left Semi
+   * and Left Anti joins.
+   */
+  protected def getJoinCondition(
+      ctx: CodegenContext,
+      input: Seq[ExprCode]): (String, String, Seq[ExprCode]) = {
+    val matched = ctx.freshName("matched")
+    val buildVars = genBuildSideVars(ctx, matched)
+    val checkCondition = if (condition.isDefined) {
+      val expr = condition.get
+      // evaluate the variables from build side that used by condition
+      val eval = evaluateRequiredVariables(buildPlan.output, buildVars, 
expr.references)
+      // filter the output via condition
+      ctx.currentVars = input ++ buildVars
+      val ev =
+        BindReferences.bindReference(expr, streamedPlan.output ++ 
buildPlan.output).genCode(ctx)
+      val skipRow = s"${ev.isNull} || !${ev.value}"
+      s"""
+         |$eval
+         |${ev.code}
+         |if (!($skipRow))
+       """.stripMargin
+    } else {
+      ""
+    }
+    (matched, checkCondition, buildVars)
+  }
+
+  /**
+   * Generates the code for Inner join.
+   */
+  protected def codegenInner(ctx: CodegenContext, input: Seq[ExprCode]): 
String = {
+    val (relationTerm, keyIsKnownUnique) = prepareRelation(ctx)
+    val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
+    val (matched, checkCondition, buildVars) = getJoinCondition(ctx, input)
+    val numOutput = metricTerm(ctx, "numOutputRows")
+
+    val resultVars = buildSide match {
+      case BuildLeft => buildVars ++ input
+      case BuildRight => input ++ buildVars
+    }
+
+    if (keyIsKnownUnique) {

Review comment:
       @cloud-fan - updated naming back to be `keyIsUnique`.

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
##########
@@ -70,4 +74,69 @@ case class ShuffledHashJoinExec(
       join(streamIter, hashed, numOutputRows)
     }
   }
+
+  override def inputRDDs(): Seq[RDD[InternalRow]] = {
+    streamedPlan.execute() :: buildPlan.execute() :: Nil
+  }
+
+  override def needCopyResult: Boolean = true
+
+  override protected def doProduce(ctx: CodegenContext): String = {
+    // inline mutable state since not many join operations in a task
+    val streamedInput = ctx.addMutableState(
+      "scala.collection.Iterator", "streamedInput", v => s"$v = inputs[0];", 
forceInline = true)
+    val buildInput = ctx.addMutableState(
+      "scala.collection.Iterator", "buildInput", v => s"$v = inputs[1];", 
forceInline = true)
+    val initRelation = ctx.addMutableState(
+      CodeGenerator.JAVA_BOOLEAN, "initRelation", v => s"$v = false;", 
forceInline = true)
+    val streamedRow = ctx.addMutableState(
+      "InternalRow", "streamedRow", forceInline = true)
+
+    val thisPlan = ctx.addReferenceObj("plan", this)
+    val (relationTerm, _) = prepareRelation(ctx)
+    val buildRelation = s"$relationTerm = 
$thisPlan.buildHashedRelation($buildInput);"

Review comment:
       @cloud-fan - yes I got your point. Updated to do initialization in 
`prepareRelation`, thanks.

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
##########
@@ -70,4 +74,69 @@ case class ShuffledHashJoinExec(
       join(streamIter, hashed, numOutputRows)
     }
   }
+
+  override def inputRDDs(): Seq[RDD[InternalRow]] = {
+    streamedPlan.execute() :: buildPlan.execute() :: Nil
+  }
+
+  override def needCopyResult: Boolean = true
+
+  override protected def doProduce(ctx: CodegenContext): String = {
+    // inline mutable state since not many join operations in a task
+    val streamedInput = ctx.addMutableState(
+      "scala.collection.Iterator", "streamedInput", v => s"$v = inputs[0];", 
forceInline = true)
+    val buildInput = ctx.addMutableState(
+      "scala.collection.Iterator", "buildInput", v => s"$v = inputs[1];", 
forceInline = true)
+    val initRelation = ctx.addMutableState(
+      CodeGenerator.JAVA_BOOLEAN, "initRelation", v => s"$v = false;", 
forceInline = true)
+    val streamedRow = ctx.addMutableState(
+      "InternalRow", "streamedRow", forceInline = true)
+
+    val thisPlan = ctx.addReferenceObj("plan", this)
+    val (relationTerm, _) = prepareRelation(ctx)
+    val buildRelation = s"$relationTerm = 
$thisPlan.buildHashedRelation($buildInput);"
+    val (streamInputVar, streamInputVarDecl) = createVars(ctx, streamedRow, 
streamedPlan.output)
+
+    val join = joinType match {
+      case _: InnerLike => codegenInner(ctx, streamInputVar)
+      case LeftOuter | RightOuter => codegenOuter(ctx, streamInputVar)
+      case LeftSemi => codegenSemi(ctx, streamInputVar)
+      case LeftAnti => codegenAnti(ctx, streamInputVar)
+      case _: ExistenceJoin => codegenExistence(ctx, streamInputVar)
+      case x =>
+        throw new IllegalArgumentException(
+          s"ShuffledHashJoin should not take $x as the JoinType")
+    }
+
+    s"""
+       |// construct hash map for shuffled hash join build side
+       |if (!$initRelation) {
+       |  $buildRelation
+       |  $initRelation = true;
+       |}
+       |
+       |while ($streamedInput.hasNext()) {
+       |  $streamedRow = (InternalRow) $streamedInput.next();
+       |  ${streamInputVarDecl.mkString("\n")}
+       |  $join
+       |
+       |  if (shouldStop()) return;
+       |}
+     """.stripMargin
+  }
+
+  /**
+   * Returns a tuple of variable name for HashedRelation,
+   * and boolean false to indicate key not to be known unique in code-gen time.
+   */
+  protected override def prepareRelation(ctx: CodegenContext): (String, 
Boolean) = {
+    if (relationTerm == null) {
+      // Inline mutable state since not many join operations in a task
+      relationTerm = ctx.addMutableState(
+        "org.apache.spark.sql.execution.joins.HashedRelation", "relation", 
forceInline = true)

Review comment:
       @cloud-fan and @viirya -
   
   Updated code with following change:
    
   (1). `ShuffledHashJoinExec.prepareRelation` will do `buildRelation` to build 
hash map.
   (1).`BroadcastHashJoinExec.prepareRelation` will do `prepareBroadcast`to 
broadcast build side and build hash map.

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
##########
@@ -70,4 +74,69 @@ case class ShuffledHashJoinExec(
       join(streamIter, hashed, numOutputRows)
     }
   }
+
+  override def inputRDDs(): Seq[RDD[InternalRow]] = {
+    streamedPlan.execute() :: buildPlan.execute() :: Nil
+  }
+
+  override def needCopyResult: Boolean = true
+
+  override protected def doProduce(ctx: CodegenContext): String = {
+    // inline mutable state since not many join operations in a task
+    val streamedInput = ctx.addMutableState(
+      "scala.collection.Iterator", "streamedInput", v => s"$v = inputs[0];", 
forceInline = true)
+    val buildInput = ctx.addMutableState(
+      "scala.collection.Iterator", "buildInput", v => s"$v = inputs[1];", 
forceInline = true)
+    val initRelation = ctx.addMutableState(
+      CodeGenerator.JAVA_BOOLEAN, "initRelation", v => s"$v = false;", 
forceInline = true)
+    val streamedRow = ctx.addMutableState(
+      "InternalRow", "streamedRow", forceInline = true)
+
+    val thisPlan = ctx.addReferenceObj("plan", this)
+    val (relationTerm, _) = prepareRelation(ctx)
+    val buildRelation = s"$relationTerm = 
$thisPlan.buildHashedRelation($buildInput);"
+    val (streamInputVar, streamInputVarDecl) = createVars(ctx, streamedRow, 
streamedPlan.output)
+
+    val join = joinType match {
+      case _: InnerLike => codegenInner(ctx, streamInputVar)
+      case LeftOuter | RightOuter => codegenOuter(ctx, streamInputVar)
+      case LeftSemi => codegenSemi(ctx, streamInputVar)
+      case LeftAnti => codegenAnti(ctx, streamInputVar)
+      case _: ExistenceJoin => codegenExistence(ctx, streamInputVar)
+      case x =>
+        throw new IllegalArgumentException(
+          s"ShuffledHashJoin should not take $x as the JoinType")
+    }
+
+    s"""
+       |// construct hash map for shuffled hash join build side
+       |if (!$initRelation) {
+       |  $buildRelation
+       |  $initRelation = true;
+       |}
+       |
+       |while ($streamedInput.hasNext()) {
+       |  $streamedRow = (InternalRow) $streamedInput.next();
+       |  ${streamInputVarDecl.mkString("\n")}
+       |  $join
+       |
+       |  if (shouldStop()) return;
+       |}
+     """.stripMargin
+  }
+
+  /**
+   * Returns a tuple of variable name for HashedRelation,
+   * and boolean false to indicate key not to be known unique in code-gen time.
+   */
+  protected override def prepareRelation(ctx: CodegenContext): (String, 
Boolean) = {
+    if (relationTerm == null) {
+      // Inline mutable state since not many join operations in a task
+      relationTerm = ctx.addMutableState(
+        "org.apache.spark.sql.execution.joins.HashedRelation", "relation", 
forceInline = true)

Review comment:
       @cloud-fan 
   
   re https://github.com/apache/spark/pull/29277#discussion_r462449882:
   
   wondering do you think keeping two vars would look much better, than the 
current approach for calling `prepareRelation` separately in each 
`codegenInner/codegenOuter/...` ? If yes, I can make the change accordingly, 
thanks.
   
   re https://github.com/apache/spark/pull/29277#discussion_r462457807:
   
   By design, `doConsume()` generates code for processing one input row. 
`BroadcastHashJoinExec` can do codegen work in `doConsume()` with only stream 
side input, because it just broadcasts executing its build side query plan, and 
generates per-row-processing codegen for stream side in `doConsume()`.
   
   However, `ShuffledHashJoinExec` cannot do codegen work in `doConsume()` with 
stream and build side input, because it needs to first read all build side 
input rows and build a hash map, before processing each row from stream side 
input. We cannot generate code in `doConsume()` with simply a pair of stream 
and build side input row. Similar to `SortMergeJoinExec`, where it needs to 
stream only one row in one side, and buffer the other side for multiple rows. 
`ShuffledHashJoinExec` has to generate code in `doProduce()`, and its children 
have to do codegen separately in their own iterator classes.

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
##########
@@ -903,6 +904,10 @@ case class CollapseCodegenStages(
         // The children of SortMergeJoin should do codegen separately.
         j.withNewChildren(j.children.map(
           child => InputAdapter(insertWholeStageCodegen(child))))
+      case j: ShuffledHashJoinExec =>
+        // The children of ShuffledHashJoin should do codegen separately.

Review comment:
       @viirya - sure, wondering what kind of wording you are expecting here? 
does it look better with:
   
   ```
   // The children of ShuffledHashJoin should do codegen separately,
   // because codegen for ShuffledHashJoin depends on more than one row
   // from the build side input.
   ```
   
   ```
   // The children of SortMergeJoin should do codegen separately,
   // because codegen for SortMergeJoin depends on more than one row
   // from the buffer side input.
   ```




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to