c21 commented on a change in pull request #29277:
URL: https://github.com/apache/spark/pull/29277#discussion_r461912796
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
##########
@@ -70,4 +74,69 @@ case class ShuffledHashJoinExec(
join(streamIter, hashed, numOutputRows)
}
}
+
+ override def inputRDDs(): Seq[RDD[InternalRow]] = {
+ streamedPlan.execute() :: buildPlan.execute() :: Nil
+ }
+
+ override def needCopyResult: Boolean = true
+
+ override protected def doProduce(ctx: CodegenContext): String = {
+ // inline mutable state since not many join operations in a task
+ val streamedInput = ctx.addMutableState(
+ "scala.collection.Iterator", "streamedInput", v => s"$v = inputs[0];",
forceInline = true)
+ val buildInput = ctx.addMutableState(
+ "scala.collection.Iterator", "buildInput", v => s"$v = inputs[1];",
forceInline = true)
+ val initRelation = ctx.addMutableState(
+ CodeGenerator.JAVA_BOOLEAN, "initRelation", v => s"$v = false;",
forceInline = true)
+ val streamedRow = ctx.addMutableState(
+ "InternalRow", "streamedRow", forceInline = true)
+
+ val thisPlan = ctx.addReferenceObj("plan", this)
+ val (relationTerm, _) = prepareRelation(ctx)
+ val buildRelation = s"$relationTerm =
$thisPlan.buildHashedRelation($buildInput);"
+ val (streamInputVar, streamInputVarDecl) = createVars(ctx, streamedRow,
streamedPlan.output)
+
+ val join = joinType match {
+ case _: InnerLike => codegenInner(ctx, streamInputVar)
+ case LeftOuter | RightOuter => codegenOuter(ctx, streamInputVar)
+ case LeftSemi => codegenSemi(ctx, streamInputVar)
+ case LeftAnti => codegenAnti(ctx, streamInputVar)
+ case _: ExistenceJoin => codegenExistence(ctx, streamInputVar)
+ case x =>
+ throw new IllegalArgumentException(
+ s"ShuffledHashJoin should not take $x as the JoinType")
+ }
+
+ s"""
+ |// construct hash map for shuffled hash join build side
+ |if (!$initRelation) {
+ | $buildRelation
+ | $initRelation = true;
+ |}
+ |
+ |while ($streamedInput.hasNext()) {
+ | $streamedRow = (InternalRow) $streamedInput.next();
+ | ${streamInputVarDecl.mkString("\n")}
+ | $join
+ |
+ | if (shouldStop()) return;
+ |}
+ """.stripMargin
+ }
+
+ /**
+ * Returns a tuple of variable name for HashedRelation,
+ * and boolean false to indicate key not to be known unique in code-gen time.
+ */
+ protected override def prepareRelation(ctx: CodegenContext): (String,
Boolean) = {
+ if (relationTerm == null) {
+ // Inline mutable state since not many join operations in a task
+ relationTerm = ctx.addMutableState(
+ "org.apache.spark.sql.execution.joins.HashedRelation", "relation",
forceInline = true)
Review comment:
@viirya - `BroadcastHashJoinExec` needs to [broadcast build
side](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala#L222-L223)
in `prepareBroadcast()`. I feel it's hard to refactor there. Wondering do you
have any idea to make it cleaner?
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala
##########
@@ -40,4 +43,43 @@ trait ShuffledJoin extends BaseJoinExec {
throw new IllegalArgumentException(
s"ShuffledJoin should not take $x as the JoinType")
}
+
+ /**
+ * Creates variables and declarations for attributes in row.
+ *
+ * In order to defer the access after condition and also only access once in
the loop,
+ * the variables should be declared separately from accessing the columns,
we can't use the
+ * codegen of BoundReference here.
+ */
+ protected def createVars(
+ ctx: CodegenContext,
+ row: String,
+ attributes: Seq[Attribute]): (Seq[ExprCode], Seq[String]) = {
+ ctx.INPUT_ROW = row
+ attributes.zipWithIndex.map { case (a, i) =>
+ val value = ctx.freshName("value")
+ val valueCode = CodeGenerator.getValue(row, a.dataType, i.toString)
+ val javaType = CodeGenerator.javaType(a.dataType)
+ val defaultValue = CodeGenerator.defaultValue(a.dataType)
+ if (a.nullable) {
+ val isNull = ctx.freshName("isNull")
+ val code =
+ code"""
+ |$isNull = $row.isNullAt($i);
+ |$value = $isNull ? $defaultValue : ($valueCode);
+ """.stripMargin
+ val leftVarsDecl =
Review comment:
@viirya - sorry for missing this, done.
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala
##########
@@ -40,4 +43,43 @@ trait ShuffledJoin extends BaseJoinExec {
throw new IllegalArgumentException(
s"ShuffledJoin should not take $x as the JoinType")
}
+
+ /**
+ * Creates variables and declarations for attributes in row.
+ *
+ * In order to defer the access after condition and also only access once in
the loop,
+ * the variables should be declared separately from accessing the columns,
we can't use the
+ * codegen of BoundReference here.
+ */
+ protected def createVars(
Review comment:
@viirya - I was actually originally using `BoundReference`, but got
compilation error with variable redefinition. E.g., for code branch
`c21:codegen-fail` ([with change to `BoundReference` compare to this PR](
https://github.com/c21/spark/compare/codegen...codegen-fail)), and [example
query in
`JoinBenchmark`](https://github.com/apache/spark/blob/master/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/JoinBenchmark.scala#L153-L167).
The `BoundReference` version generated code like this:
```
while (shj_streamedInput_0.hasNext()) {
shj_streamedRow_0 = (InternalRow) shj_streamedInput_0.next();
// generate join key for stream side
long shj_value_0 = shj_streamedRow_0.getLong(0); // 1st definition here
// find matches from HashRelation
scala.collection.Iterator shj_matches_0 = false ? null :
(scala.collection.Iterator)shj_relation_0.get(shj_value_0);
while (shj_matches_0.hasNext()) {
UnsafeRow shj_matched_0 = (UnsafeRow) shj_matches_0.next();
long shj_value_0 = shj_streamedRow_0.getLong(0); // 2nd definition here
and compilation error
shj_mutableStateArray_0[0].write(0, shj_value_0);
}
}
```
So basically the variable `shj_value_0` here (stream side key ) needs to be
defer accessed in
[`HashJoin.consume()`](https://github.com/c21/spark/blob/codegen/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala#L441),
and it was first accessed in
[`HashJoin.genStreamSideJoinKey()`](https://github.com/c21/spark/blob/codegen/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala#L326).
So it seems that `BoundReference` not work for me out of box.
Let me know if it makes sense or there would be any other better approach,
thanks.
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
##########
@@ -233,410 +231,58 @@ case class BroadcastHashJoinExec(
}
/**
- * Returns the code for generating join key for stream side, and expression
of whether the key
- * has any null in it or not.
+ * Returns a tuple of variable name for broadcast HashedRelation,
+ * and a boolean to indicate whether keys of HashedRelation to be unique.
*/
- private def genStreamSideJoinKey(
- ctx: CodegenContext,
- input: Seq[ExprCode]): (ExprCode, String) = {
- ctx.currentVars = input
- if (streamedBoundKeys.length == 1 && streamedBoundKeys.head.dataType ==
LongType) {
- // generate the join key as Long
- val ev = streamedBoundKeys.head.genCode(ctx)
- (ev, ev.isNull)
- } else {
- // generate the join key as UnsafeRow
- val ev = GenerateUnsafeProjection.createCode(ctx, streamedBoundKeys)
- (ev, s"${ev.value}.anyNull()")
- }
- }
-
- /**
- * Generates the code for variable of build side.
- */
- private def genBuildSideVars(ctx: CodegenContext, matched: String):
Seq[ExprCode] = {
- ctx.currentVars = null
- ctx.INPUT_ROW = matched
- buildPlan.output.zipWithIndex.map { case (a, i) =>
- val ev = BoundReference(i, a.dataType, a.nullable).genCode(ctx)
- if (joinType.isInstanceOf[InnerLike]) {
- ev
- } else {
- // the variables are needed even there is no matched rows
- val isNull = ctx.freshName("isNull")
- val value = ctx.freshName("value")
- val javaType = CodeGenerator.javaType(a.dataType)
- val code = code"""
- |boolean $isNull = true;
- |$javaType $value = ${CodeGenerator.defaultValue(a.dataType)};
- |if ($matched != null) {
- | ${ev.code}
- | $isNull = ${ev.isNull};
- | $value = ${ev.value};
- |}
- """.stripMargin
- ExprCode(code, JavaCode.isNullVariable(isNull),
JavaCode.variable(value, a.dataType))
- }
- }
- }
-
- /**
- * Generate the (non-equi) condition used to filter joined rows. This is
used in Inner, Left Semi
- * and Left Anti joins.
- */
- private def getJoinCondition(
- ctx: CodegenContext,
- input: Seq[ExprCode]): (String, String, Seq[ExprCode]) = {
- val matched = ctx.freshName("matched")
- val buildVars = genBuildSideVars(ctx, matched)
- val checkCondition = if (condition.isDefined) {
- val expr = condition.get
- // evaluate the variables from build side that used by condition
- val eval = evaluateRequiredVariables(buildPlan.output, buildVars,
expr.references)
- // filter the output via condition
- ctx.currentVars = input ++ buildVars
- val ev =
- BindReferences.bindReference(expr, streamedPlan.output ++
buildPlan.output).genCode(ctx)
- val skipRow = s"${ev.isNull} || !${ev.value}"
- s"""
- |$eval
- |${ev.code}
- |if (!($skipRow))
- """.stripMargin
- } else {
- ""
- }
- (matched, checkCondition, buildVars)
- }
-
- /**
- * Generates the code for Inner join.
- */
- private def codegenInner(ctx: CodegenContext, input: Seq[ExprCode]): String
= {
- val (broadcastRelation, relationTerm) = prepareBroadcast(ctx)
- val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
- val (matched, checkCondition, buildVars) = getJoinCondition(ctx, input)
- val numOutput = metricTerm(ctx, "numOutputRows")
-
- val resultVars = buildSide match {
- case BuildLeft => buildVars ++ input
- case BuildRight => input ++ buildVars
- }
- if (broadcastRelation.value.keyIsUnique) {
- s"""
- |// generate join key for stream side
- |${keyEv.code}
- |// find matches from HashedRelation
- |UnsafeRow $matched = $anyNull ? null:
(UnsafeRow)$relationTerm.getValue(${keyEv.value});
- |if ($matched != null) {
- | $checkCondition {
- | $numOutput.add(1);
- | ${consume(ctx, resultVars)}
- | }
- |}
- """.stripMargin
-
- } else {
- val matches = ctx.freshName("matches")
- val iteratorCls = classOf[Iterator[UnsafeRow]].getName
- s"""
- |// generate join key for stream side
- |${keyEv.code}
- |// find matches from HashRelation
- |$iteratorCls $matches = $anyNull ? null :
($iteratorCls)$relationTerm.get(${keyEv.value});
- |if ($matches != null) {
- | while ($matches.hasNext()) {
- | UnsafeRow $matched = (UnsafeRow) $matches.next();
- | $checkCondition {
- | $numOutput.add(1);
- | ${consume(ctx, resultVars)}
- | }
- | }
- |}
- """.stripMargin
- }
- }
-
- /**
- * Generates the code for left or right outer join.
- */
- private def codegenOuter(ctx: CodegenContext, input: Seq[ExprCode]): String
= {
- val (broadcastRelation, relationTerm) = prepareBroadcast(ctx)
- val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
- val matched = ctx.freshName("matched")
- val buildVars = genBuildSideVars(ctx, matched)
- val numOutput = metricTerm(ctx, "numOutputRows")
-
- // filter the output via condition
- val conditionPassed = ctx.freshName("conditionPassed")
- val checkCondition = if (condition.isDefined) {
- val expr = condition.get
- // evaluate the variables from build side that used by condition
- val eval = evaluateRequiredVariables(buildPlan.output, buildVars,
expr.references)
- ctx.currentVars = input ++ buildVars
- val ev =
- BindReferences.bindReference(expr, streamedPlan.output ++
buildPlan.output).genCode(ctx)
- s"""
- |boolean $conditionPassed = true;
- |${eval.trim}
- |if ($matched != null) {
- | ${ev.code}
- | $conditionPassed = !${ev.isNull} && ${ev.value};
- |}
- """.stripMargin
- } else {
- s"final boolean $conditionPassed = true;"
- }
-
- val resultVars = buildSide match {
- case BuildLeft => buildVars ++ input
- case BuildRight => input ++ buildVars
- }
- if (broadcastRelation.value.keyIsUnique) {
- s"""
- |// generate join key for stream side
- |${keyEv.code}
- |// find matches from HashedRelation
- |UnsafeRow $matched = $anyNull ? null:
(UnsafeRow)$relationTerm.getValue(${keyEv.value});
- |${checkCondition.trim}
- |if (!$conditionPassed) {
- | $matched = null;
- | // reset the variables those are already evaluated.
- | ${buildVars.filter(_.code.isEmpty).map(v => s"${v.isNull} =
true;").mkString("\n")}
- |}
- |$numOutput.add(1);
- |${consume(ctx, resultVars)}
- """.stripMargin
-
- } else {
- val matches = ctx.freshName("matches")
- val iteratorCls = classOf[Iterator[UnsafeRow]].getName
- val found = ctx.freshName("found")
- s"""
- |// generate join key for stream side
- |${keyEv.code}
- |// find matches from HashRelation
- |$iteratorCls $matches = $anyNull ? null :
($iteratorCls)$relationTerm.get(${keyEv.value});
- |boolean $found = false;
- |// the last iteration of this loop is to emit an empty row if there
is no matched rows.
- |while ($matches != null && $matches.hasNext() || !$found) {
- | UnsafeRow $matched = $matches != null && $matches.hasNext() ?
- | (UnsafeRow) $matches.next() : null;
- | ${checkCondition.trim}
- | if ($conditionPassed) {
- | $found = true;
- | $numOutput.add(1);
- | ${consume(ctx, resultVars)}
- | }
- |}
- """.stripMargin
- }
- }
-
- /**
- * Generates the code for left semi join.
- */
- private def codegenSemi(ctx: CodegenContext, input: Seq[ExprCode]): String =
{
+ protected override def prepareRelation(ctx: CodegenContext): (String,
Boolean) = {
Review comment:
a new method `prepareRelation` is added to call `prepareBroadcast()` and
get to know whether the key is known to be unique during codegen time.
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
##########
@@ -23,6 +23,7 @@ import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD
Review comment:
all change in `ShuffledHashJoinExec` here are real change, not
refactoring.
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
##########
@@ -903,6 +904,10 @@ case class CollapseCodegenStages(
// The children of SortMergeJoin should do codegen separately.
j.withNewChildren(j.children.map(
child => InputAdapter(insertWholeStageCodegen(child))))
+ case j: ShuffledHashJoinExec =>
Review comment:
codegen children of `ShuffledHashJoinExec` separately same as
`SortMergeJoinExec`.
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
##########
@@ -316,6 +318,387 @@ trait HashJoin extends BaseJoinExec {
resultProj(r)
}
}
+
+ /**
+ * Returns the code for generating join key for stream side, and expression
of whether the key
+ * has any null in it or not.
+ */
+ protected def genStreamSideJoinKey(
+ ctx: CodegenContext,
+ input: Seq[ExprCode]): (ExprCode, String) = {
+ ctx.currentVars = input
+ if (streamedBoundKeys.length == 1 && streamedBoundKeys.head.dataType ==
LongType) {
+ // generate the join key as Long
+ val ev = streamedBoundKeys.head.genCode(ctx)
+ (ev, ev.isNull)
+ } else {
+ // generate the join key as UnsafeRow
+ val ev = GenerateUnsafeProjection.createCode(ctx, streamedBoundKeys)
+ (ev, s"${ev.value}.anyNull()")
+ }
+ }
+
+ /**
+ * Generates the code for variable of build side.
+ */
+ private def genBuildSideVars(ctx: CodegenContext, matched: String):
Seq[ExprCode] = {
+ ctx.currentVars = null
+ ctx.INPUT_ROW = matched
+ buildPlan.output.zipWithIndex.map { case (a, i) =>
+ val ev = BoundReference(i, a.dataType, a.nullable).genCode(ctx)
+ if (joinType.isInstanceOf[InnerLike]) {
+ ev
+ } else {
+ // the variables are needed even there is no matched rows
+ val isNull = ctx.freshName("isNull")
+ val value = ctx.freshName("value")
+ val javaType = CodeGenerator.javaType(a.dataType)
+ val code = code"""
+ |boolean $isNull = true;
+ |$javaType $value = ${CodeGenerator.defaultValue(a.dataType)};
+ |if ($matched != null) {
+ | ${ev.code}
+ | $isNull = ${ev.isNull};
+ | $value = ${ev.value};
+ |}
+ """.stripMargin
+ ExprCode(code, JavaCode.isNullVariable(isNull),
JavaCode.variable(value, a.dataType))
+ }
+ }
+ }
+
+ /**
+ * Generate the (non-equi) condition used to filter joined rows. This is
used in Inner, Left Semi
+ * and Left Anti joins.
+ */
+ protected def getJoinCondition(
+ ctx: CodegenContext,
+ input: Seq[ExprCode]): (String, String, Seq[ExprCode]) = {
+ val matched = ctx.freshName("matched")
+ val buildVars = genBuildSideVars(ctx, matched)
+ val checkCondition = if (condition.isDefined) {
+ val expr = condition.get
+ // evaluate the variables from build side that used by condition
+ val eval = evaluateRequiredVariables(buildPlan.output, buildVars,
expr.references)
+ // filter the output via condition
+ ctx.currentVars = input ++ buildVars
+ val ev =
+ BindReferences.bindReference(expr, streamedPlan.output ++
buildPlan.output).genCode(ctx)
+ val skipRow = s"${ev.isNull} || !${ev.value}"
+ s"""
+ |$eval
+ |${ev.code}
+ |if (!($skipRow))
+ """.stripMargin
+ } else {
+ ""
+ }
+ (matched, checkCondition, buildVars)
+ }
+
+ /**
+ * Generates the code for Inner join.
+ */
+ protected def codegenInner(ctx: CodegenContext, input: Seq[ExprCode]):
String = {
+ val (relationTerm, keyIsKnownUnique) = prepareRelation(ctx)
+ val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
+ val (matched, checkCondition, buildVars) = getJoinCondition(ctx, input)
+ val numOutput = metricTerm(ctx, "numOutputRows")
+
+ val resultVars = buildSide match {
+ case BuildLeft => buildVars ++ input
+ case BuildRight => input ++ buildVars
+ }
+
+ if (keyIsKnownUnique) {
Review comment:
added `keyIsKnownUnique` to support unique-key code path for
`BroadcastHashJoinExec`.
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
##########
@@ -316,6 +318,387 @@ trait HashJoin extends BaseJoinExec {
resultProj(r)
}
}
+
+ /**
+ * Returns the code for generating join key for stream side, and expression
of whether the key
+ * has any null in it or not.
+ */
+ protected def genStreamSideJoinKey(
+ ctx: CodegenContext,
+ input: Seq[ExprCode]): (ExprCode, String) = {
+ ctx.currentVars = input
+ if (streamedBoundKeys.length == 1 && streamedBoundKeys.head.dataType ==
LongType) {
+ // generate the join key as Long
+ val ev = streamedBoundKeys.head.genCode(ctx)
+ (ev, ev.isNull)
+ } else {
+ // generate the join key as UnsafeRow
+ val ev = GenerateUnsafeProjection.createCode(ctx, streamedBoundKeys)
+ (ev, s"${ev.value}.anyNull()")
+ }
+ }
+
+ /**
+ * Generates the code for variable of build side.
+ */
+ private def genBuildSideVars(ctx: CodegenContext, matched: String):
Seq[ExprCode] = {
+ ctx.currentVars = null
+ ctx.INPUT_ROW = matched
+ buildPlan.output.zipWithIndex.map { case (a, i) =>
+ val ev = BoundReference(i, a.dataType, a.nullable).genCode(ctx)
+ if (joinType.isInstanceOf[InnerLike]) {
+ ev
+ } else {
+ // the variables are needed even there is no matched rows
+ val isNull = ctx.freshName("isNull")
+ val value = ctx.freshName("value")
+ val javaType = CodeGenerator.javaType(a.dataType)
+ val code = code"""
+ |boolean $isNull = true;
+ |$javaType $value = ${CodeGenerator.defaultValue(a.dataType)};
+ |if ($matched != null) {
+ | ${ev.code}
+ | $isNull = ${ev.isNull};
+ | $value = ${ev.value};
+ |}
+ """.stripMargin
+ ExprCode(code, JavaCode.isNullVariable(isNull),
JavaCode.variable(value, a.dataType))
+ }
+ }
+ }
+
+ /**
+ * Generate the (non-equi) condition used to filter joined rows. This is
used in Inner, Left Semi
+ * and Left Anti joins.
+ */
+ protected def getJoinCondition(
+ ctx: CodegenContext,
+ input: Seq[ExprCode]): (String, String, Seq[ExprCode]) = {
+ val matched = ctx.freshName("matched")
+ val buildVars = genBuildSideVars(ctx, matched)
+ val checkCondition = if (condition.isDefined) {
+ val expr = condition.get
+ // evaluate the variables from build side that used by condition
+ val eval = evaluateRequiredVariables(buildPlan.output, buildVars,
expr.references)
+ // filter the output via condition
+ ctx.currentVars = input ++ buildVars
+ val ev =
+ BindReferences.bindReference(expr, streamedPlan.output ++
buildPlan.output).genCode(ctx)
+ val skipRow = s"${ev.isNull} || !${ev.value}"
+ s"""
+ |$eval
+ |${ev.code}
+ |if (!($skipRow))
+ """.stripMargin
+ } else {
+ ""
+ }
+ (matched, checkCondition, buildVars)
+ }
+
+ /**
+ * Generates the code for Inner join.
+ */
+ protected def codegenInner(ctx: CodegenContext, input: Seq[ExprCode]):
String = {
+ val (relationTerm, keyIsKnownUnique) = prepareRelation(ctx)
+ val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
+ val (matched, checkCondition, buildVars) = getJoinCondition(ctx, input)
+ val numOutput = metricTerm(ctx, "numOutputRows")
+
+ val resultVars = buildSide match {
+ case BuildLeft => buildVars ++ input
+ case BuildRight => input ++ buildVars
+ }
+
+ if (keyIsKnownUnique) {
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// find matches from HashedRelation
+ |UnsafeRow $matched = $anyNull ? null:
(UnsafeRow)$relationTerm.getValue(${keyEv.value});
+ |if ($matched != null) {
+ | $checkCondition {
+ | $numOutput.add(1);
+ | ${consume(ctx, resultVars)}
+ | }
+ |}
+ """.stripMargin
+ } else {
+ val matches = ctx.freshName("matches")
+ val iteratorCls = classOf[Iterator[UnsafeRow]].getName
+
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// find matches from HashRelation
+ |$iteratorCls $matches = $anyNull ?
+ | null : ($iteratorCls)$relationTerm.get(${keyEv.value});
+ |if ($matches != null) {
+ | while ($matches.hasNext()) {
+ | UnsafeRow $matched = (UnsafeRow) $matches.next();
+ | $checkCondition {
+ | $numOutput.add(1);
+ | ${consume(ctx, resultVars)}
+ | }
+ | }
+ |}
+ """.stripMargin
+ }
+ }
+
+ /**
+ * Generates the code for left or right outer join.
+ */
+ protected def codegenOuter(ctx: CodegenContext, input: Seq[ExprCode]):
String = {
+ val (relationTerm, keyIsKnownUnique) = prepareRelation(ctx)
+ val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
+ val matched = ctx.freshName("matched")
+ val buildVars = genBuildSideVars(ctx, matched)
+ val numOutput = metricTerm(ctx, "numOutputRows")
+
+ // filter the output via condition
+ val conditionPassed = ctx.freshName("conditionPassed")
+ val checkCondition = if (condition.isDefined) {
+ val expr = condition.get
+ // evaluate the variables from build side that used by condition
+ val eval = evaluateRequiredVariables(buildPlan.output, buildVars,
expr.references)
+ ctx.currentVars = input ++ buildVars
+ val ev =
+ BindReferences.bindReference(expr, streamedPlan.output ++
buildPlan.output).genCode(ctx)
+ s"""
+ |boolean $conditionPassed = true;
+ |${eval.trim}
+ |if ($matched != null) {
+ | ${ev.code}
+ | $conditionPassed = !${ev.isNull} && ${ev.value};
+ |}
+ """.stripMargin
+ } else {
+ s"final boolean $conditionPassed = true;"
+ }
+
+ val resultVars = buildSide match {
+ case BuildLeft => buildVars ++ input
+ case BuildRight => input ++ buildVars
+ }
+
+ if (keyIsKnownUnique) {
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// find matches from HashedRelation
+ |UnsafeRow $matched = $anyNull ? null:
(UnsafeRow)$relationTerm.getValue(${keyEv.value});
+ |${checkCondition.trim}
+ |if (!$conditionPassed) {
+ | $matched = null;
+ | // reset the variables those are already evaluated.
+ | ${buildVars.filter(_.code.isEmpty).map(v => s"${v.isNull} =
true;").mkString("\n")}
+ |}
+ |$numOutput.add(1);
+ |${consume(ctx, resultVars)}
+ """.stripMargin
+ } else {
+ val matches = ctx.freshName("matches")
+ val iteratorCls = classOf[Iterator[UnsafeRow]].getName
+ val found = ctx.freshName("found")
+
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// find matches from HashRelation
+ |$iteratorCls $matches = $anyNull ? null :
($iteratorCls)$relationTerm.get(${keyEv.value});
+ |boolean $found = false;
+ |// the last iteration of this loop is to emit an empty row if there
is no matched rows.
+ |while ($matches != null && $matches.hasNext() || !$found) {
+ | UnsafeRow $matched = $matches != null && $matches.hasNext() ?
+ | (UnsafeRow) $matches.next() : null;
+ | ${checkCondition.trim}
+ | if ($conditionPassed) {
+ | $found = true;
+ | $numOutput.add(1);
+ | ${consume(ctx, resultVars)}
+ | }
+ |}
+ """.stripMargin
+ }
+ }
+
+ /**
+ * Generates the code for left semi join.
+ */
+ protected def codegenSemi(ctx: CodegenContext, input: Seq[ExprCode]): String
= {
+ val (relationTerm, keyIsKnownUnique) = prepareRelation(ctx)
+ val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
+ val (matched, checkCondition, _) = getJoinCondition(ctx, input)
+ val numOutput = metricTerm(ctx, "numOutputRows")
+
+ if (keyIsKnownUnique) {
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// find matches from HashedRelation
+ |UnsafeRow $matched = $anyNull ? null:
(UnsafeRow)$relationTerm.getValue(${keyEv.value});
+ |if ($matched != null) {
+ | $checkCondition {
+ | $numOutput.add(1);
+ | ${consume(ctx, input)}
+ | }
+ |}
+ """.stripMargin
+ } else {
+ val matches = ctx.freshName("matches")
+ val iteratorCls = classOf[Iterator[UnsafeRow]].getName
+ val found = ctx.freshName("found")
+
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// find matches from HashRelation
+ |$iteratorCls $matches = $anyNull ? null :
($iteratorCls)$relationTerm.get(${keyEv.value});
+ |if ($matches != null) {
+ | boolean $found = false;
+ | while (!$found && $matches.hasNext()) {
+ | UnsafeRow $matched = (UnsafeRow) $matches.next();
+ | $checkCondition {
+ | $found = true;
+ | }
+ | }
+ | if ($found) {
+ | $numOutput.add(1);
+ | ${consume(ctx, input)}
+ | }
+ |}
+ """.stripMargin
+ }
+ }
+
+ /**
+ * Generates the code for anti join.
+ */
+ protected def codegenAnti(ctx: CodegenContext, input: Seq[ExprCode]): String
= {
+ val (relationTerm, keyIsKnownUnique) = prepareRelation(ctx)
+ val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
+ val (matched, checkCondition, _) = getJoinCondition(ctx, input)
+ val numOutput = metricTerm(ctx, "numOutputRows")
+
+ if (keyIsKnownUnique) {
+ val found = ctx.freshName("found")
+ s"""
+ |boolean $found = false;
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// Check if the key has nulls.
+ |if (!($anyNull)) {
+ | // Check if the HashedRelation exists.
+ | UnsafeRow $matched =
(UnsafeRow)$relationTerm.getValue(${keyEv.value});
+ | if ($matched != null) {
+ | // Evaluate the condition.
+ | $checkCondition {
+ | $found = true;
+ | }
+ | }
+ |}
+ |if (!$found) {
+ | $numOutput.add(1);
+ | ${consume(ctx, input)}
+ |}
+ """.stripMargin
+ } else {
+ val matches = ctx.freshName("matches")
+ val iteratorCls = classOf[Iterator[UnsafeRow]].getName
+ val found = ctx.freshName("found")
+ s"""
+ |boolean $found = false;
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// Check if the key has nulls.
+ |if (!($anyNull)) {
+ | // Check if the HashedRelation exists.
+ | $iteratorCls $matches =
($iteratorCls)$relationTerm.get(${keyEv.value});
+ | if ($matches != null) {
+ | // Evaluate the condition.
+ | while (!$found && $matches.hasNext()) {
+ | UnsafeRow $matched = (UnsafeRow) $matches.next();
+ | $checkCondition {
+ | $found = true;
+ | }
+ | }
+ | }
+ |}
+ |if (!$found) {
+ | $numOutput.add(1);
+ | ${consume(ctx, input)}
+ |}
+ """.stripMargin
+ }
+ }
+
+ /**
+ * Generates the code for existence join.
+ */
+ protected def codegenExistence(ctx: CodegenContext, input: Seq[ExprCode]):
String = {
+ val (relationTerm, keyIsKnownUnique) = prepareRelation(ctx)
Review comment:
changed to `prepareRelation` to get `keyIsKnownUnique`.
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
##########
@@ -316,6 +318,387 @@ trait HashJoin extends BaseJoinExec {
resultProj(r)
}
}
+
+ /**
+ * Returns the code for generating join key for stream side, and expression
of whether the key
+ * has any null in it or not.
+ */
+ protected def genStreamSideJoinKey(
+ ctx: CodegenContext,
+ input: Seq[ExprCode]): (ExprCode, String) = {
+ ctx.currentVars = input
+ if (streamedBoundKeys.length == 1 && streamedBoundKeys.head.dataType ==
LongType) {
+ // generate the join key as Long
+ val ev = streamedBoundKeys.head.genCode(ctx)
+ (ev, ev.isNull)
+ } else {
+ // generate the join key as UnsafeRow
+ val ev = GenerateUnsafeProjection.createCode(ctx, streamedBoundKeys)
+ (ev, s"${ev.value}.anyNull()")
+ }
+ }
+
+ /**
+ * Generates the code for variable of build side.
+ */
+ private def genBuildSideVars(ctx: CodegenContext, matched: String):
Seq[ExprCode] = {
+ ctx.currentVars = null
+ ctx.INPUT_ROW = matched
+ buildPlan.output.zipWithIndex.map { case (a, i) =>
+ val ev = BoundReference(i, a.dataType, a.nullable).genCode(ctx)
+ if (joinType.isInstanceOf[InnerLike]) {
+ ev
+ } else {
+ // the variables are needed even there is no matched rows
+ val isNull = ctx.freshName("isNull")
+ val value = ctx.freshName("value")
+ val javaType = CodeGenerator.javaType(a.dataType)
+ val code = code"""
+ |boolean $isNull = true;
+ |$javaType $value = ${CodeGenerator.defaultValue(a.dataType)};
+ |if ($matched != null) {
+ | ${ev.code}
+ | $isNull = ${ev.isNull};
+ | $value = ${ev.value};
+ |}
+ """.stripMargin
+ ExprCode(code, JavaCode.isNullVariable(isNull),
JavaCode.variable(value, a.dataType))
+ }
+ }
+ }
+
+ /**
+ * Generate the (non-equi) condition used to filter joined rows. This is
used in Inner, Left Semi
+ * and Left Anti joins.
+ */
+ protected def getJoinCondition(
+ ctx: CodegenContext,
+ input: Seq[ExprCode]): (String, String, Seq[ExprCode]) = {
+ val matched = ctx.freshName("matched")
+ val buildVars = genBuildSideVars(ctx, matched)
+ val checkCondition = if (condition.isDefined) {
+ val expr = condition.get
+ // evaluate the variables from build side that used by condition
+ val eval = evaluateRequiredVariables(buildPlan.output, buildVars,
expr.references)
+ // filter the output via condition
+ ctx.currentVars = input ++ buildVars
+ val ev =
+ BindReferences.bindReference(expr, streamedPlan.output ++
buildPlan.output).genCode(ctx)
+ val skipRow = s"${ev.isNull} || !${ev.value}"
+ s"""
+ |$eval
+ |${ev.code}
+ |if (!($skipRow))
+ """.stripMargin
+ } else {
+ ""
+ }
+ (matched, checkCondition, buildVars)
+ }
+
+ /**
+ * Generates the code for Inner join.
+ */
+ protected def codegenInner(ctx: CodegenContext, input: Seq[ExprCode]):
String = {
+ val (relationTerm, keyIsKnownUnique) = prepareRelation(ctx)
+ val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
+ val (matched, checkCondition, buildVars) = getJoinCondition(ctx, input)
+ val numOutput = metricTerm(ctx, "numOutputRows")
+
+ val resultVars = buildSide match {
+ case BuildLeft => buildVars ++ input
+ case BuildRight => input ++ buildVars
+ }
+
+ if (keyIsKnownUnique) {
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// find matches from HashedRelation
+ |UnsafeRow $matched = $anyNull ? null:
(UnsafeRow)$relationTerm.getValue(${keyEv.value});
+ |if ($matched != null) {
+ | $checkCondition {
+ | $numOutput.add(1);
+ | ${consume(ctx, resultVars)}
+ | }
+ |}
+ """.stripMargin
+ } else {
+ val matches = ctx.freshName("matches")
+ val iteratorCls = classOf[Iterator[UnsafeRow]].getName
+
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// find matches from HashRelation
+ |$iteratorCls $matches = $anyNull ?
+ | null : ($iteratorCls)$relationTerm.get(${keyEv.value});
+ |if ($matches != null) {
+ | while ($matches.hasNext()) {
+ | UnsafeRow $matched = (UnsafeRow) $matches.next();
+ | $checkCondition {
+ | $numOutput.add(1);
+ | ${consume(ctx, resultVars)}
+ | }
+ | }
+ |}
+ """.stripMargin
+ }
+ }
+
+ /**
+ * Generates the code for left or right outer join.
+ */
+ protected def codegenOuter(ctx: CodegenContext, input: Seq[ExprCode]):
String = {
+ val (relationTerm, keyIsKnownUnique) = prepareRelation(ctx)
+ val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
+ val matched = ctx.freshName("matched")
+ val buildVars = genBuildSideVars(ctx, matched)
+ val numOutput = metricTerm(ctx, "numOutputRows")
+
+ // filter the output via condition
+ val conditionPassed = ctx.freshName("conditionPassed")
+ val checkCondition = if (condition.isDefined) {
+ val expr = condition.get
+ // evaluate the variables from build side that used by condition
+ val eval = evaluateRequiredVariables(buildPlan.output, buildVars,
expr.references)
+ ctx.currentVars = input ++ buildVars
+ val ev =
+ BindReferences.bindReference(expr, streamedPlan.output ++
buildPlan.output).genCode(ctx)
+ s"""
+ |boolean $conditionPassed = true;
+ |${eval.trim}
+ |if ($matched != null) {
+ | ${ev.code}
+ | $conditionPassed = !${ev.isNull} && ${ev.value};
+ |}
+ """.stripMargin
+ } else {
+ s"final boolean $conditionPassed = true;"
+ }
+
+ val resultVars = buildSide match {
+ case BuildLeft => buildVars ++ input
+ case BuildRight => input ++ buildVars
+ }
+
+ if (keyIsKnownUnique) {
Review comment:
added `keyIsKnownUnique` to support unique-key code path for
`BroadcastHashJoinExec`.
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
##########
@@ -233,410 +231,58 @@ case class BroadcastHashJoinExec(
}
/**
- * Returns the code for generating join key for stream side, and expression
of whether the key
- * has any null in it or not.
+ * Returns a tuple of variable name for broadcast HashedRelation,
+ * and a boolean to indicate whether keys of HashedRelation to be unique.
*/
- private def genStreamSideJoinKey(
- ctx: CodegenContext,
- input: Seq[ExprCode]): (ExprCode, String) = {
- ctx.currentVars = input
- if (streamedBoundKeys.length == 1 && streamedBoundKeys.head.dataType ==
LongType) {
- // generate the join key as Long
- val ev = streamedBoundKeys.head.genCode(ctx)
- (ev, ev.isNull)
- } else {
- // generate the join key as UnsafeRow
- val ev = GenerateUnsafeProjection.createCode(ctx, streamedBoundKeys)
- (ev, s"${ev.value}.anyNull()")
- }
- }
-
- /**
- * Generates the code for variable of build side.
- */
- private def genBuildSideVars(ctx: CodegenContext, matched: String):
Seq[ExprCode] = {
- ctx.currentVars = null
- ctx.INPUT_ROW = matched
- buildPlan.output.zipWithIndex.map { case (a, i) =>
- val ev = BoundReference(i, a.dataType, a.nullable).genCode(ctx)
- if (joinType.isInstanceOf[InnerLike]) {
- ev
- } else {
- // the variables are needed even there is no matched rows
- val isNull = ctx.freshName("isNull")
- val value = ctx.freshName("value")
- val javaType = CodeGenerator.javaType(a.dataType)
- val code = code"""
- |boolean $isNull = true;
- |$javaType $value = ${CodeGenerator.defaultValue(a.dataType)};
- |if ($matched != null) {
- | ${ev.code}
- | $isNull = ${ev.isNull};
- | $value = ${ev.value};
- |}
- """.stripMargin
- ExprCode(code, JavaCode.isNullVariable(isNull),
JavaCode.variable(value, a.dataType))
- }
- }
- }
-
- /**
- * Generate the (non-equi) condition used to filter joined rows. This is
used in Inner, Left Semi
- * and Left Anti joins.
- */
- private def getJoinCondition(
- ctx: CodegenContext,
- input: Seq[ExprCode]): (String, String, Seq[ExprCode]) = {
- val matched = ctx.freshName("matched")
- val buildVars = genBuildSideVars(ctx, matched)
- val checkCondition = if (condition.isDefined) {
- val expr = condition.get
- // evaluate the variables from build side that used by condition
- val eval = evaluateRequiredVariables(buildPlan.output, buildVars,
expr.references)
- // filter the output via condition
- ctx.currentVars = input ++ buildVars
- val ev =
- BindReferences.bindReference(expr, streamedPlan.output ++
buildPlan.output).genCode(ctx)
- val skipRow = s"${ev.isNull} || !${ev.value}"
- s"""
- |$eval
- |${ev.code}
- |if (!($skipRow))
- """.stripMargin
- } else {
- ""
- }
- (matched, checkCondition, buildVars)
- }
-
- /**
- * Generates the code for Inner join.
- */
- private def codegenInner(ctx: CodegenContext, input: Seq[ExprCode]): String
= {
- val (broadcastRelation, relationTerm) = prepareBroadcast(ctx)
- val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
- val (matched, checkCondition, buildVars) = getJoinCondition(ctx, input)
- val numOutput = metricTerm(ctx, "numOutputRows")
-
- val resultVars = buildSide match {
- case BuildLeft => buildVars ++ input
- case BuildRight => input ++ buildVars
- }
- if (broadcastRelation.value.keyIsUnique) {
- s"""
- |// generate join key for stream side
- |${keyEv.code}
- |// find matches from HashedRelation
- |UnsafeRow $matched = $anyNull ? null:
(UnsafeRow)$relationTerm.getValue(${keyEv.value});
- |if ($matched != null) {
- | $checkCondition {
- | $numOutput.add(1);
- | ${consume(ctx, resultVars)}
- | }
- |}
- """.stripMargin
-
- } else {
- val matches = ctx.freshName("matches")
- val iteratorCls = classOf[Iterator[UnsafeRow]].getName
- s"""
- |// generate join key for stream side
- |${keyEv.code}
- |// find matches from HashRelation
- |$iteratorCls $matches = $anyNull ? null :
($iteratorCls)$relationTerm.get(${keyEv.value});
- |if ($matches != null) {
- | while ($matches.hasNext()) {
- | UnsafeRow $matched = (UnsafeRow) $matches.next();
- | $checkCondition {
- | $numOutput.add(1);
- | ${consume(ctx, resultVars)}
- | }
- | }
- |}
- """.stripMargin
- }
- }
-
- /**
- * Generates the code for left or right outer join.
- */
- private def codegenOuter(ctx: CodegenContext, input: Seq[ExprCode]): String
= {
- val (broadcastRelation, relationTerm) = prepareBroadcast(ctx)
- val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
- val matched = ctx.freshName("matched")
- val buildVars = genBuildSideVars(ctx, matched)
- val numOutput = metricTerm(ctx, "numOutputRows")
-
- // filter the output via condition
- val conditionPassed = ctx.freshName("conditionPassed")
- val checkCondition = if (condition.isDefined) {
- val expr = condition.get
- // evaluate the variables from build side that used by condition
- val eval = evaluateRequiredVariables(buildPlan.output, buildVars,
expr.references)
- ctx.currentVars = input ++ buildVars
- val ev =
- BindReferences.bindReference(expr, streamedPlan.output ++
buildPlan.output).genCode(ctx)
- s"""
- |boolean $conditionPassed = true;
- |${eval.trim}
- |if ($matched != null) {
- | ${ev.code}
- | $conditionPassed = !${ev.isNull} && ${ev.value};
- |}
- """.stripMargin
- } else {
- s"final boolean $conditionPassed = true;"
- }
-
- val resultVars = buildSide match {
- case BuildLeft => buildVars ++ input
- case BuildRight => input ++ buildVars
- }
- if (broadcastRelation.value.keyIsUnique) {
- s"""
- |// generate join key for stream side
- |${keyEv.code}
- |// find matches from HashedRelation
- |UnsafeRow $matched = $anyNull ? null:
(UnsafeRow)$relationTerm.getValue(${keyEv.value});
- |${checkCondition.trim}
- |if (!$conditionPassed) {
- | $matched = null;
- | // reset the variables those are already evaluated.
- | ${buildVars.filter(_.code.isEmpty).map(v => s"${v.isNull} =
true;").mkString("\n")}
- |}
- |$numOutput.add(1);
- |${consume(ctx, resultVars)}
- """.stripMargin
-
- } else {
- val matches = ctx.freshName("matches")
- val iteratorCls = classOf[Iterator[UnsafeRow]].getName
- val found = ctx.freshName("found")
- s"""
- |// generate join key for stream side
- |${keyEv.code}
- |// find matches from HashRelation
- |$iteratorCls $matches = $anyNull ? null :
($iteratorCls)$relationTerm.get(${keyEv.value});
- |boolean $found = false;
- |// the last iteration of this loop is to emit an empty row if there
is no matched rows.
- |while ($matches != null && $matches.hasNext() || !$found) {
- | UnsafeRow $matched = $matches != null && $matches.hasNext() ?
- | (UnsafeRow) $matches.next() : null;
- | ${checkCondition.trim}
- | if ($conditionPassed) {
- | $found = true;
- | $numOutput.add(1);
- | ${consume(ctx, resultVars)}
- | }
- |}
- """.stripMargin
- }
- }
-
- /**
- * Generates the code for left semi join.
- */
- private def codegenSemi(ctx: CodegenContext, input: Seq[ExprCode]): String =
{
+ protected override def prepareRelation(ctx: CodegenContext): (String,
Boolean) = {
val (broadcastRelation, relationTerm) = prepareBroadcast(ctx)
- val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
- val (matched, checkCondition, _) = getJoinCondition(ctx, input)
- val numOutput = metricTerm(ctx, "numOutputRows")
- if (broadcastRelation.value.keyIsUnique) {
- s"""
- |// generate join key for stream side
- |${keyEv.code}
- |// find matches from HashedRelation
- |UnsafeRow $matched = $anyNull ? null:
(UnsafeRow)$relationTerm.getValue(${keyEv.value});
- |if ($matched != null) {
- | $checkCondition {
- | $numOutput.add(1);
- | ${consume(ctx, input)}
- | }
- |}
- """.stripMargin
- } else {
- val matches = ctx.freshName("matches")
- val iteratorCls = classOf[Iterator[UnsafeRow]].getName
- val found = ctx.freshName("found")
- s"""
- |// generate join key for stream side
- |${keyEv.code}
- |// find matches from HashRelation
- |$iteratorCls $matches = $anyNull ? null :
($iteratorCls)$relationTerm.get(${keyEv.value});
- |if ($matches != null) {
- | boolean $found = false;
- | while (!$found && $matches.hasNext()) {
- | UnsafeRow $matched = (UnsafeRow) $matches.next();
- | $checkCondition {
- | $found = true;
- | }
- | }
- | if ($found) {
- | $numOutput.add(1);
- | ${consume(ctx, input)}
- | }
- |}
- """.stripMargin
- }
+ (relationTerm, broadcastRelation.value.keyIsUnique)
}
/**
* Generates the code for anti join.
+ * Handles NULL-aware anti join (NAAJ) separately here.
*/
- private def codegenAnti(ctx: CodegenContext, input: Seq[ExprCode]): String =
{
- val (broadcastRelation, relationTerm) = prepareBroadcast(ctx)
- val uniqueKeyCodePath = broadcastRelation.value.keyIsUnique
- val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
- val (matched, checkCondition, _) = getJoinCondition(ctx, input)
- val numOutput = metricTerm(ctx, "numOutputRows")
-
+ protected override def codegenAnti(ctx: CodegenContext, input:
Seq[ExprCode]): String = {
Review comment:
`codegenAnti` is changed to keep NULL-aware anti join separately here,
and move other logic to `HashJoin.codegenAnti`.
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
##########
@@ -316,6 +318,387 @@ trait HashJoin extends BaseJoinExec {
resultProj(r)
}
}
+
+ /**
+ * Returns the code for generating join key for stream side, and expression
of whether the key
+ * has any null in it or not.
+ */
+ protected def genStreamSideJoinKey(
Review comment:
this method is copied from `BroadcastHashJoinExec` without change.
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
##########
@@ -316,6 +318,387 @@ trait HashJoin extends BaseJoinExec {
resultProj(r)
}
}
+
+ /**
+ * Returns the code for generating join key for stream side, and expression
of whether the key
+ * has any null in it or not.
+ */
+ protected def genStreamSideJoinKey(
+ ctx: CodegenContext,
+ input: Seq[ExprCode]): (ExprCode, String) = {
+ ctx.currentVars = input
+ if (streamedBoundKeys.length == 1 && streamedBoundKeys.head.dataType ==
LongType) {
+ // generate the join key as Long
+ val ev = streamedBoundKeys.head.genCode(ctx)
+ (ev, ev.isNull)
+ } else {
+ // generate the join key as UnsafeRow
+ val ev = GenerateUnsafeProjection.createCode(ctx, streamedBoundKeys)
+ (ev, s"${ev.value}.anyNull()")
+ }
+ }
+
+ /**
+ * Generates the code for variable of build side.
+ */
+ private def genBuildSideVars(ctx: CodegenContext, matched: String):
Seq[ExprCode] = {
+ ctx.currentVars = null
+ ctx.INPUT_ROW = matched
+ buildPlan.output.zipWithIndex.map { case (a, i) =>
+ val ev = BoundReference(i, a.dataType, a.nullable).genCode(ctx)
+ if (joinType.isInstanceOf[InnerLike]) {
+ ev
+ } else {
+ // the variables are needed even there is no matched rows
+ val isNull = ctx.freshName("isNull")
+ val value = ctx.freshName("value")
+ val javaType = CodeGenerator.javaType(a.dataType)
+ val code = code"""
+ |boolean $isNull = true;
+ |$javaType $value = ${CodeGenerator.defaultValue(a.dataType)};
+ |if ($matched != null) {
+ | ${ev.code}
+ | $isNull = ${ev.isNull};
+ | $value = ${ev.value};
+ |}
+ """.stripMargin
+ ExprCode(code, JavaCode.isNullVariable(isNull),
JavaCode.variable(value, a.dataType))
+ }
+ }
+ }
+
+ /**
+ * Generate the (non-equi) condition used to filter joined rows. This is
used in Inner, Left Semi
+ * and Left Anti joins.
+ */
+ protected def getJoinCondition(
+ ctx: CodegenContext,
+ input: Seq[ExprCode]): (String, String, Seq[ExprCode]) = {
+ val matched = ctx.freshName("matched")
+ val buildVars = genBuildSideVars(ctx, matched)
+ val checkCondition = if (condition.isDefined) {
+ val expr = condition.get
+ // evaluate the variables from build side that used by condition
+ val eval = evaluateRequiredVariables(buildPlan.output, buildVars,
expr.references)
+ // filter the output via condition
+ ctx.currentVars = input ++ buildVars
+ val ev =
+ BindReferences.bindReference(expr, streamedPlan.output ++
buildPlan.output).genCode(ctx)
+ val skipRow = s"${ev.isNull} || !${ev.value}"
+ s"""
+ |$eval
+ |${ev.code}
+ |if (!($skipRow))
+ """.stripMargin
+ } else {
+ ""
+ }
+ (matched, checkCondition, buildVars)
+ }
+
+ /**
+ * Generates the code for Inner join.
+ */
+ protected def codegenInner(ctx: CodegenContext, input: Seq[ExprCode]):
String = {
+ val (relationTerm, keyIsKnownUnique) = prepareRelation(ctx)
Review comment:
changed to `prepareRelation` to get `keyIsKnownUnique`.
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala
##########
@@ -40,4 +43,43 @@ trait ShuffledJoin extends BaseJoinExec {
throw new IllegalArgumentException(
s"ShuffledJoin should not take $x as the JoinType")
}
+
+ /**
+ * Creates variables and declarations for attributes in row.
+ *
+ * In order to defer the access after condition and also only access once in
the loop,
+ * the variables should be declared separately from accessing the columns,
we can't use the
+ * codegen of BoundReference here.
+ */
+ protected def createVars(
Review comment:
`createVars` is copied from `SortMergeJoinExec.createLeftVars()` to be
usable from `SortMergeJoinExec` and `ShuffledHashJoinExec` for generating code
for stream side input.
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
##########
@@ -316,6 +318,387 @@ trait HashJoin extends BaseJoinExec {
resultProj(r)
}
}
+
+ /**
+ * Returns the code for generating join key for stream side, and expression
of whether the key
+ * has any null in it or not.
+ */
+ protected def genStreamSideJoinKey(
+ ctx: CodegenContext,
+ input: Seq[ExprCode]): (ExprCode, String) = {
+ ctx.currentVars = input
+ if (streamedBoundKeys.length == 1 && streamedBoundKeys.head.dataType ==
LongType) {
+ // generate the join key as Long
+ val ev = streamedBoundKeys.head.genCode(ctx)
+ (ev, ev.isNull)
+ } else {
+ // generate the join key as UnsafeRow
+ val ev = GenerateUnsafeProjection.createCode(ctx, streamedBoundKeys)
+ (ev, s"${ev.value}.anyNull()")
+ }
+ }
+
+ /**
+ * Generates the code for variable of build side.
+ */
+ private def genBuildSideVars(ctx: CodegenContext, matched: String):
Seq[ExprCode] = {
Review comment:
this method is copied from `BroadcastHashJoinExec` without change.
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
##########
@@ -316,6 +318,387 @@ trait HashJoin extends BaseJoinExec {
resultProj(r)
}
}
+
+ /**
+ * Returns the code for generating join key for stream side, and expression
of whether the key
+ * has any null in it or not.
+ */
+ protected def genStreamSideJoinKey(
+ ctx: CodegenContext,
+ input: Seq[ExprCode]): (ExprCode, String) = {
+ ctx.currentVars = input
+ if (streamedBoundKeys.length == 1 && streamedBoundKeys.head.dataType ==
LongType) {
+ // generate the join key as Long
+ val ev = streamedBoundKeys.head.genCode(ctx)
+ (ev, ev.isNull)
+ } else {
+ // generate the join key as UnsafeRow
+ val ev = GenerateUnsafeProjection.createCode(ctx, streamedBoundKeys)
+ (ev, s"${ev.value}.anyNull()")
+ }
+ }
+
+ /**
+ * Generates the code for variable of build side.
+ */
+ private def genBuildSideVars(ctx: CodegenContext, matched: String):
Seq[ExprCode] = {
+ ctx.currentVars = null
+ ctx.INPUT_ROW = matched
+ buildPlan.output.zipWithIndex.map { case (a, i) =>
+ val ev = BoundReference(i, a.dataType, a.nullable).genCode(ctx)
+ if (joinType.isInstanceOf[InnerLike]) {
+ ev
+ } else {
+ // the variables are needed even there is no matched rows
+ val isNull = ctx.freshName("isNull")
+ val value = ctx.freshName("value")
+ val javaType = CodeGenerator.javaType(a.dataType)
+ val code = code"""
+ |boolean $isNull = true;
+ |$javaType $value = ${CodeGenerator.defaultValue(a.dataType)};
+ |if ($matched != null) {
+ | ${ev.code}
+ | $isNull = ${ev.isNull};
+ | $value = ${ev.value};
+ |}
+ """.stripMargin
+ ExprCode(code, JavaCode.isNullVariable(isNull),
JavaCode.variable(value, a.dataType))
+ }
+ }
+ }
+
+ /**
+ * Generate the (non-equi) condition used to filter joined rows. This is
used in Inner, Left Semi
+ * and Left Anti joins.
+ */
+ protected def getJoinCondition(
+ ctx: CodegenContext,
+ input: Seq[ExprCode]): (String, String, Seq[ExprCode]) = {
+ val matched = ctx.freshName("matched")
+ val buildVars = genBuildSideVars(ctx, matched)
+ val checkCondition = if (condition.isDefined) {
+ val expr = condition.get
+ // evaluate the variables from build side that used by condition
+ val eval = evaluateRequiredVariables(buildPlan.output, buildVars,
expr.references)
+ // filter the output via condition
+ ctx.currentVars = input ++ buildVars
+ val ev =
+ BindReferences.bindReference(expr, streamedPlan.output ++
buildPlan.output).genCode(ctx)
+ val skipRow = s"${ev.isNull} || !${ev.value}"
+ s"""
+ |$eval
+ |${ev.code}
+ |if (!($skipRow))
+ """.stripMargin
+ } else {
+ ""
+ }
+ (matched, checkCondition, buildVars)
+ }
+
+ /**
+ * Generates the code for Inner join.
+ */
+ protected def codegenInner(ctx: CodegenContext, input: Seq[ExprCode]):
String = {
+ val (relationTerm, keyIsKnownUnique) = prepareRelation(ctx)
+ val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
+ val (matched, checkCondition, buildVars) = getJoinCondition(ctx, input)
+ val numOutput = metricTerm(ctx, "numOutputRows")
+
+ val resultVars = buildSide match {
+ case BuildLeft => buildVars ++ input
+ case BuildRight => input ++ buildVars
+ }
+
+ if (keyIsKnownUnique) {
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// find matches from HashedRelation
+ |UnsafeRow $matched = $anyNull ? null:
(UnsafeRow)$relationTerm.getValue(${keyEv.value});
+ |if ($matched != null) {
+ | $checkCondition {
+ | $numOutput.add(1);
+ | ${consume(ctx, resultVars)}
+ | }
+ |}
+ """.stripMargin
+ } else {
+ val matches = ctx.freshName("matches")
+ val iteratorCls = classOf[Iterator[UnsafeRow]].getName
+
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// find matches from HashRelation
+ |$iteratorCls $matches = $anyNull ?
+ | null : ($iteratorCls)$relationTerm.get(${keyEv.value});
+ |if ($matches != null) {
+ | while ($matches.hasNext()) {
+ | UnsafeRow $matched = (UnsafeRow) $matches.next();
+ | $checkCondition {
+ | $numOutput.add(1);
+ | ${consume(ctx, resultVars)}
+ | }
+ | }
+ |}
+ """.stripMargin
+ }
+ }
+
+ /**
+ * Generates the code for left or right outer join.
+ */
+ protected def codegenOuter(ctx: CodegenContext, input: Seq[ExprCode]):
String = {
+ val (relationTerm, keyIsKnownUnique) = prepareRelation(ctx)
+ val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
+ val matched = ctx.freshName("matched")
+ val buildVars = genBuildSideVars(ctx, matched)
+ val numOutput = metricTerm(ctx, "numOutputRows")
+
+ // filter the output via condition
+ val conditionPassed = ctx.freshName("conditionPassed")
+ val checkCondition = if (condition.isDefined) {
+ val expr = condition.get
+ // evaluate the variables from build side that used by condition
+ val eval = evaluateRequiredVariables(buildPlan.output, buildVars,
expr.references)
+ ctx.currentVars = input ++ buildVars
+ val ev =
+ BindReferences.bindReference(expr, streamedPlan.output ++
buildPlan.output).genCode(ctx)
+ s"""
+ |boolean $conditionPassed = true;
+ |${eval.trim}
+ |if ($matched != null) {
+ | ${ev.code}
+ | $conditionPassed = !${ev.isNull} && ${ev.value};
+ |}
+ """.stripMargin
+ } else {
+ s"final boolean $conditionPassed = true;"
+ }
+
+ val resultVars = buildSide match {
+ case BuildLeft => buildVars ++ input
+ case BuildRight => input ++ buildVars
+ }
+
+ if (keyIsKnownUnique) {
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// find matches from HashedRelation
+ |UnsafeRow $matched = $anyNull ? null:
(UnsafeRow)$relationTerm.getValue(${keyEv.value});
+ |${checkCondition.trim}
+ |if (!$conditionPassed) {
+ | $matched = null;
+ | // reset the variables those are already evaluated.
+ | ${buildVars.filter(_.code.isEmpty).map(v => s"${v.isNull} =
true;").mkString("\n")}
+ |}
+ |$numOutput.add(1);
+ |${consume(ctx, resultVars)}
+ """.stripMargin
+ } else {
+ val matches = ctx.freshName("matches")
+ val iteratorCls = classOf[Iterator[UnsafeRow]].getName
+ val found = ctx.freshName("found")
+
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// find matches from HashRelation
+ |$iteratorCls $matches = $anyNull ? null :
($iteratorCls)$relationTerm.get(${keyEv.value});
+ |boolean $found = false;
+ |// the last iteration of this loop is to emit an empty row if there
is no matched rows.
+ |while ($matches != null && $matches.hasNext() || !$found) {
+ | UnsafeRow $matched = $matches != null && $matches.hasNext() ?
+ | (UnsafeRow) $matches.next() : null;
+ | ${checkCondition.trim}
+ | if ($conditionPassed) {
+ | $found = true;
+ | $numOutput.add(1);
+ | ${consume(ctx, resultVars)}
+ | }
+ |}
+ """.stripMargin
+ }
+ }
+
+ /**
+ * Generates the code for left semi join.
+ */
+ protected def codegenSemi(ctx: CodegenContext, input: Seq[ExprCode]): String
= {
+ val (relationTerm, keyIsKnownUnique) = prepareRelation(ctx)
+ val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
+ val (matched, checkCondition, _) = getJoinCondition(ctx, input)
+ val numOutput = metricTerm(ctx, "numOutputRows")
+
+ if (keyIsKnownUnique) {
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// find matches from HashedRelation
+ |UnsafeRow $matched = $anyNull ? null:
(UnsafeRow)$relationTerm.getValue(${keyEv.value});
+ |if ($matched != null) {
+ | $checkCondition {
+ | $numOutput.add(1);
+ | ${consume(ctx, input)}
+ | }
+ |}
+ """.stripMargin
+ } else {
+ val matches = ctx.freshName("matches")
+ val iteratorCls = classOf[Iterator[UnsafeRow]].getName
+ val found = ctx.freshName("found")
+
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// find matches from HashRelation
+ |$iteratorCls $matches = $anyNull ? null :
($iteratorCls)$relationTerm.get(${keyEv.value});
+ |if ($matches != null) {
+ | boolean $found = false;
+ | while (!$found && $matches.hasNext()) {
+ | UnsafeRow $matched = (UnsafeRow) $matches.next();
+ | $checkCondition {
+ | $found = true;
+ | }
+ | }
+ | if ($found) {
+ | $numOutput.add(1);
+ | ${consume(ctx, input)}
+ | }
+ |}
+ """.stripMargin
+ }
+ }
+
+ /**
+ * Generates the code for anti join.
+ */
+ protected def codegenAnti(ctx: CodegenContext, input: Seq[ExprCode]): String
= {
+ val (relationTerm, keyIsKnownUnique) = prepareRelation(ctx)
Review comment:
changed to `prepareRelation` to get `keyIsKnownUnique`.
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
##########
@@ -316,6 +318,387 @@ trait HashJoin extends BaseJoinExec {
resultProj(r)
}
}
+
+ /**
+ * Returns the code for generating join key for stream side, and expression
of whether the key
+ * has any null in it or not.
+ */
+ protected def genStreamSideJoinKey(
+ ctx: CodegenContext,
+ input: Seq[ExprCode]): (ExprCode, String) = {
+ ctx.currentVars = input
+ if (streamedBoundKeys.length == 1 && streamedBoundKeys.head.dataType ==
LongType) {
+ // generate the join key as Long
+ val ev = streamedBoundKeys.head.genCode(ctx)
+ (ev, ev.isNull)
+ } else {
+ // generate the join key as UnsafeRow
+ val ev = GenerateUnsafeProjection.createCode(ctx, streamedBoundKeys)
+ (ev, s"${ev.value}.anyNull()")
+ }
+ }
+
+ /**
+ * Generates the code for variable of build side.
+ */
+ private def genBuildSideVars(ctx: CodegenContext, matched: String):
Seq[ExprCode] = {
+ ctx.currentVars = null
+ ctx.INPUT_ROW = matched
+ buildPlan.output.zipWithIndex.map { case (a, i) =>
+ val ev = BoundReference(i, a.dataType, a.nullable).genCode(ctx)
+ if (joinType.isInstanceOf[InnerLike]) {
+ ev
+ } else {
+ // the variables are needed even there is no matched rows
+ val isNull = ctx.freshName("isNull")
+ val value = ctx.freshName("value")
+ val javaType = CodeGenerator.javaType(a.dataType)
+ val code = code"""
+ |boolean $isNull = true;
+ |$javaType $value = ${CodeGenerator.defaultValue(a.dataType)};
+ |if ($matched != null) {
+ | ${ev.code}
+ | $isNull = ${ev.isNull};
+ | $value = ${ev.value};
+ |}
+ """.stripMargin
+ ExprCode(code, JavaCode.isNullVariable(isNull),
JavaCode.variable(value, a.dataType))
+ }
+ }
+ }
+
+ /**
+ * Generate the (non-equi) condition used to filter joined rows. This is
used in Inner, Left Semi
+ * and Left Anti joins.
+ */
+ protected def getJoinCondition(
Review comment:
this method is copied from `BroadcastHashJoinExec` without change.
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
##########
@@ -316,6 +318,387 @@ trait HashJoin extends BaseJoinExec {
resultProj(r)
}
}
+
+ /**
+ * Returns the code for generating join key for stream side, and expression
of whether the key
+ * has any null in it or not.
+ */
+ protected def genStreamSideJoinKey(
+ ctx: CodegenContext,
+ input: Seq[ExprCode]): (ExprCode, String) = {
+ ctx.currentVars = input
+ if (streamedBoundKeys.length == 1 && streamedBoundKeys.head.dataType ==
LongType) {
+ // generate the join key as Long
+ val ev = streamedBoundKeys.head.genCode(ctx)
+ (ev, ev.isNull)
+ } else {
+ // generate the join key as UnsafeRow
+ val ev = GenerateUnsafeProjection.createCode(ctx, streamedBoundKeys)
+ (ev, s"${ev.value}.anyNull()")
+ }
+ }
+
+ /**
+ * Generates the code for variable of build side.
+ */
+ private def genBuildSideVars(ctx: CodegenContext, matched: String):
Seq[ExprCode] = {
+ ctx.currentVars = null
+ ctx.INPUT_ROW = matched
+ buildPlan.output.zipWithIndex.map { case (a, i) =>
+ val ev = BoundReference(i, a.dataType, a.nullable).genCode(ctx)
+ if (joinType.isInstanceOf[InnerLike]) {
+ ev
+ } else {
+ // the variables are needed even there is no matched rows
+ val isNull = ctx.freshName("isNull")
+ val value = ctx.freshName("value")
+ val javaType = CodeGenerator.javaType(a.dataType)
+ val code = code"""
+ |boolean $isNull = true;
+ |$javaType $value = ${CodeGenerator.defaultValue(a.dataType)};
+ |if ($matched != null) {
+ | ${ev.code}
+ | $isNull = ${ev.isNull};
+ | $value = ${ev.value};
+ |}
+ """.stripMargin
+ ExprCode(code, JavaCode.isNullVariable(isNull),
JavaCode.variable(value, a.dataType))
+ }
+ }
+ }
+
+ /**
+ * Generate the (non-equi) condition used to filter joined rows. This is
used in Inner, Left Semi
+ * and Left Anti joins.
+ */
+ protected def getJoinCondition(
+ ctx: CodegenContext,
+ input: Seq[ExprCode]): (String, String, Seq[ExprCode]) = {
+ val matched = ctx.freshName("matched")
+ val buildVars = genBuildSideVars(ctx, matched)
+ val checkCondition = if (condition.isDefined) {
+ val expr = condition.get
+ // evaluate the variables from build side that used by condition
+ val eval = evaluateRequiredVariables(buildPlan.output, buildVars,
expr.references)
+ // filter the output via condition
+ ctx.currentVars = input ++ buildVars
+ val ev =
+ BindReferences.bindReference(expr, streamedPlan.output ++
buildPlan.output).genCode(ctx)
+ val skipRow = s"${ev.isNull} || !${ev.value}"
+ s"""
+ |$eval
+ |${ev.code}
+ |if (!($skipRow))
+ """.stripMargin
+ } else {
+ ""
+ }
+ (matched, checkCondition, buildVars)
+ }
+
+ /**
+ * Generates the code for Inner join.
+ */
+ protected def codegenInner(ctx: CodegenContext, input: Seq[ExprCode]):
String = {
+ val (relationTerm, keyIsKnownUnique) = prepareRelation(ctx)
+ val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
+ val (matched, checkCondition, buildVars) = getJoinCondition(ctx, input)
+ val numOutput = metricTerm(ctx, "numOutputRows")
+
+ val resultVars = buildSide match {
+ case BuildLeft => buildVars ++ input
+ case BuildRight => input ++ buildVars
+ }
+
+ if (keyIsKnownUnique) {
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// find matches from HashedRelation
+ |UnsafeRow $matched = $anyNull ? null:
(UnsafeRow)$relationTerm.getValue(${keyEv.value});
+ |if ($matched != null) {
+ | $checkCondition {
+ | $numOutput.add(1);
+ | ${consume(ctx, resultVars)}
+ | }
+ |}
+ """.stripMargin
+ } else {
+ val matches = ctx.freshName("matches")
+ val iteratorCls = classOf[Iterator[UnsafeRow]].getName
+
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// find matches from HashRelation
+ |$iteratorCls $matches = $anyNull ?
+ | null : ($iteratorCls)$relationTerm.get(${keyEv.value});
+ |if ($matches != null) {
+ | while ($matches.hasNext()) {
+ | UnsafeRow $matched = (UnsafeRow) $matches.next();
+ | $checkCondition {
+ | $numOutput.add(1);
+ | ${consume(ctx, resultVars)}
+ | }
+ | }
+ |}
+ """.stripMargin
+ }
+ }
+
+ /**
+ * Generates the code for left or right outer join.
+ */
+ protected def codegenOuter(ctx: CodegenContext, input: Seq[ExprCode]):
String = {
+ val (relationTerm, keyIsKnownUnique) = prepareRelation(ctx)
Review comment:
changed to `prepareRelation` to get `keyIsKnownUnique`.
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
##########
@@ -316,6 +318,387 @@ trait HashJoin extends BaseJoinExec {
resultProj(r)
}
}
+
+ /**
+ * Returns the code for generating join key for stream side, and expression
of whether the key
+ * has any null in it or not.
+ */
+ protected def genStreamSideJoinKey(
+ ctx: CodegenContext,
+ input: Seq[ExprCode]): (ExprCode, String) = {
+ ctx.currentVars = input
+ if (streamedBoundKeys.length == 1 && streamedBoundKeys.head.dataType ==
LongType) {
+ // generate the join key as Long
+ val ev = streamedBoundKeys.head.genCode(ctx)
+ (ev, ev.isNull)
+ } else {
+ // generate the join key as UnsafeRow
+ val ev = GenerateUnsafeProjection.createCode(ctx, streamedBoundKeys)
+ (ev, s"${ev.value}.anyNull()")
+ }
+ }
+
+ /**
+ * Generates the code for variable of build side.
+ */
+ private def genBuildSideVars(ctx: CodegenContext, matched: String):
Seq[ExprCode] = {
+ ctx.currentVars = null
+ ctx.INPUT_ROW = matched
+ buildPlan.output.zipWithIndex.map { case (a, i) =>
+ val ev = BoundReference(i, a.dataType, a.nullable).genCode(ctx)
+ if (joinType.isInstanceOf[InnerLike]) {
+ ev
+ } else {
+ // the variables are needed even there is no matched rows
+ val isNull = ctx.freshName("isNull")
+ val value = ctx.freshName("value")
+ val javaType = CodeGenerator.javaType(a.dataType)
+ val code = code"""
+ |boolean $isNull = true;
+ |$javaType $value = ${CodeGenerator.defaultValue(a.dataType)};
+ |if ($matched != null) {
+ | ${ev.code}
+ | $isNull = ${ev.isNull};
+ | $value = ${ev.value};
+ |}
+ """.stripMargin
+ ExprCode(code, JavaCode.isNullVariable(isNull),
JavaCode.variable(value, a.dataType))
+ }
+ }
+ }
+
+ /**
+ * Generate the (non-equi) condition used to filter joined rows. This is
used in Inner, Left Semi
+ * and Left Anti joins.
+ */
+ protected def getJoinCondition(
+ ctx: CodegenContext,
+ input: Seq[ExprCode]): (String, String, Seq[ExprCode]) = {
+ val matched = ctx.freshName("matched")
+ val buildVars = genBuildSideVars(ctx, matched)
+ val checkCondition = if (condition.isDefined) {
+ val expr = condition.get
+ // evaluate the variables from build side that used by condition
+ val eval = evaluateRequiredVariables(buildPlan.output, buildVars,
expr.references)
+ // filter the output via condition
+ ctx.currentVars = input ++ buildVars
+ val ev =
+ BindReferences.bindReference(expr, streamedPlan.output ++
buildPlan.output).genCode(ctx)
+ val skipRow = s"${ev.isNull} || !${ev.value}"
+ s"""
+ |$eval
+ |${ev.code}
+ |if (!($skipRow))
+ """.stripMargin
+ } else {
+ ""
+ }
+ (matched, checkCondition, buildVars)
+ }
+
+ /**
+ * Generates the code for Inner join.
+ */
+ protected def codegenInner(ctx: CodegenContext, input: Seq[ExprCode]):
String = {
+ val (relationTerm, keyIsKnownUnique) = prepareRelation(ctx)
+ val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
+ val (matched, checkCondition, buildVars) = getJoinCondition(ctx, input)
+ val numOutput = metricTerm(ctx, "numOutputRows")
+
+ val resultVars = buildSide match {
+ case BuildLeft => buildVars ++ input
+ case BuildRight => input ++ buildVars
+ }
+
+ if (keyIsKnownUnique) {
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// find matches from HashedRelation
+ |UnsafeRow $matched = $anyNull ? null:
(UnsafeRow)$relationTerm.getValue(${keyEv.value});
+ |if ($matched != null) {
+ | $checkCondition {
+ | $numOutput.add(1);
+ | ${consume(ctx, resultVars)}
+ | }
+ |}
+ """.stripMargin
+ } else {
+ val matches = ctx.freshName("matches")
+ val iteratorCls = classOf[Iterator[UnsafeRow]].getName
+
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// find matches from HashRelation
+ |$iteratorCls $matches = $anyNull ?
+ | null : ($iteratorCls)$relationTerm.get(${keyEv.value});
+ |if ($matches != null) {
+ | while ($matches.hasNext()) {
+ | UnsafeRow $matched = (UnsafeRow) $matches.next();
+ | $checkCondition {
+ | $numOutput.add(1);
+ | ${consume(ctx, resultVars)}
+ | }
+ | }
+ |}
+ """.stripMargin
+ }
+ }
+
+ /**
+ * Generates the code for left or right outer join.
+ */
+ protected def codegenOuter(ctx: CodegenContext, input: Seq[ExprCode]):
String = {
+ val (relationTerm, keyIsKnownUnique) = prepareRelation(ctx)
+ val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
+ val matched = ctx.freshName("matched")
+ val buildVars = genBuildSideVars(ctx, matched)
+ val numOutput = metricTerm(ctx, "numOutputRows")
+
+ // filter the output via condition
+ val conditionPassed = ctx.freshName("conditionPassed")
+ val checkCondition = if (condition.isDefined) {
+ val expr = condition.get
+ // evaluate the variables from build side that used by condition
+ val eval = evaluateRequiredVariables(buildPlan.output, buildVars,
expr.references)
+ ctx.currentVars = input ++ buildVars
+ val ev =
+ BindReferences.bindReference(expr, streamedPlan.output ++
buildPlan.output).genCode(ctx)
+ s"""
+ |boolean $conditionPassed = true;
+ |${eval.trim}
+ |if ($matched != null) {
+ | ${ev.code}
+ | $conditionPassed = !${ev.isNull} && ${ev.value};
+ |}
+ """.stripMargin
+ } else {
+ s"final boolean $conditionPassed = true;"
+ }
+
+ val resultVars = buildSide match {
+ case BuildLeft => buildVars ++ input
+ case BuildRight => input ++ buildVars
+ }
+
+ if (keyIsKnownUnique) {
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// find matches from HashedRelation
+ |UnsafeRow $matched = $anyNull ? null:
(UnsafeRow)$relationTerm.getValue(${keyEv.value});
+ |${checkCondition.trim}
+ |if (!$conditionPassed) {
+ | $matched = null;
+ | // reset the variables those are already evaluated.
+ | ${buildVars.filter(_.code.isEmpty).map(v => s"${v.isNull} =
true;").mkString("\n")}
+ |}
+ |$numOutput.add(1);
+ |${consume(ctx, resultVars)}
+ """.stripMargin
+ } else {
+ val matches = ctx.freshName("matches")
+ val iteratorCls = classOf[Iterator[UnsafeRow]].getName
+ val found = ctx.freshName("found")
+
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// find matches from HashRelation
+ |$iteratorCls $matches = $anyNull ? null :
($iteratorCls)$relationTerm.get(${keyEv.value});
+ |boolean $found = false;
+ |// the last iteration of this loop is to emit an empty row if there
is no matched rows.
+ |while ($matches != null && $matches.hasNext() || !$found) {
+ | UnsafeRow $matched = $matches != null && $matches.hasNext() ?
+ | (UnsafeRow) $matches.next() : null;
+ | ${checkCondition.trim}
+ | if ($conditionPassed) {
+ | $found = true;
+ | $numOutput.add(1);
+ | ${consume(ctx, resultVars)}
+ | }
+ |}
+ """.stripMargin
+ }
+ }
+
+ /**
+ * Generates the code for left semi join.
+ */
+ protected def codegenSemi(ctx: CodegenContext, input: Seq[ExprCode]): String
= {
+ val (relationTerm, keyIsKnownUnique) = prepareRelation(ctx)
Review comment:
changed to `prepareRelation` to get `keyIsKnownUnique`.
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
##########
@@ -316,6 +318,387 @@ trait HashJoin extends BaseJoinExec {
resultProj(r)
}
}
+
+ /**
+ * Returns the code for generating join key for stream side, and expression
of whether the key
+ * has any null in it or not.
+ */
+ protected def genStreamSideJoinKey(
+ ctx: CodegenContext,
+ input: Seq[ExprCode]): (ExprCode, String) = {
+ ctx.currentVars = input
+ if (streamedBoundKeys.length == 1 && streamedBoundKeys.head.dataType ==
LongType) {
+ // generate the join key as Long
+ val ev = streamedBoundKeys.head.genCode(ctx)
+ (ev, ev.isNull)
+ } else {
+ // generate the join key as UnsafeRow
+ val ev = GenerateUnsafeProjection.createCode(ctx, streamedBoundKeys)
+ (ev, s"${ev.value}.anyNull()")
+ }
+ }
+
+ /**
+ * Generates the code for variable of build side.
+ */
+ private def genBuildSideVars(ctx: CodegenContext, matched: String):
Seq[ExprCode] = {
+ ctx.currentVars = null
+ ctx.INPUT_ROW = matched
+ buildPlan.output.zipWithIndex.map { case (a, i) =>
+ val ev = BoundReference(i, a.dataType, a.nullable).genCode(ctx)
+ if (joinType.isInstanceOf[InnerLike]) {
+ ev
+ } else {
+ // the variables are needed even there is no matched rows
+ val isNull = ctx.freshName("isNull")
+ val value = ctx.freshName("value")
+ val javaType = CodeGenerator.javaType(a.dataType)
+ val code = code"""
+ |boolean $isNull = true;
+ |$javaType $value = ${CodeGenerator.defaultValue(a.dataType)};
+ |if ($matched != null) {
+ | ${ev.code}
+ | $isNull = ${ev.isNull};
+ | $value = ${ev.value};
+ |}
+ """.stripMargin
+ ExprCode(code, JavaCode.isNullVariable(isNull),
JavaCode.variable(value, a.dataType))
+ }
+ }
+ }
+
+ /**
+ * Generate the (non-equi) condition used to filter joined rows. This is
used in Inner, Left Semi
+ * and Left Anti joins.
+ */
+ protected def getJoinCondition(
+ ctx: CodegenContext,
+ input: Seq[ExprCode]): (String, String, Seq[ExprCode]) = {
+ val matched = ctx.freshName("matched")
+ val buildVars = genBuildSideVars(ctx, matched)
+ val checkCondition = if (condition.isDefined) {
+ val expr = condition.get
+ // evaluate the variables from build side that used by condition
+ val eval = evaluateRequiredVariables(buildPlan.output, buildVars,
expr.references)
+ // filter the output via condition
+ ctx.currentVars = input ++ buildVars
+ val ev =
+ BindReferences.bindReference(expr, streamedPlan.output ++
buildPlan.output).genCode(ctx)
+ val skipRow = s"${ev.isNull} || !${ev.value}"
+ s"""
+ |$eval
+ |${ev.code}
+ |if (!($skipRow))
+ """.stripMargin
+ } else {
+ ""
+ }
+ (matched, checkCondition, buildVars)
+ }
+
+ /**
+ * Generates the code for Inner join.
+ */
+ protected def codegenInner(ctx: CodegenContext, input: Seq[ExprCode]):
String = {
+ val (relationTerm, keyIsKnownUnique) = prepareRelation(ctx)
+ val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
+ val (matched, checkCondition, buildVars) = getJoinCondition(ctx, input)
+ val numOutput = metricTerm(ctx, "numOutputRows")
+
+ val resultVars = buildSide match {
+ case BuildLeft => buildVars ++ input
+ case BuildRight => input ++ buildVars
+ }
+
+ if (keyIsKnownUnique) {
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// find matches from HashedRelation
+ |UnsafeRow $matched = $anyNull ? null:
(UnsafeRow)$relationTerm.getValue(${keyEv.value});
+ |if ($matched != null) {
+ | $checkCondition {
+ | $numOutput.add(1);
+ | ${consume(ctx, resultVars)}
+ | }
+ |}
+ """.stripMargin
+ } else {
+ val matches = ctx.freshName("matches")
+ val iteratorCls = classOf[Iterator[UnsafeRow]].getName
+
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// find matches from HashRelation
+ |$iteratorCls $matches = $anyNull ?
+ | null : ($iteratorCls)$relationTerm.get(${keyEv.value});
+ |if ($matches != null) {
+ | while ($matches.hasNext()) {
+ | UnsafeRow $matched = (UnsafeRow) $matches.next();
+ | $checkCondition {
+ | $numOutput.add(1);
+ | ${consume(ctx, resultVars)}
+ | }
+ | }
+ |}
+ """.stripMargin
+ }
+ }
+
+ /**
+ * Generates the code for left or right outer join.
+ */
+ protected def codegenOuter(ctx: CodegenContext, input: Seq[ExprCode]):
String = {
+ val (relationTerm, keyIsKnownUnique) = prepareRelation(ctx)
+ val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
+ val matched = ctx.freshName("matched")
+ val buildVars = genBuildSideVars(ctx, matched)
+ val numOutput = metricTerm(ctx, "numOutputRows")
+
+ // filter the output via condition
+ val conditionPassed = ctx.freshName("conditionPassed")
+ val checkCondition = if (condition.isDefined) {
+ val expr = condition.get
+ // evaluate the variables from build side that used by condition
+ val eval = evaluateRequiredVariables(buildPlan.output, buildVars,
expr.references)
+ ctx.currentVars = input ++ buildVars
+ val ev =
+ BindReferences.bindReference(expr, streamedPlan.output ++
buildPlan.output).genCode(ctx)
+ s"""
+ |boolean $conditionPassed = true;
+ |${eval.trim}
+ |if ($matched != null) {
+ | ${ev.code}
+ | $conditionPassed = !${ev.isNull} && ${ev.value};
+ |}
+ """.stripMargin
+ } else {
+ s"final boolean $conditionPassed = true;"
+ }
+
+ val resultVars = buildSide match {
+ case BuildLeft => buildVars ++ input
+ case BuildRight => input ++ buildVars
+ }
+
+ if (keyIsKnownUnique) {
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// find matches from HashedRelation
+ |UnsafeRow $matched = $anyNull ? null:
(UnsafeRow)$relationTerm.getValue(${keyEv.value});
+ |${checkCondition.trim}
+ |if (!$conditionPassed) {
+ | $matched = null;
+ | // reset the variables those are already evaluated.
+ | ${buildVars.filter(_.code.isEmpty).map(v => s"${v.isNull} =
true;").mkString("\n")}
+ |}
+ |$numOutput.add(1);
+ |${consume(ctx, resultVars)}
+ """.stripMargin
+ } else {
+ val matches = ctx.freshName("matches")
+ val iteratorCls = classOf[Iterator[UnsafeRow]].getName
+ val found = ctx.freshName("found")
+
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// find matches from HashRelation
+ |$iteratorCls $matches = $anyNull ? null :
($iteratorCls)$relationTerm.get(${keyEv.value});
+ |boolean $found = false;
+ |// the last iteration of this loop is to emit an empty row if there
is no matched rows.
+ |while ($matches != null && $matches.hasNext() || !$found) {
+ | UnsafeRow $matched = $matches != null && $matches.hasNext() ?
+ | (UnsafeRow) $matches.next() : null;
+ | ${checkCondition.trim}
+ | if ($conditionPassed) {
+ | $found = true;
+ | $numOutput.add(1);
+ | ${consume(ctx, resultVars)}
+ | }
+ |}
+ """.stripMargin
+ }
+ }
+
+ /**
+ * Generates the code for left semi join.
+ */
+ protected def codegenSemi(ctx: CodegenContext, input: Seq[ExprCode]): String
= {
+ val (relationTerm, keyIsKnownUnique) = prepareRelation(ctx)
+ val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
+ val (matched, checkCondition, _) = getJoinCondition(ctx, input)
+ val numOutput = metricTerm(ctx, "numOutputRows")
+
+ if (keyIsKnownUnique) {
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// find matches from HashedRelation
+ |UnsafeRow $matched = $anyNull ? null:
(UnsafeRow)$relationTerm.getValue(${keyEv.value});
+ |if ($matched != null) {
+ | $checkCondition {
+ | $numOutput.add(1);
+ | ${consume(ctx, input)}
+ | }
+ |}
+ """.stripMargin
+ } else {
+ val matches = ctx.freshName("matches")
+ val iteratorCls = classOf[Iterator[UnsafeRow]].getName
+ val found = ctx.freshName("found")
+
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// find matches from HashRelation
+ |$iteratorCls $matches = $anyNull ? null :
($iteratorCls)$relationTerm.get(${keyEv.value});
+ |if ($matches != null) {
+ | boolean $found = false;
+ | while (!$found && $matches.hasNext()) {
+ | UnsafeRow $matched = (UnsafeRow) $matches.next();
+ | $checkCondition {
+ | $found = true;
+ | }
+ | }
+ | if ($found) {
+ | $numOutput.add(1);
+ | ${consume(ctx, input)}
+ | }
+ |}
+ """.stripMargin
+ }
+ }
+
+ /**
+ * Generates the code for anti join.
+ */
+ protected def codegenAnti(ctx: CodegenContext, input: Seq[ExprCode]): String
= {
+ val (relationTerm, keyIsKnownUnique) = prepareRelation(ctx)
+ val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
+ val (matched, checkCondition, _) = getJoinCondition(ctx, input)
+ val numOutput = metricTerm(ctx, "numOutputRows")
+
+ if (keyIsKnownUnique) {
+ val found = ctx.freshName("found")
+ s"""
+ |boolean $found = false;
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// Check if the key has nulls.
+ |if (!($anyNull)) {
+ | // Check if the HashedRelation exists.
+ | UnsafeRow $matched =
(UnsafeRow)$relationTerm.getValue(${keyEv.value});
+ | if ($matched != null) {
+ | // Evaluate the condition.
+ | $checkCondition {
+ | $found = true;
+ | }
+ | }
+ |}
+ |if (!$found) {
+ | $numOutput.add(1);
+ | ${consume(ctx, input)}
+ |}
+ """.stripMargin
+ } else {
+ val matches = ctx.freshName("matches")
+ val iteratorCls = classOf[Iterator[UnsafeRow]].getName
+ val found = ctx.freshName("found")
+ s"""
+ |boolean $found = false;
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// Check if the key has nulls.
+ |if (!($anyNull)) {
+ | // Check if the HashedRelation exists.
+ | $iteratorCls $matches =
($iteratorCls)$relationTerm.get(${keyEv.value});
+ | if ($matches != null) {
+ | // Evaluate the condition.
+ | while (!$found && $matches.hasNext()) {
+ | UnsafeRow $matched = (UnsafeRow) $matches.next();
+ | $checkCondition {
+ | $found = true;
+ | }
+ | }
+ | }
+ |}
+ |if (!$found) {
+ | $numOutput.add(1);
+ | ${consume(ctx, input)}
+ |}
+ """.stripMargin
+ }
+ }
+
+ /**
+ * Generates the code for existence join.
+ */
+ protected def codegenExistence(ctx: CodegenContext, input: Seq[ExprCode]):
String = {
+ val (relationTerm, keyIsKnownUnique) = prepareRelation(ctx)
+ val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
+ val numOutput = metricTerm(ctx, "numOutputRows")
+ val existsVar = ctx.freshName("exists")
+
+ val matched = ctx.freshName("matched")
+ val buildVars = genBuildSideVars(ctx, matched)
+ val checkCondition = if (condition.isDefined) {
+ val expr = condition.get
+ // evaluate the variables from build side that used by condition
+ val eval = evaluateRequiredVariables(buildPlan.output, buildVars,
expr.references)
+ // filter the output via condition
+ ctx.currentVars = input ++ buildVars
+ val ev =
+ BindReferences.bindReference(expr, streamedPlan.output ++
buildPlan.output).genCode(ctx)
+ s"""
+ |$eval
+ |${ev.code}
+ |$existsVar = !${ev.isNull} && ${ev.value};
+ """.stripMargin
+ } else {
+ s"$existsVar = true;"
+ }
+
+ val resultVar = input ++ Seq(ExprCode.forNonNullValue(
+ JavaCode.variable(existsVar, BooleanType)))
+
+ if (keyIsKnownUnique) {
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// find matches from HashedRelation
+ |UnsafeRow $matched = $anyNull ? null:
(UnsafeRow)$relationTerm.getValue(${keyEv.value});
+ |boolean $existsVar = false;
+ |if ($matched != null) {
+ | $checkCondition
+ |}
+ |$numOutput.add(1);
+ |${consume(ctx, resultVar)}
+ """.stripMargin
+ } else {
+ val matches = ctx.freshName("matches")
+ val iteratorCls = classOf[Iterator[UnsafeRow]].getName
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// find matches from HashRelation
+ |$iteratorCls $matches = $anyNull ? null :
($iteratorCls)$relationTerm.get(${keyEv.value});
+ |boolean $existsVar = false;
+ |if ($matches != null) {
+ | while (!$existsVar && $matches.hasNext()) {
+ | UnsafeRow $matched = (UnsafeRow) $matches.next();
+ | $checkCondition
+ | }
+ |}
+ |$numOutput.add(1);
+ |${consume(ctx, resultVar)}
+ """.stripMargin
+ }
+ }
+
+ protected def prepareRelation(ctx: CodegenContext): (String, Boolean)
Review comment:
added an abstract method `prepareRelation` which is implemented
separately from SHJ and BHJ.
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
##########
@@ -316,6 +318,387 @@ trait HashJoin extends BaseJoinExec {
resultProj(r)
}
}
+
Review comment:
@cloud-fan - sorry about that. yes mostly of them is moving code around
without change. Highlighted change with comments, thanks.
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
##########
@@ -316,6 +318,387 @@ trait HashJoin extends BaseJoinExec {
resultProj(r)
}
}
+
+ /**
+ * Returns the code for generating join key for stream side, and expression
of whether the key
+ * has any null in it or not.
+ */
+ protected def genStreamSideJoinKey(
+ ctx: CodegenContext,
+ input: Seq[ExprCode]): (ExprCode, String) = {
+ ctx.currentVars = input
+ if (streamedBoundKeys.length == 1 && streamedBoundKeys.head.dataType ==
LongType) {
+ // generate the join key as Long
+ val ev = streamedBoundKeys.head.genCode(ctx)
+ (ev, ev.isNull)
+ } else {
+ // generate the join key as UnsafeRow
+ val ev = GenerateUnsafeProjection.createCode(ctx, streamedBoundKeys)
+ (ev, s"${ev.value}.anyNull()")
+ }
+ }
+
+ /**
+ * Generates the code for variable of build side.
+ */
+ private def genBuildSideVars(ctx: CodegenContext, matched: String):
Seq[ExprCode] = {
+ ctx.currentVars = null
+ ctx.INPUT_ROW = matched
+ buildPlan.output.zipWithIndex.map { case (a, i) =>
+ val ev = BoundReference(i, a.dataType, a.nullable).genCode(ctx)
+ if (joinType.isInstanceOf[InnerLike]) {
+ ev
+ } else {
+ // the variables are needed even there is no matched rows
+ val isNull = ctx.freshName("isNull")
+ val value = ctx.freshName("value")
+ val javaType = CodeGenerator.javaType(a.dataType)
+ val code = code"""
+ |boolean $isNull = true;
+ |$javaType $value = ${CodeGenerator.defaultValue(a.dataType)};
+ |if ($matched != null) {
+ | ${ev.code}
+ | $isNull = ${ev.isNull};
+ | $value = ${ev.value};
+ |}
+ """.stripMargin
+ ExprCode(code, JavaCode.isNullVariable(isNull),
JavaCode.variable(value, a.dataType))
+ }
+ }
+ }
+
+ /**
+ * Generate the (non-equi) condition used to filter joined rows. This is
used in Inner, Left Semi
+ * and Left Anti joins.
+ */
+ protected def getJoinCondition(
+ ctx: CodegenContext,
+ input: Seq[ExprCode]): (String, String, Seq[ExprCode]) = {
+ val matched = ctx.freshName("matched")
+ val buildVars = genBuildSideVars(ctx, matched)
+ val checkCondition = if (condition.isDefined) {
+ val expr = condition.get
+ // evaluate the variables from build side that used by condition
+ val eval = evaluateRequiredVariables(buildPlan.output, buildVars,
expr.references)
+ // filter the output via condition
+ ctx.currentVars = input ++ buildVars
+ val ev =
+ BindReferences.bindReference(expr, streamedPlan.output ++
buildPlan.output).genCode(ctx)
+ val skipRow = s"${ev.isNull} || !${ev.value}"
+ s"""
+ |$eval
+ |${ev.code}
+ |if (!($skipRow))
+ """.stripMargin
+ } else {
+ ""
+ }
+ (matched, checkCondition, buildVars)
+ }
+
+ /**
+ * Generates the code for Inner join.
+ */
+ protected def codegenInner(ctx: CodegenContext, input: Seq[ExprCode]):
String = {
+ val (relationTerm, keyIsKnownUnique) = prepareRelation(ctx)
+ val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
+ val (matched, checkCondition, buildVars) = getJoinCondition(ctx, input)
+ val numOutput = metricTerm(ctx, "numOutputRows")
+
+ val resultVars = buildSide match {
+ case BuildLeft => buildVars ++ input
+ case BuildRight => input ++ buildVars
+ }
+
+ if (keyIsKnownUnique) {
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// find matches from HashedRelation
+ |UnsafeRow $matched = $anyNull ? null:
(UnsafeRow)$relationTerm.getValue(${keyEv.value});
+ |if ($matched != null) {
+ | $checkCondition {
+ | $numOutput.add(1);
+ | ${consume(ctx, resultVars)}
+ | }
+ |}
+ """.stripMargin
+ } else {
+ val matches = ctx.freshName("matches")
+ val iteratorCls = classOf[Iterator[UnsafeRow]].getName
+
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// find matches from HashRelation
+ |$iteratorCls $matches = $anyNull ?
+ | null : ($iteratorCls)$relationTerm.get(${keyEv.value});
+ |if ($matches != null) {
+ | while ($matches.hasNext()) {
+ | UnsafeRow $matched = (UnsafeRow) $matches.next();
+ | $checkCondition {
+ | $numOutput.add(1);
+ | ${consume(ctx, resultVars)}
+ | }
+ | }
+ |}
+ """.stripMargin
+ }
+ }
+
+ /**
+ * Generates the code for left or right outer join.
+ */
+ protected def codegenOuter(ctx: CodegenContext, input: Seq[ExprCode]):
String = {
+ val (relationTerm, keyIsKnownUnique) = prepareRelation(ctx)
+ val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
+ val matched = ctx.freshName("matched")
+ val buildVars = genBuildSideVars(ctx, matched)
+ val numOutput = metricTerm(ctx, "numOutputRows")
+
+ // filter the output via condition
+ val conditionPassed = ctx.freshName("conditionPassed")
+ val checkCondition = if (condition.isDefined) {
+ val expr = condition.get
+ // evaluate the variables from build side that used by condition
+ val eval = evaluateRequiredVariables(buildPlan.output, buildVars,
expr.references)
+ ctx.currentVars = input ++ buildVars
+ val ev =
+ BindReferences.bindReference(expr, streamedPlan.output ++
buildPlan.output).genCode(ctx)
+ s"""
+ |boolean $conditionPassed = true;
+ |${eval.trim}
+ |if ($matched != null) {
+ | ${ev.code}
+ | $conditionPassed = !${ev.isNull} && ${ev.value};
+ |}
+ """.stripMargin
+ } else {
+ s"final boolean $conditionPassed = true;"
+ }
+
+ val resultVars = buildSide match {
+ case BuildLeft => buildVars ++ input
+ case BuildRight => input ++ buildVars
+ }
+
+ if (keyIsKnownUnique) {
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// find matches from HashedRelation
+ |UnsafeRow $matched = $anyNull ? null:
(UnsafeRow)$relationTerm.getValue(${keyEv.value});
+ |${checkCondition.trim}
+ |if (!$conditionPassed) {
+ | $matched = null;
+ | // reset the variables those are already evaluated.
+ | ${buildVars.filter(_.code.isEmpty).map(v => s"${v.isNull} =
true;").mkString("\n")}
+ |}
+ |$numOutput.add(1);
+ |${consume(ctx, resultVars)}
+ """.stripMargin
+ } else {
+ val matches = ctx.freshName("matches")
+ val iteratorCls = classOf[Iterator[UnsafeRow]].getName
+ val found = ctx.freshName("found")
+
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// find matches from HashRelation
+ |$iteratorCls $matches = $anyNull ? null :
($iteratorCls)$relationTerm.get(${keyEv.value});
+ |boolean $found = false;
+ |// the last iteration of this loop is to emit an empty row if there
is no matched rows.
+ |while ($matches != null && $matches.hasNext() || !$found) {
+ | UnsafeRow $matched = $matches != null && $matches.hasNext() ?
+ | (UnsafeRow) $matches.next() : null;
+ | ${checkCondition.trim}
+ | if ($conditionPassed) {
+ | $found = true;
+ | $numOutput.add(1);
+ | ${consume(ctx, resultVars)}
+ | }
+ |}
+ """.stripMargin
+ }
+ }
+
+ /**
+ * Generates the code for left semi join.
+ */
+ protected def codegenSemi(ctx: CodegenContext, input: Seq[ExprCode]): String
= {
+ val (relationTerm, keyIsKnownUnique) = prepareRelation(ctx)
+ val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
+ val (matched, checkCondition, _) = getJoinCondition(ctx, input)
+ val numOutput = metricTerm(ctx, "numOutputRows")
+
+ if (keyIsKnownUnique) {
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// find matches from HashedRelation
+ |UnsafeRow $matched = $anyNull ? null:
(UnsafeRow)$relationTerm.getValue(${keyEv.value});
+ |if ($matched != null) {
+ | $checkCondition {
+ | $numOutput.add(1);
+ | ${consume(ctx, input)}
+ | }
+ |}
+ """.stripMargin
+ } else {
+ val matches = ctx.freshName("matches")
+ val iteratorCls = classOf[Iterator[UnsafeRow]].getName
+ val found = ctx.freshName("found")
+
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// find matches from HashRelation
+ |$iteratorCls $matches = $anyNull ? null :
($iteratorCls)$relationTerm.get(${keyEv.value});
+ |if ($matches != null) {
+ | boolean $found = false;
+ | while (!$found && $matches.hasNext()) {
+ | UnsafeRow $matched = (UnsafeRow) $matches.next();
+ | $checkCondition {
+ | $found = true;
+ | }
+ | }
+ | if ($found) {
+ | $numOutput.add(1);
+ | ${consume(ctx, input)}
+ | }
+ |}
+ """.stripMargin
+ }
+ }
+
+ /**
+ * Generates the code for anti join.
+ */
+ protected def codegenAnti(ctx: CodegenContext, input: Seq[ExprCode]): String
= {
+ val (relationTerm, keyIsKnownUnique) = prepareRelation(ctx)
+ val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
+ val (matched, checkCondition, _) = getJoinCondition(ctx, input)
+ val numOutput = metricTerm(ctx, "numOutputRows")
+
+ if (keyIsKnownUnique) {
+ val found = ctx.freshName("found")
+ s"""
+ |boolean $found = false;
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// Check if the key has nulls.
+ |if (!($anyNull)) {
+ | // Check if the HashedRelation exists.
+ | UnsafeRow $matched =
(UnsafeRow)$relationTerm.getValue(${keyEv.value});
+ | if ($matched != null) {
+ | // Evaluate the condition.
+ | $checkCondition {
+ | $found = true;
+ | }
+ | }
+ |}
+ |if (!$found) {
+ | $numOutput.add(1);
+ | ${consume(ctx, input)}
+ |}
+ """.stripMargin
+ } else {
+ val matches = ctx.freshName("matches")
+ val iteratorCls = classOf[Iterator[UnsafeRow]].getName
+ val found = ctx.freshName("found")
+ s"""
+ |boolean $found = false;
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// Check if the key has nulls.
+ |if (!($anyNull)) {
+ | // Check if the HashedRelation exists.
+ | $iteratorCls $matches =
($iteratorCls)$relationTerm.get(${keyEv.value});
+ | if ($matches != null) {
+ | // Evaluate the condition.
+ | while (!$found && $matches.hasNext()) {
+ | UnsafeRow $matched = (UnsafeRow) $matches.next();
+ | $checkCondition {
+ | $found = true;
+ | }
+ | }
+ | }
+ |}
+ |if (!$found) {
+ | $numOutput.add(1);
+ | ${consume(ctx, input)}
+ |}
+ """.stripMargin
+ }
+ }
+
+ /**
+ * Generates the code for existence join.
+ */
+ protected def codegenExistence(ctx: CodegenContext, input: Seq[ExprCode]):
String = {
+ val (relationTerm, keyIsKnownUnique) = prepareRelation(ctx)
+ val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
+ val numOutput = metricTerm(ctx, "numOutputRows")
+ val existsVar = ctx.freshName("exists")
+
+ val matched = ctx.freshName("matched")
+ val buildVars = genBuildSideVars(ctx, matched)
+ val checkCondition = if (condition.isDefined) {
+ val expr = condition.get
+ // evaluate the variables from build side that used by condition
+ val eval = evaluateRequiredVariables(buildPlan.output, buildVars,
expr.references)
+ // filter the output via condition
+ ctx.currentVars = input ++ buildVars
+ val ev =
+ BindReferences.bindReference(expr, streamedPlan.output ++
buildPlan.output).genCode(ctx)
+ s"""
+ |$eval
+ |${ev.code}
+ |$existsVar = !${ev.isNull} && ${ev.value};
+ """.stripMargin
+ } else {
+ s"$existsVar = true;"
+ }
+
+ val resultVar = input ++ Seq(ExprCode.forNonNullValue(
+ JavaCode.variable(existsVar, BooleanType)))
+
+ if (keyIsKnownUnique) {
Review comment:
added `keyIsKnownUnique` to support unique-key code path for
`BroadcastHashJoinExec`.
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
##########
@@ -316,6 +318,387 @@ trait HashJoin extends BaseJoinExec {
resultProj(r)
}
}
+
+ /**
+ * Returns the code for generating join key for stream side, and expression
of whether the key
+ * has any null in it or not.
+ */
+ protected def genStreamSideJoinKey(
+ ctx: CodegenContext,
+ input: Seq[ExprCode]): (ExprCode, String) = {
+ ctx.currentVars = input
+ if (streamedBoundKeys.length == 1 && streamedBoundKeys.head.dataType ==
LongType) {
+ // generate the join key as Long
+ val ev = streamedBoundKeys.head.genCode(ctx)
+ (ev, ev.isNull)
+ } else {
+ // generate the join key as UnsafeRow
+ val ev = GenerateUnsafeProjection.createCode(ctx, streamedBoundKeys)
+ (ev, s"${ev.value}.anyNull()")
+ }
+ }
+
+ /**
+ * Generates the code for variable of build side.
+ */
+ private def genBuildSideVars(ctx: CodegenContext, matched: String):
Seq[ExprCode] = {
+ ctx.currentVars = null
+ ctx.INPUT_ROW = matched
+ buildPlan.output.zipWithIndex.map { case (a, i) =>
+ val ev = BoundReference(i, a.dataType, a.nullable).genCode(ctx)
+ if (joinType.isInstanceOf[InnerLike]) {
+ ev
+ } else {
+ // the variables are needed even there is no matched rows
+ val isNull = ctx.freshName("isNull")
+ val value = ctx.freshName("value")
+ val javaType = CodeGenerator.javaType(a.dataType)
+ val code = code"""
+ |boolean $isNull = true;
+ |$javaType $value = ${CodeGenerator.defaultValue(a.dataType)};
+ |if ($matched != null) {
+ | ${ev.code}
+ | $isNull = ${ev.isNull};
+ | $value = ${ev.value};
+ |}
+ """.stripMargin
+ ExprCode(code, JavaCode.isNullVariable(isNull),
JavaCode.variable(value, a.dataType))
+ }
+ }
+ }
+
+ /**
+ * Generate the (non-equi) condition used to filter joined rows. This is
used in Inner, Left Semi
+ * and Left Anti joins.
+ */
+ protected def getJoinCondition(
+ ctx: CodegenContext,
+ input: Seq[ExprCode]): (String, String, Seq[ExprCode]) = {
+ val matched = ctx.freshName("matched")
+ val buildVars = genBuildSideVars(ctx, matched)
+ val checkCondition = if (condition.isDefined) {
+ val expr = condition.get
+ // evaluate the variables from build side that used by condition
+ val eval = evaluateRequiredVariables(buildPlan.output, buildVars,
expr.references)
+ // filter the output via condition
+ ctx.currentVars = input ++ buildVars
+ val ev =
+ BindReferences.bindReference(expr, streamedPlan.output ++
buildPlan.output).genCode(ctx)
+ val skipRow = s"${ev.isNull} || !${ev.value}"
+ s"""
+ |$eval
+ |${ev.code}
+ |if (!($skipRow))
+ """.stripMargin
+ } else {
+ ""
+ }
+ (matched, checkCondition, buildVars)
+ }
+
+ /**
+ * Generates the code for Inner join.
+ */
+ protected def codegenInner(ctx: CodegenContext, input: Seq[ExprCode]):
String = {
+ val (relationTerm, keyIsKnownUnique) = prepareRelation(ctx)
+ val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
+ val (matched, checkCondition, buildVars) = getJoinCondition(ctx, input)
+ val numOutput = metricTerm(ctx, "numOutputRows")
+
+ val resultVars = buildSide match {
+ case BuildLeft => buildVars ++ input
+ case BuildRight => input ++ buildVars
+ }
+
+ if (keyIsKnownUnique) {
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// find matches from HashedRelation
+ |UnsafeRow $matched = $anyNull ? null:
(UnsafeRow)$relationTerm.getValue(${keyEv.value});
+ |if ($matched != null) {
+ | $checkCondition {
+ | $numOutput.add(1);
+ | ${consume(ctx, resultVars)}
+ | }
+ |}
+ """.stripMargin
+ } else {
+ val matches = ctx.freshName("matches")
+ val iteratorCls = classOf[Iterator[UnsafeRow]].getName
+
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// find matches from HashRelation
+ |$iteratorCls $matches = $anyNull ?
+ | null : ($iteratorCls)$relationTerm.get(${keyEv.value});
+ |if ($matches != null) {
+ | while ($matches.hasNext()) {
+ | UnsafeRow $matched = (UnsafeRow) $matches.next();
+ | $checkCondition {
+ | $numOutput.add(1);
+ | ${consume(ctx, resultVars)}
+ | }
+ | }
+ |}
+ """.stripMargin
+ }
+ }
+
+ /**
+ * Generates the code for left or right outer join.
+ */
+ protected def codegenOuter(ctx: CodegenContext, input: Seq[ExprCode]):
String = {
+ val (relationTerm, keyIsKnownUnique) = prepareRelation(ctx)
+ val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
+ val matched = ctx.freshName("matched")
+ val buildVars = genBuildSideVars(ctx, matched)
+ val numOutput = metricTerm(ctx, "numOutputRows")
+
+ // filter the output via condition
+ val conditionPassed = ctx.freshName("conditionPassed")
+ val checkCondition = if (condition.isDefined) {
+ val expr = condition.get
+ // evaluate the variables from build side that used by condition
+ val eval = evaluateRequiredVariables(buildPlan.output, buildVars,
expr.references)
+ ctx.currentVars = input ++ buildVars
+ val ev =
+ BindReferences.bindReference(expr, streamedPlan.output ++
buildPlan.output).genCode(ctx)
+ s"""
+ |boolean $conditionPassed = true;
+ |${eval.trim}
+ |if ($matched != null) {
+ | ${ev.code}
+ | $conditionPassed = !${ev.isNull} && ${ev.value};
+ |}
+ """.stripMargin
+ } else {
+ s"final boolean $conditionPassed = true;"
+ }
+
+ val resultVars = buildSide match {
+ case BuildLeft => buildVars ++ input
+ case BuildRight => input ++ buildVars
+ }
+
+ if (keyIsKnownUnique) {
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// find matches from HashedRelation
+ |UnsafeRow $matched = $anyNull ? null:
(UnsafeRow)$relationTerm.getValue(${keyEv.value});
+ |${checkCondition.trim}
+ |if (!$conditionPassed) {
+ | $matched = null;
+ | // reset the variables those are already evaluated.
+ | ${buildVars.filter(_.code.isEmpty).map(v => s"${v.isNull} =
true;").mkString("\n")}
+ |}
+ |$numOutput.add(1);
+ |${consume(ctx, resultVars)}
+ """.stripMargin
+ } else {
+ val matches = ctx.freshName("matches")
+ val iteratorCls = classOf[Iterator[UnsafeRow]].getName
+ val found = ctx.freshName("found")
+
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// find matches from HashRelation
+ |$iteratorCls $matches = $anyNull ? null :
($iteratorCls)$relationTerm.get(${keyEv.value});
+ |boolean $found = false;
+ |// the last iteration of this loop is to emit an empty row if there
is no matched rows.
+ |while ($matches != null && $matches.hasNext() || !$found) {
+ | UnsafeRow $matched = $matches != null && $matches.hasNext() ?
+ | (UnsafeRow) $matches.next() : null;
+ | ${checkCondition.trim}
+ | if ($conditionPassed) {
+ | $found = true;
+ | $numOutput.add(1);
+ | ${consume(ctx, resultVars)}
+ | }
+ |}
+ """.stripMargin
+ }
+ }
+
+ /**
+ * Generates the code for left semi join.
+ */
+ protected def codegenSemi(ctx: CodegenContext, input: Seq[ExprCode]): String
= {
+ val (relationTerm, keyIsKnownUnique) = prepareRelation(ctx)
+ val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
+ val (matched, checkCondition, _) = getJoinCondition(ctx, input)
+ val numOutput = metricTerm(ctx, "numOutputRows")
+
+ if (keyIsKnownUnique) {
Review comment:
added `keyIsKnownUnique` to support unique-key code path for
`BroadcastHashJoinExec`.
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
##########
@@ -316,6 +318,387 @@ trait HashJoin extends BaseJoinExec {
resultProj(r)
}
}
+
+ /**
+ * Returns the code for generating join key for stream side, and expression
of whether the key
+ * has any null in it or not.
+ */
+ protected def genStreamSideJoinKey(
+ ctx: CodegenContext,
+ input: Seq[ExprCode]): (ExprCode, String) = {
+ ctx.currentVars = input
+ if (streamedBoundKeys.length == 1 && streamedBoundKeys.head.dataType ==
LongType) {
+ // generate the join key as Long
+ val ev = streamedBoundKeys.head.genCode(ctx)
+ (ev, ev.isNull)
+ } else {
+ // generate the join key as UnsafeRow
+ val ev = GenerateUnsafeProjection.createCode(ctx, streamedBoundKeys)
+ (ev, s"${ev.value}.anyNull()")
+ }
+ }
+
+ /**
+ * Generates the code for variable of build side.
+ */
+ private def genBuildSideVars(ctx: CodegenContext, matched: String):
Seq[ExprCode] = {
+ ctx.currentVars = null
+ ctx.INPUT_ROW = matched
+ buildPlan.output.zipWithIndex.map { case (a, i) =>
+ val ev = BoundReference(i, a.dataType, a.nullable).genCode(ctx)
+ if (joinType.isInstanceOf[InnerLike]) {
+ ev
+ } else {
+ // the variables are needed even there is no matched rows
+ val isNull = ctx.freshName("isNull")
+ val value = ctx.freshName("value")
+ val javaType = CodeGenerator.javaType(a.dataType)
+ val code = code"""
+ |boolean $isNull = true;
+ |$javaType $value = ${CodeGenerator.defaultValue(a.dataType)};
+ |if ($matched != null) {
+ | ${ev.code}
+ | $isNull = ${ev.isNull};
+ | $value = ${ev.value};
+ |}
+ """.stripMargin
+ ExprCode(code, JavaCode.isNullVariable(isNull),
JavaCode.variable(value, a.dataType))
+ }
+ }
+ }
+
+ /**
+ * Generate the (non-equi) condition used to filter joined rows. This is
used in Inner, Left Semi
+ * and Left Anti joins.
+ */
+ protected def getJoinCondition(
+ ctx: CodegenContext,
+ input: Seq[ExprCode]): (String, String, Seq[ExprCode]) = {
+ val matched = ctx.freshName("matched")
+ val buildVars = genBuildSideVars(ctx, matched)
+ val checkCondition = if (condition.isDefined) {
+ val expr = condition.get
+ // evaluate the variables from build side that used by condition
+ val eval = evaluateRequiredVariables(buildPlan.output, buildVars,
expr.references)
+ // filter the output via condition
+ ctx.currentVars = input ++ buildVars
+ val ev =
+ BindReferences.bindReference(expr, streamedPlan.output ++
buildPlan.output).genCode(ctx)
+ val skipRow = s"${ev.isNull} || !${ev.value}"
+ s"""
+ |$eval
+ |${ev.code}
+ |if (!($skipRow))
+ """.stripMargin
+ } else {
+ ""
+ }
+ (matched, checkCondition, buildVars)
+ }
+
+ /**
+ * Generates the code for Inner join.
+ */
+ protected def codegenInner(ctx: CodegenContext, input: Seq[ExprCode]):
String = {
+ val (relationTerm, keyIsKnownUnique) = prepareRelation(ctx)
+ val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
+ val (matched, checkCondition, buildVars) = getJoinCondition(ctx, input)
+ val numOutput = metricTerm(ctx, "numOutputRows")
+
+ val resultVars = buildSide match {
+ case BuildLeft => buildVars ++ input
+ case BuildRight => input ++ buildVars
+ }
+
+ if (keyIsKnownUnique) {
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// find matches from HashedRelation
+ |UnsafeRow $matched = $anyNull ? null:
(UnsafeRow)$relationTerm.getValue(${keyEv.value});
+ |if ($matched != null) {
+ | $checkCondition {
+ | $numOutput.add(1);
+ | ${consume(ctx, resultVars)}
+ | }
+ |}
+ """.stripMargin
+ } else {
+ val matches = ctx.freshName("matches")
+ val iteratorCls = classOf[Iterator[UnsafeRow]].getName
+
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// find matches from HashRelation
+ |$iteratorCls $matches = $anyNull ?
+ | null : ($iteratorCls)$relationTerm.get(${keyEv.value});
+ |if ($matches != null) {
+ | while ($matches.hasNext()) {
+ | UnsafeRow $matched = (UnsafeRow) $matches.next();
+ | $checkCondition {
+ | $numOutput.add(1);
+ | ${consume(ctx, resultVars)}
+ | }
+ | }
+ |}
+ """.stripMargin
+ }
+ }
+
+ /**
+ * Generates the code for left or right outer join.
+ */
+ protected def codegenOuter(ctx: CodegenContext, input: Seq[ExprCode]):
String = {
+ val (relationTerm, keyIsKnownUnique) = prepareRelation(ctx)
+ val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
+ val matched = ctx.freshName("matched")
+ val buildVars = genBuildSideVars(ctx, matched)
+ val numOutput = metricTerm(ctx, "numOutputRows")
+
+ // filter the output via condition
+ val conditionPassed = ctx.freshName("conditionPassed")
+ val checkCondition = if (condition.isDefined) {
+ val expr = condition.get
+ // evaluate the variables from build side that used by condition
+ val eval = evaluateRequiredVariables(buildPlan.output, buildVars,
expr.references)
+ ctx.currentVars = input ++ buildVars
+ val ev =
+ BindReferences.bindReference(expr, streamedPlan.output ++
buildPlan.output).genCode(ctx)
+ s"""
+ |boolean $conditionPassed = true;
+ |${eval.trim}
+ |if ($matched != null) {
+ | ${ev.code}
+ | $conditionPassed = !${ev.isNull} && ${ev.value};
+ |}
+ """.stripMargin
+ } else {
+ s"final boolean $conditionPassed = true;"
+ }
+
+ val resultVars = buildSide match {
+ case BuildLeft => buildVars ++ input
+ case BuildRight => input ++ buildVars
+ }
+
+ if (keyIsKnownUnique) {
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// find matches from HashedRelation
+ |UnsafeRow $matched = $anyNull ? null:
(UnsafeRow)$relationTerm.getValue(${keyEv.value});
+ |${checkCondition.trim}
+ |if (!$conditionPassed) {
+ | $matched = null;
+ | // reset the variables those are already evaluated.
+ | ${buildVars.filter(_.code.isEmpty).map(v => s"${v.isNull} =
true;").mkString("\n")}
+ |}
+ |$numOutput.add(1);
+ |${consume(ctx, resultVars)}
+ """.stripMargin
+ } else {
+ val matches = ctx.freshName("matches")
+ val iteratorCls = classOf[Iterator[UnsafeRow]].getName
+ val found = ctx.freshName("found")
+
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// find matches from HashRelation
+ |$iteratorCls $matches = $anyNull ? null :
($iteratorCls)$relationTerm.get(${keyEv.value});
+ |boolean $found = false;
+ |// the last iteration of this loop is to emit an empty row if there
is no matched rows.
+ |while ($matches != null && $matches.hasNext() || !$found) {
+ | UnsafeRow $matched = $matches != null && $matches.hasNext() ?
+ | (UnsafeRow) $matches.next() : null;
+ | ${checkCondition.trim}
+ | if ($conditionPassed) {
+ | $found = true;
+ | $numOutput.add(1);
+ | ${consume(ctx, resultVars)}
+ | }
+ |}
+ """.stripMargin
+ }
+ }
+
+ /**
+ * Generates the code for left semi join.
+ */
+ protected def codegenSemi(ctx: CodegenContext, input: Seq[ExprCode]): String
= {
+ val (relationTerm, keyIsKnownUnique) = prepareRelation(ctx)
+ val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
+ val (matched, checkCondition, _) = getJoinCondition(ctx, input)
+ val numOutput = metricTerm(ctx, "numOutputRows")
+
+ if (keyIsKnownUnique) {
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// find matches from HashedRelation
+ |UnsafeRow $matched = $anyNull ? null:
(UnsafeRow)$relationTerm.getValue(${keyEv.value});
+ |if ($matched != null) {
+ | $checkCondition {
+ | $numOutput.add(1);
+ | ${consume(ctx, input)}
+ | }
+ |}
+ """.stripMargin
+ } else {
+ val matches = ctx.freshName("matches")
+ val iteratorCls = classOf[Iterator[UnsafeRow]].getName
+ val found = ctx.freshName("found")
+
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// find matches from HashRelation
+ |$iteratorCls $matches = $anyNull ? null :
($iteratorCls)$relationTerm.get(${keyEv.value});
+ |if ($matches != null) {
+ | boolean $found = false;
+ | while (!$found && $matches.hasNext()) {
+ | UnsafeRow $matched = (UnsafeRow) $matches.next();
+ | $checkCondition {
+ | $found = true;
+ | }
+ | }
+ | if ($found) {
+ | $numOutput.add(1);
+ | ${consume(ctx, input)}
+ | }
+ |}
+ """.stripMargin
+ }
+ }
+
+ /**
+ * Generates the code for anti join.
+ */
+ protected def codegenAnti(ctx: CodegenContext, input: Seq[ExprCode]): String
= {
+ val (relationTerm, keyIsKnownUnique) = prepareRelation(ctx)
+ val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
+ val (matched, checkCondition, _) = getJoinCondition(ctx, input)
+ val numOutput = metricTerm(ctx, "numOutputRows")
+
+ if (keyIsKnownUnique) {
Review comment:
added `keyIsKnownUnique` to support unique-key code path for
`BroadcastHashJoinExec`.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]