viirya commented on a change in pull request #29277:
URL: https://github.com/apache/spark/pull/29277#discussion_r461864312
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
##########
@@ -70,4 +74,69 @@ case class ShuffledHashJoinExec(
join(streamIter, hashed, numOutputRows)
}
}
+
+ override def inputRDDs(): Seq[RDD[InternalRow]] = {
+ streamedPlan.execute() :: buildPlan.execute() :: Nil
+ }
+
+ override def needCopyResult: Boolean = true
+
+ override protected def doProduce(ctx: CodegenContext): String = {
+ // inline mutable state since not many join operations in a task
+ val streamedInput = ctx.addMutableState(
+ "scala.collection.Iterator", "streamedInput", v => s"$v = inputs[0];",
forceInline = true)
+ val buildInput = ctx.addMutableState(
+ "scala.collection.Iterator", "buildInput", v => s"$v = inputs[1];",
forceInline = true)
+ val initRelation = ctx.addMutableState(
+ CodeGenerator.JAVA_BOOLEAN, "initRelation", v => s"$v = false;",
forceInline = true)
+ val streamedRow = ctx.addMutableState(
+ "InternalRow", "streamedRow", forceInline = true)
+
+ val thisPlan = ctx.addReferenceObj("plan", this)
+ val (relationTerm, _) = prepareRelation(ctx)
+ val buildRelation = s"$relationTerm =
$thisPlan.buildHashedRelation($buildInput);"
+ val (streamInputVar, streamInputVarDecl) = createVars(ctx, streamedRow,
streamedPlan.output)
+
+ val join = joinType match {
+ case _: InnerLike => codegenInner(ctx, streamInputVar)
+ case LeftOuter | RightOuter => codegenOuter(ctx, streamInputVar)
+ case LeftSemi => codegenSemi(ctx, streamInputVar)
+ case LeftAnti => codegenAnti(ctx, streamInputVar)
+ case _: ExistenceJoin => codegenExistence(ctx, streamInputVar)
+ case x =>
+ throw new IllegalArgumentException(
+ s"ShuffledHashJoin should not take $x as the JoinType")
+ }
+
+ s"""
+ |// construct hash map for shuffled hash join build side
+ |if (!$initRelation) {
+ | $buildRelation
+ | $initRelation = true;
+ |}
+ |
+ |while ($streamedInput.hasNext()) {
+ | $streamedRow = (InternalRow) $streamedInput.next();
+ | ${streamInputVarDecl.mkString("\n")}
+ | $join
+ |
+ | if (shouldStop()) return;
+ |}
+ """.stripMargin
+ }
+
+ /**
+ * Returns a tuple of variable name for HashedRelation,
+ * and boolean false to indicate key not to be known unique in code-gen time.
+ */
+ protected override def prepareRelation(ctx: CodegenContext): (String,
Boolean) = {
+ if (relationTerm == null) {
+ // Inline mutable state since not many join operations in a task
+ relationTerm = ctx.addMutableState(
+ "org.apache.spark.sql.execution.joins.HashedRelation", "relation",
forceInline = true)
Review comment:
As you already use mutable state for the hashed relation here, why don't
just follow BroadcastHashJoinExec to call buildHashedRelation at
prepareRelation and set it to mutable state? Then BroadcastHashJoinExec and
ShuffledHashJoinExec look more consistent.
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala
##########
@@ -40,4 +43,43 @@ trait ShuffledJoin extends BaseJoinExec {
throw new IllegalArgumentException(
s"ShuffledJoin should not take $x as the JoinType")
}
+
+ /**
+ * Creates variables and declarations for attributes in row.
+ *
+ * In order to defer the access after condition and also only access once in
the loop,
+ * the variables should be declared separately from accessing the columns,
we can't use the
+ * codegen of BoundReference here.
+ */
+ protected def createVars(
+ ctx: CodegenContext,
+ row: String,
+ attributes: Seq[Attribute]): (Seq[ExprCode], Seq[String]) = {
+ ctx.INPUT_ROW = row
+ attributes.zipWithIndex.map { case (a, i) =>
+ val value = ctx.freshName("value")
+ val valueCode = CodeGenerator.getValue(row, a.dataType, i.toString)
+ val javaType = CodeGenerator.javaType(a.dataType)
+ val defaultValue = CodeGenerator.defaultValue(a.dataType)
+ if (a.nullable) {
+ val isNull = ctx.freshName("isNull")
+ val code =
+ code"""
+ |$isNull = $row.isNullAt($i);
+ |$value = $isNull ? $defaultValue : ($valueCode);
+ """.stripMargin
+ val leftVarsDecl =
Review comment:
Since you remove left concept, we better clean up these leftXXX
variables too.
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala
##########
@@ -40,4 +43,43 @@ trait ShuffledJoin extends BaseJoinExec {
throw new IllegalArgumentException(
s"ShuffledJoin should not take $x as the JoinType")
}
+
+ /**
+ * Creates variables and declarations for attributes in row.
+ *
+ * In order to defer the access after condition and also only access once in
the loop,
+ * the variables should be declared separately from accessing the columns,
we can't use the
+ * codegen of BoundReference here.
+ */
+ protected def createVars(
Review comment:
Original `createLeftVars` is created to defer accessing of row fields
after condition evaluation. But I look at the usage of this `createVars` in
`HashJoin`, I don't see such thing. If you don't do defer there, you can simply
use `BoundReference` codegen, it is much simpler.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]