c21 commented on a change in pull request #29342:
URL: https://github.com/apache/spark/pull/29342#discussion_r470883805



##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
##########
@@ -71,8 +85,215 @@ case class ShuffledHashJoinExec(
     val numOutputRows = longMetric("numOutputRows")
     streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, 
buildIter) =>
       val hashed = buildHashedRelation(buildIter)
-      join(streamIter, hashed, numOutputRows)
+      joinType match {
+        case FullOuter => fullOuterJoin(streamIter, hashed, numOutputRows)
+        case _ => join(streamIter, hashed, numOutputRows)
+      }
+    }
+  }
+
+  private def fullOuterJoin(
+      streamIter: Iterator[InternalRow],
+      hashedRelation: HashedRelation,
+      numOutputRows: SQLMetric): Iterator[InternalRow] = {
+    val joinKeys = streamSideKeyGenerator()
+    val joinRow = new JoinedRow
+    val (joinRowWithStream, joinRowWithBuild) = {
+      buildSide match {
+        case BuildLeft => (joinRow.withRight _, joinRow.withLeft _)
+        case BuildRight => (joinRow.withLeft _, joinRow.withRight _)
+      }
+    }
+    val buildNullRow = new GenericInternalRow(buildOutput.length)
+    val streamNullRow = new GenericInternalRow(streamedOutput.length)
+    lazy val streamNullJoinRowWithBuild = {
+      buildSide match {
+        case BuildLeft =>
+          joinRow.withRight(streamNullRow)
+          joinRow.withLeft _
+        case BuildRight =>
+          joinRow.withLeft(streamNullRow)
+          joinRow.withRight _
+      }
+    }
+
+    val iter = if (hashedRelation.keyIsUnique) {
+      fullOuterJoinWithUniqueKey(streamIter, hashedRelation, joinKeys, 
joinRowWithStream,
+        joinRowWithBuild, streamNullJoinRowWithBuild, buildNullRow, 
streamNullRow)
+    } else {
+      fullOuterJoinWithNonUniqueKey(streamIter, hashedRelation, joinKeys, 
joinRowWithStream,
+        joinRowWithBuild, streamNullJoinRowWithBuild, buildNullRow, 
streamNullRow)
     }
+
+    val resultProj = UnsafeProjection.create(output, output)
+    iter.map { r =>
+      numOutputRows += 1
+      resultProj(r)
+    }
+  }
+
+  /**
+   * Full outer shuffled hash join with unique join keys:
+   * 1. Process rows from stream side by looking up hash relation.
+   *    Mark the matched rows from build side be looked up.
+   *    A `BitSet` is used to track matched rows with key index.
+   * 2. Process rows from build side by iterating hash relation.
+   *    Filter out rows from build side being matched already,
+   *    by checking key index from `BitSet`.
+   */
+  private def fullOuterJoinWithUniqueKey(
+      streamIter: Iterator[InternalRow],
+      hashedRelation: HashedRelation,
+      joinKeys: UnsafeProjection,
+      joinRowWithStream: InternalRow => JoinedRow,
+      joinRowWithBuild: InternalRow => JoinedRow,
+      streamNullJoinRowWithBuild: => InternalRow => JoinedRow,
+      buildNullRow: GenericInternalRow,
+      streamNullRow: GenericInternalRow): Iterator[InternalRow] = {
+    val matchedKeys = new BitSet(hashedRelation.maxNumKeysIndex)
+
+    // Process stream side with looking up hash relation
+    val streamResultIter = streamIter.map { srow =>
+      joinRowWithStream(srow)
+      val keys = joinKeys(srow)
+      if (keys.anyNull) {
+        joinRowWithBuild(buildNullRow)
+      } else {
+        val matched = hashedRelation.getValueWithKeyIndex(keys)
+        if (matched != null) {
+          val keyIndex = matched.getKeyIndex
+          val buildRow = matched.getValue
+          val joinRow = joinRowWithBuild(buildRow)
+          if (boundCondition(joinRow)) {
+            matchedKeys.set(keyIndex)
+            joinRow
+          } else {
+            joinRowWithBuild(buildNullRow)
+          }
+        } else {
+          joinRowWithBuild(buildNullRow)
+        }
+      }
+    }
+
+    // Process build side with filtering out the matched rows
+    val buildResultIter = hashedRelation.valuesWithKeyIndex().flatMap {
+      valueRowWithKeyIndex =>
+        val keyIndex = valueRowWithKeyIndex.getKeyIndex
+        val isMatched = matchedKeys.get(keyIndex)
+        if (!isMatched) {
+          val buildRow = valueRowWithKeyIndex.getValue
+          Some(streamNullJoinRowWithBuild(buildRow))
+        } else {
+          None
+        }
+    }
+
+    streamResultIter ++ buildResultIter
+  }
+
+  /**
+   * Full outer shuffled hash join with unique join keys:
+   * 1. Process rows from stream side by looking up hash relation.
+   *    Mark the matched rows from build side be looked up.
+   *    A `HashSet[Long]` is used to track matched rows with
+   *    key index (Int) and value index (Int) together.
+   * 2. Process rows from build side by iterating hash relation.
+   *    Filter out rows from build side being matched already,
+   *    by checking key index and value index from `HashSet`.
+   *
+   * The "value index" is defined as the index of the tuple in the chain
+   * of tuples having the same key. For example, if certain key is found 
thrice,
+   * the value indices of its tuples will be 0, 1 and 2.
+   * Note that value indices of tuples with different keys are incomparable.
+   */
+  private def fullOuterJoinWithNonUniqueKey(
+      streamIter: Iterator[InternalRow],
+      hashedRelation: HashedRelation,
+      joinKeys: UnsafeProjection,
+      joinRowWithStream: InternalRow => JoinedRow,
+      joinRowWithBuild: InternalRow => JoinedRow,
+      streamNullJoinRowWithBuild: => InternalRow => JoinedRow,
+      buildNullRow: GenericInternalRow,
+      streamNullRow: GenericInternalRow): Iterator[InternalRow] = {
+    val matchedRows = new mutable.HashSet[Long]
+
+    def markRowMatched(keyIndex: Int, valueIndex: Int): Unit = {
+      val rowIndex: Long = (keyIndex.toLong << 32) | valueIndex
+      matchedRows.add(rowIndex)
+    }
+
+    def isRowMatched(keyIndex: Int, valueIndex: Int): Boolean = {
+      val rowIndex: Long = (keyIndex.toLong << 32) | valueIndex
+      matchedRows.contains(rowIndex)
+    }
+
+    // Process stream side with looking up hash relation
+    val streamResultIter = streamIter.flatMap { srow =>
+      val joinRow = joinRowWithStream(srow)
+      val keys = joinKeys(srow)
+      if (keys.anyNull) {
+        Iterator.single(joinRowWithBuild(buildNullRow))
+      } else {
+        val buildIter = hashedRelation.getWithKeyIndex(keys)
+        new RowIterator {
+          private var found = false
+          private var valueIndex = -1
+          override def advanceNext(): Boolean = {
+            while (buildIter != null && buildIter.hasNext) {
+              val buildRowWithKeyIndex = buildIter.next()
+              val keyIndex = buildRowWithKeyIndex.getKeyIndex
+              val buildRow = buildRowWithKeyIndex.getValue
+              valueIndex += 1
+              if (boundCondition(joinRowWithBuild(buildRow))) {
+                markRowMatched(keyIndex, valueIndex)
+                found = true
+                return true
+              }
+            }
+            // When we reach here, it means no match is found for this key.
+            // So we need to return one row with build side NULL row,
+            // to match the full outer join semantic.

Review comment:
       @agrawaldevesh - sure. this is suggestion from @cloud-fan 
[here](https://github.com/apache/spark/pull/29342#discussion_r469823249)... 
Personally i feel this is so nit, but I will change in next iteration anyway.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to