Github user davies commented on a diff in the pull request:

    https://github.com/apache/spark/pull/11248#discussion_r53718141
  
    --- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
 ---
    @@ -125,6 +126,246 @@ case class SortMergeJoin(
           }.toScala
         }
       }
    +
    +  override def upstreams(): Seq[RDD[InternalRow]] = {
    +    left.execute() :: right.execute() :: Nil
    +  }
    +
    +  private def createJoinKey(
    +      ctx: CodegenContext,
    +      row: String,
    +      keys: Seq[Expression],
    +      input: Seq[Attribute]): Seq[ExprCode] = {
    +    ctx.INPUT_ROW = row
    +    keys.map(BindReferences.bindReference(_, input).gen(ctx))
    +  }
    +
    +  private def copyKeys(ctx: CodegenContext, vars: Seq[ExprCode]): 
Seq[ExprCode] = {
    +    vars.zipWithIndex.map { case (ev, i) =>
    +      val value = ctx.freshName("value")
    +      ctx.addMutableState(ctx.javaType(leftKeys(i).dataType), value, "")
    +      val code =
    +        s"""
    +           |$value = ${ev.value};
    +         """.stripMargin
    +      ExprCode(code, "false", value)
    +    }
    +  }
    +
    +  private def genComparision(ctx: CodegenContext, a: Seq[ExprCode], b: 
Seq[ExprCode]): String = {
    +    val comparisons = a.zip(b).zipWithIndex.map { case ((l, r), i) =>
    +      s"""
    +         |if (comp == 0) {
    +         |  comp = ${ctx.genComp(leftKeys(i).dataType, l.value, r.value)};
    +         |}
    +       """.stripMargin.trim
    +    }
    +    s"""
    +       |comp = 0;
    +       |${comparisons.mkString("\n")}
    +     """.stripMargin
    +  }
    +
    +  /**
    +    * Generate a function to scan both left and right to find a match, 
returns the term for
    +    * matched one row from left side and buffered rows from right side.
    +    */
    +  private def genScanner(ctx: CodegenContext): (String, String) = {
    +    // Create class member for next row from both sides.
    +    val leftRow = ctx.freshName("leftRow")
    +    ctx.addMutableState("InternalRow", leftRow, "")
    +    val rightRow = ctx.freshName("rightRow")
    +    ctx.addMutableState("InternalRow", rightRow, s"$rightRow = null;")
    +
    +    // Create variables for join keys from both sides.
    +    val leftKeyVars = createJoinKey(ctx, leftRow, leftKeys, left.output)
    +    val leftAnyNull = leftKeyVars.map(_.isNull).mkString(" || ")
    +    val rightKeyTmpVars = createJoinKey(ctx, rightRow, rightKeys, 
right.output)
    +    val rightAnyNull = rightKeyTmpVars.map(_.isNull).mkString(" || ")
    +    // Copy the right key as class members so they could be used in next 
function call.
    +    val rightKeyVars = copyKeys(ctx, rightKeyTmpVars)
    +
    +    // A list to hold all matched rows from right side.
    +    val matches = ctx.freshName("matches")
    +    val clsName = classOf[java.util.ArrayList[InternalRow]].getName
    +    ctx.addMutableState(clsName, matches, s"$matches = new $clsName();")
    +    // Copy the left keys as class members so they could be used in next 
function call.
    +    val matchedKeyVars = copyKeys(ctx, leftKeyVars)
    +
    +    ctx.addNewFunction("findNextInnerJoinRows",
    +      s"""
    +         |private boolean findNextInnerJoinRows(
    +         |    scala.collection.Iterator leftIter,
    +         |    scala.collection.Iterator rightIter) {
    +         |  $leftRow = null;
    +         |  int comp = 0;
    +         |  while ($leftRow == null) {
    +         |    if (!leftIter.hasNext()) return false;
    +         |    $leftRow = (InternalRow) leftIter.next();
    +         |    ${leftKeyVars.map(_.code).mkString("\n")}
    +         |    if ($leftAnyNull) {
    +         |      $leftRow = null;
    +         |      continue;
    +         |    }
    +         |    if (!$matches.isEmpty()) {
    +         |      ${genComparision(ctx, leftKeyVars, matchedKeyVars)}
    +         |      if (comp == 0) {
    +         |        return true;
    +         |      }
    +         |      $matches.clear();
    +         |    }
    +         |
    +         |    do {
    +         |      if ($rightRow == null) {
    +         |        if (!rightIter.hasNext()) {
    +         |          ${matchedKeyVars.map(_.code).mkString("\n")}
    +         |          return !$matches.isEmpty();
    +         |        }
    +         |        $rightRow = (InternalRow) rightIter.next();
    +         |        ${rightKeyTmpVars.map(_.code).mkString("\n")}
    +         |        if ($rightAnyNull) {
    +         |          $rightRow = null;
    +         |          continue;
    +         |        }
    +         |        ${rightKeyVars.map(_.code).mkString("\n")}
    +         |      }
    +         |      ${genComparision(ctx, leftKeyVars, rightKeyVars)}
    +         |      if (comp > 0) {
    +         |        $rightRow = null;
    +         |      } else if (comp < 0) {
    +         |        if (!$matches.isEmpty()) {
    +         |          ${matchedKeyVars.map(_.code).mkString("\n")}
    +         |          return true;
    +         |        }
    +         |        $leftRow = null;
    +         |      } else {
    +         |        $matches.add($rightRow.copy());
    +         |        $rightRow = null;;
    +         |      }
    +         |    } while ($leftRow != null);
    +         |  }
    +         |  return false; // unreachable
    +         |}
    +       """.stripMargin)
    +
    +    (leftRow, matches)
    +  }
    +
    +  /**
    +    * Creates variables for left part of result row.
    +    *
    +    * In order to defer the access after condition and also only access 
once in the loop,
    +    * the variables should be declared separately from accessing the 
columns, we can't use the
    +    * codegen of BoundReference here.
    +    */
    +  private def createLeftVars(ctx: CodegenContext, leftRow: String): 
Seq[ExprCode] = {
    +    ctx.INPUT_ROW = leftRow
    +    left.output.zipWithIndex.map { case (a, i) =>
    +      val value = ctx.freshName("value")
    +      val valueCode = ctx.getValue(leftRow, a.dataType, i.toString)
    +      // declare it as class member, so we can access the column before or 
in the loop.
    +      ctx.addMutableState(ctx.javaType(a.dataType), value, "")
    +      if (a.nullable) {
    +        val isNull = ctx.freshName("isNull")
    +        ctx.addMutableState("boolean", isNull, "")
    +        val code =
    +          s"""
    +             |$isNull = $leftRow.isNullAt($i);
    +             |$value = $isNull ? ${ctx.defaultValue(a.dataType)} : 
($valueCode);
    +           """.stripMargin
    +        ExprCode(code, isNull, value)
    +      } else {
    +        ExprCode(s"$value = $valueCode;", "false", value)
    +      }
    +    }
    +  }
    +
    +  /**
    +    * Creates the variables for right part of result row, using 
BoundReference, since the right
    +    * part are accessed inside the loop.
    +    */
    +  private def createRightVar(ctx: CodegenContext, rightRow: String): 
Seq[ExprCode] = {
    +    ctx.INPUT_ROW = rightRow
    +    right.output.zipWithIndex.map { case (a, i) =>
    +      BoundReference(i, a.dataType, a.nullable).gen(ctx)
    +    }
    +  }
    +
    +  /**
    +    * Splits variables based on whether it's used by condition or not, 
returns the code to create
    +    * these variables before the condition and after the condition.
    +    *
    +    * Only a few columns are used by condition, then we can skip the 
accessing of those columns
    +    * that are not used by condition also filtered out by condition.
    +    */
    +  private def splitVarsByCondition(
    +      attributes: Seq[Attribute],
    +      variables: Seq[ExprCode]): (String, String) = {
    +    if (condition.isDefined) {
    +      val condRefs = condition.get.references
    +      val (used, notUsed) = attributes.zip(variables).partition{ case (a, 
ev) =>
    +        condRefs.contains(a)
    +      }
    +      val beforeCond = used.map(_._2.code).mkString("\n")
    +      val afterCond = notUsed.map(_._2.code).mkString("\n")
    +      (beforeCond, afterCond)
    +    } else {
    +      (variables.map(_.code).mkString("\n"), "")
    +    }
    +  }
    +
    +  override def doProduce(ctx: CodegenContext): String = {
    +    val leftInput = ctx.freshName("leftInput")
    +    ctx.addMutableState("scala.collection.Iterator", leftInput, 
s"$leftInput = inputs[0];")
    +    val rightInput = ctx.freshName("rightInput")
    +    ctx.addMutableState("scala.collection.Iterator", rightInput, 
s"$rightInput = inputs[1];")
    +
    +    val (leftRow, matches) = genScanner(ctx)
    +
    +    // Create variables for row from both sides.
    +    val leftVars = createLeftVars(ctx, leftRow)
    +    val rightRow = ctx.freshName("rightRow")
    +    val rightVars = createRightVar(ctx, rightRow)
    +    val resultVars = leftVars ++ rightVars
    +
    +    // Check condition
    +    ctx.currentVars = resultVars
    +    val cond = if (condition.isDefined) {
    +      BindReferences.bindReference(condition.get, output).gen(ctx)
    +    } else {
    +      ExprCode("", "false", "true")
    +    }
    +    // Split the code of creating variables based on whether it's used by 
condition or not.
    +    val loaded = ctx.freshName("loaded")
    +    val (leftBefore, leftAfter) = splitVarsByCondition(left.output, 
leftVars)
    +    val (rightBefore, rightAfter) = splitVarsByCondition(right.output, 
rightVars)
    +
    +
    +    val size = ctx.freshName("size")
    +    val i = ctx.freshName("i")
    +    val numOutput = metricTerm(ctx, "numOutputRows")
    +    s"""
    +       |while (findNextInnerJoinRows($leftInput, $rightInput)) {
    +       |  int $size = $matches.size();
    +       |  boolean $loaded = false;
    +       |  $leftBefore
    +       |  for (int $i = 0; $i < $size; $i ++) {
    +       |    InternalRow $rightRow = (InternalRow) $matches.get($i);
    +       |    $rightBefore
    +       |    ${cond.code}
    +       |    if (${cond.isNull} || !${cond.value}) continue;
    +       |    if (!$loaded) {
    +       |      $loaded = true;
    +       |      $leftAfter
    +       |    }
    +       |    $rightAfter
    +       |    $numOutput.add(1);
    --- End diff --
    
    see https://github.com/apache/spark/pull/11170
    
    It costs 0.1 - 0.2 nanoseconds per rows.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to