c21 commented on a change in pull request #29342:
URL: https://github.com/apache/spark/pull/29342#discussion_r467505445



##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
##########
@@ -71,8 +88,122 @@ case class ShuffledHashJoinExec(
     val numOutputRows = longMetric("numOutputRows")
     streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, 
buildIter) =>
       val hashed = buildHashedRelation(buildIter)
-      join(streamIter, hashed, numOutputRows)
+      joinType match {
+        case FullOuter => fullOuterJoin(streamIter, hashed, numOutputRows)
+        case _ => join(streamIter, hashed, numOutputRows)
+      }
+    }
+  }
+
+  /**
+   * Full outer shuffled hash join has three steps:
+   * 1. Construct hash relation from build side,
+   *    with extra boolean value at the end of row to track look up information
+   *    (done in `buildHashedRelation`).
+   * 2. Process rows from stream side by looking up hash relation,
+   *    and mark the matched rows from build side be looked up.
+   * 3. Process rows from build side by iterating hash relation,
+   *    and filter out rows from build side being looked up already.
+   */
+  private def fullOuterJoin(
+      streamIter: Iterator[InternalRow],
+      hashedRelation: HashedRelation,
+      numOutputRows: SQLMetric): Iterator[InternalRow] = {
+    val joinRow = new JoinedRow
+    val (joinRowWithStream, joinRowWithBuild) = {
+      buildSide match {
+        case BuildLeft => (joinRow.withRight _, joinRow.withLeft _)
+        case BuildRight => (joinRow.withLeft _, joinRow.withRight _)
+      }
+    }
+    val joinKeys = streamSideKeyGenerator()
+    val buildRowGenerator = UnsafeProjection.create(buildOutput, buildOutput)
+    val buildNullRow = new GenericInternalRow(buildOutput.length)
+    val streamNullRow = new GenericInternalRow(streamedOutput.length)
+
+    def markRowLookedUp(row: UnsafeRow): Unit =
+      row.setBoolean(row.numFields() - 1, true)
+
+    // Process stream side with looking up hash relation
+    val streamResultIter =
+      if (hashedRelation.keyIsUnique) {
+        streamIter.map { srow =>
+          joinRowWithStream(srow)
+          val keys = joinKeys(srow)
+          if (keys.anyNull) {
+            joinRowWithBuild(buildNullRow)
+          } else {
+            val matched = hashedRelation.getValue(keys)
+            if (matched != null) {
+              val buildRow = buildRowGenerator(matched)
+              if (boundCondition(joinRowWithBuild(buildRow))) {
+                markRowLookedUp(matched.asInstanceOf[UnsafeRow])
+                joinRow
+              } else {
+                joinRowWithBuild(buildNullRow)
+              }
+            } else {
+              joinRowWithBuild(buildNullRow)
+            }
+          }
+        }
+      } else {
+        streamIter.flatMap { srow =>
+          joinRowWithStream(srow)
+          val keys = joinKeys(srow)
+          if (keys.anyNull) {
+            Iterator.single(joinRowWithBuild(buildNullRow))
+          } else {
+            val buildIter = hashedRelation.get(keys)
+            new RowIterator {
+              private var found = false
+              override def advanceNext(): Boolean = {
+                while (buildIter != null && buildIter.hasNext) {
+                  val matched = buildIter.next()
+                  val buildRow = buildRowGenerator(matched)
+                  if (boundCondition(joinRowWithBuild(buildRow))) {
+                    markRowLookedUp(matched.asInstanceOf[UnsafeRow])
+                    found = true
+                    return true
+                  }
+                }
+                if (!found) {
+                  joinRowWithBuild(buildNullRow)
+                  found = true

Review comment:
       @viirya - see above comment for explanation.

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
##########
@@ -71,8 +88,122 @@ case class ShuffledHashJoinExec(
     val numOutputRows = longMetric("numOutputRows")
     streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, 
buildIter) =>
       val hashed = buildHashedRelation(buildIter)
-      join(streamIter, hashed, numOutputRows)
+      joinType match {
+        case FullOuter => fullOuterJoin(streamIter, hashed, numOutputRows)
+        case _ => join(streamIter, hashed, numOutputRows)
+      }
+    }
+  }
+
+  /**
+   * Full outer shuffled hash join has three steps:
+   * 1. Construct hash relation from build side,
+   *    with extra boolean value at the end of row to track look up information
+   *    (done in `buildHashedRelation`).
+   * 2. Process rows from stream side by looking up hash relation,
+   *    and mark the matched rows from build side be looked up.
+   * 3. Process rows from build side by iterating hash relation,
+   *    and filter out rows from build side being looked up already.
+   */
+  private def fullOuterJoin(
+      streamIter: Iterator[InternalRow],
+      hashedRelation: HashedRelation,
+      numOutputRows: SQLMetric): Iterator[InternalRow] = {
+    val joinRow = new JoinedRow
+    val (joinRowWithStream, joinRowWithBuild) = {
+      buildSide match {
+        case BuildLeft => (joinRow.withRight _, joinRow.withLeft _)
+        case BuildRight => (joinRow.withLeft _, joinRow.withRight _)
+      }
+    }
+    val joinKeys = streamSideKeyGenerator()
+    val buildRowGenerator = UnsafeProjection.create(buildOutput, buildOutput)
+    val buildNullRow = new GenericInternalRow(buildOutput.length)
+    val streamNullRow = new GenericInternalRow(streamedOutput.length)
+
+    def markRowLookedUp(row: UnsafeRow): Unit =
+      row.setBoolean(row.numFields() - 1, true)
+
+    // Process stream side with looking up hash relation
+    val streamResultIter =
+      if (hashedRelation.keyIsUnique) {
+        streamIter.map { srow =>
+          joinRowWithStream(srow)
+          val keys = joinKeys(srow)
+          if (keys.anyNull) {
+            joinRowWithBuild(buildNullRow)
+          } else {
+            val matched = hashedRelation.getValue(keys)
+            if (matched != null) {
+              val buildRow = buildRowGenerator(matched)
+              if (boundCondition(joinRowWithBuild(buildRow))) {
+                markRowLookedUp(matched.asInstanceOf[UnsafeRow])
+                joinRow
+              } else {
+                joinRowWithBuild(buildNullRow)
+              }
+            } else {
+              joinRowWithBuild(buildNullRow)
+            }
+          }
+        }
+      } else {
+        streamIter.flatMap { srow =>
+          joinRowWithStream(srow)
+          val keys = joinKeys(srow)
+          if (keys.anyNull) {
+            Iterator.single(joinRowWithBuild(buildNullRow))
+          } else {
+            val buildIter = hashedRelation.get(keys)
+            new RowIterator {
+              private var found = false
+              override def advanceNext(): Boolean = {
+                while (buildIter != null && buildIter.hasNext) {
+                  val matched = buildIter.next()
+                  val buildRow = buildRowGenerator(matched)
+                  if (boundCondition(joinRowWithBuild(buildRow))) {
+                    markRowLookedUp(matched.asInstanceOf[UnsafeRow])
+                    found = true
+                    return true
+                  }
+                }
+                if (!found) {
+                  joinRowWithBuild(buildNullRow)
+                  found = true
+                  return true
+                }
+                false
+              }
+              override def getRow: InternalRow = joinRow
+            }.toScala
+          }
+        }
+      }
+
+    // Process build side with filtering out rows looked up already
+    val buildResultIter = hashedRelation.values().flatMap { brow =>
+      val unsafebrow = brow.asInstanceOf[UnsafeRow]
+      val isLookup = unsafebrow.getBoolean(unsafebrow.numFields() - 1)
+      if (!isLookup) {
+        val buildRow = buildRowGenerator(unsafebrow)
+        joinRowWithBuild(buildRow)
+        joinRowWithStream(streamNullRow)

Review comment:
       @viirya - given `JoinedRow` is 
[mutable](https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala#L134),
 I am afraid of the optimization here to leave a correctness hole once 
downstream iterator mutates the `JoinedRow` per row somehow in other place. So 
how about leaving it as it is similar to [sort merge 
join](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala#L1104)?

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
##########
@@ -71,8 +88,122 @@ case class ShuffledHashJoinExec(
     val numOutputRows = longMetric("numOutputRows")
     streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, 
buildIter) =>
       val hashed = buildHashedRelation(buildIter)
-      join(streamIter, hashed, numOutputRows)
+      joinType match {
+        case FullOuter => fullOuterJoin(streamIter, hashed, numOutputRows)
+        case _ => join(streamIter, hashed, numOutputRows)
+      }
+    }
+  }
+
+  /**
+   * Full outer shuffled hash join has three steps:
+   * 1. Construct hash relation from build side,
+   *    with extra boolean value at the end of row to track look up information
+   *    (done in `buildHashedRelation`).
+   * 2. Process rows from stream side by looking up hash relation,
+   *    and mark the matched rows from build side be looked up.
+   * 3. Process rows from build side by iterating hash relation,
+   *    and filter out rows from build side being looked up already.
+   */
+  private def fullOuterJoin(
+      streamIter: Iterator[InternalRow],
+      hashedRelation: HashedRelation,
+      numOutputRows: SQLMetric): Iterator[InternalRow] = {
+    val joinRow = new JoinedRow
+    val (joinRowWithStream, joinRowWithBuild) = {
+      buildSide match {
+        case BuildLeft => (joinRow.withRight _, joinRow.withLeft _)
+        case BuildRight => (joinRow.withLeft _, joinRow.withRight _)
+      }
+    }
+    val joinKeys = streamSideKeyGenerator()
+    val buildRowGenerator = UnsafeProjection.create(buildOutput, buildOutput)
+    val buildNullRow = new GenericInternalRow(buildOutput.length)
+    val streamNullRow = new GenericInternalRow(streamedOutput.length)
+
+    def markRowLookedUp(row: UnsafeRow): Unit =
+      row.setBoolean(row.numFields() - 1, true)
+
+    // Process stream side with looking up hash relation
+    val streamResultIter =
+      if (hashedRelation.keyIsUnique) {
+        streamIter.map { srow =>
+          joinRowWithStream(srow)
+          val keys = joinKeys(srow)
+          if (keys.anyNull) {
+            joinRowWithBuild(buildNullRow)
+          } else {
+            val matched = hashedRelation.getValue(keys)
+            if (matched != null) {
+              val buildRow = buildRowGenerator(matched)
+              if (boundCondition(joinRowWithBuild(buildRow))) {
+                markRowLookedUp(matched.asInstanceOf[UnsafeRow])
+                joinRow
+              } else {
+                joinRowWithBuild(buildNullRow)
+              }
+            } else {
+              joinRowWithBuild(buildNullRow)
+            }
+          }
+        }
+      } else {
+        streamIter.flatMap { srow =>
+          joinRowWithStream(srow)
+          val keys = joinKeys(srow)
+          if (keys.anyNull) {
+            Iterator.single(joinRowWithBuild(buildNullRow))
+          } else {
+            val buildIter = hashedRelation.get(keys)
+            new RowIterator {
+              private var found = false
+              override def advanceNext(): Boolean = {
+                while (buildIter != null && buildIter.hasNext) {
+                  val matched = buildIter.next()
+                  val buildRow = buildRowGenerator(matched)
+                  if (boundCondition(joinRowWithBuild(buildRow))) {
+                    markRowLookedUp(matched.asInstanceOf[UnsafeRow])
+                    found = true
+                    return true
+                  }
+                }
+                if (!found) {
+                  joinRowWithBuild(buildNullRow)
+                  found = true
+                  return true
+                }
+                false
+              }
+              override def getRow: InternalRow = joinRow
+            }.toScala
+          }
+        }
+      }
+
+    // Process build side with filtering out rows looked up already

Review comment:
       @viirya - I think `boundCondition` should only be for checking for 
matching row. Same logic for sort merge join is 
[here](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala#L1079-L1113).
 But do let me know if my understanding is not right.

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
##########
@@ -71,8 +88,122 @@ case class ShuffledHashJoinExec(
     val numOutputRows = longMetric("numOutputRows")
     streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, 
buildIter) =>
       val hashed = buildHashedRelation(buildIter)
-      join(streamIter, hashed, numOutputRows)
+      joinType match {
+        case FullOuter => fullOuterJoin(streamIter, hashed, numOutputRows)
+        case _ => join(streamIter, hashed, numOutputRows)
+      }
+    }
+  }
+
+  /**
+   * Full outer shuffled hash join has three steps:
+   * 1. Construct hash relation from build side,
+   *    with extra boolean value at the end of row to track look up information
+   *    (done in `buildHashedRelation`).
+   * 2. Process rows from stream side by looking up hash relation,
+   *    and mark the matched rows from build side be looked up.
+   * 3. Process rows from build side by iterating hash relation,
+   *    and filter out rows from build side being looked up already.
+   */
+  private def fullOuterJoin(
+      streamIter: Iterator[InternalRow],
+      hashedRelation: HashedRelation,
+      numOutputRows: SQLMetric): Iterator[InternalRow] = {
+    val joinRow = new JoinedRow
+    val (joinRowWithStream, joinRowWithBuild) = {
+      buildSide match {
+        case BuildLeft => (joinRow.withRight _, joinRow.withLeft _)
+        case BuildRight => (joinRow.withLeft _, joinRow.withRight _)
+      }
+    }
+    val joinKeys = streamSideKeyGenerator()
+    val buildRowGenerator = UnsafeProjection.create(buildOutput, buildOutput)
+    val buildNullRow = new GenericInternalRow(buildOutput.length)
+    val streamNullRow = new GenericInternalRow(streamedOutput.length)
+
+    def markRowLookedUp(row: UnsafeRow): Unit =
+      row.setBoolean(row.numFields() - 1, true)
+
+    // Process stream side with looking up hash relation
+    val streamResultIter =
+      if (hashedRelation.keyIsUnique) {
+        streamIter.map { srow =>
+          joinRowWithStream(srow)
+          val keys = joinKeys(srow)
+          if (keys.anyNull) {
+            joinRowWithBuild(buildNullRow)
+          } else {
+            val matched = hashedRelation.getValue(keys)
+            if (matched != null) {
+              val buildRow = buildRowGenerator(matched)
+              if (boundCondition(joinRowWithBuild(buildRow))) {
+                markRowLookedUp(matched.asInstanceOf[UnsafeRow])
+                joinRow
+              } else {
+                joinRowWithBuild(buildNullRow)
+              }
+            } else {
+              joinRowWithBuild(buildNullRow)
+            }
+          }
+        }
+      } else {
+        streamIter.flatMap { srow =>
+          joinRowWithStream(srow)
+          val keys = joinKeys(srow)
+          if (keys.anyNull) {
+            Iterator.single(joinRowWithBuild(buildNullRow))
+          } else {
+            val buildIter = hashedRelation.get(keys)
+            new RowIterator {
+              private var found = false
+              override def advanceNext(): Boolean = {

Review comment:
       @viirya - we cannot do it as we depend on state of `found` when calling 
`advanceNext()` multiple times here. Note: this is the same logic with 
[`HashJoin.outerJoin()`](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala#L194-L212).
   
   (1).case 1 - if `buildIter` is not empty, but none of element can pass 
`boundCondition`: A call to `advanceNext()` will first go through while loop to 
exhaust `buildIter`, and rely on `found` not being set, to output a NULL row.
   
   (2).case 2 - if `buildIter` is not empty, and elements can pass 
`boundCondition` (suppose `buildIter` only has 1 element): 1st call for 
`advanceNext()` will go into `buildIter` while loop and output one matching row 
and return true. 2nd call for `advanceNext()` will check `found` as true, 
output nothing and return `false`.
   
   (3).case 3 - if `buildIter` is empty: 1st call for `advanceNext()` will 
output a NULL row, set `found` to true, and return true. 2nd call for 
`advanceNext()` will return false to indicate no element.
   
   So we cannot move `found` into `advanceNext`, and we cannot skip setting 
`found` per your next comment.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to