c21 commented on a change in pull request #29342:
URL: https://github.com/apache/spark/pull/29342#discussion_r470908654



##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
##########
@@ -71,8 +85,215 @@ case class ShuffledHashJoinExec(
     val numOutputRows = longMetric("numOutputRows")
     streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, 
buildIter) =>
       val hashed = buildHashedRelation(buildIter)
-      join(streamIter, hashed, numOutputRows)
+      joinType match {
+        case FullOuter => fullOuterJoin(streamIter, hashed, numOutputRows)
+        case _ => join(streamIter, hashed, numOutputRows)
+      }
+    }
+  }
+
+  private def fullOuterJoin(
+      streamIter: Iterator[InternalRow],
+      hashedRelation: HashedRelation,
+      numOutputRows: SQLMetric): Iterator[InternalRow] = {
+    val joinKeys = streamSideKeyGenerator()
+    val joinRow = new JoinedRow
+    val (joinRowWithStream, joinRowWithBuild) = {
+      buildSide match {
+        case BuildLeft => (joinRow.withRight _, joinRow.withLeft _)
+        case BuildRight => (joinRow.withLeft _, joinRow.withRight _)
+      }
+    }
+    val buildNullRow = new GenericInternalRow(buildOutput.length)
+    val streamNullRow = new GenericInternalRow(streamedOutput.length)
+    lazy val streamNullJoinRowWithBuild = {
+      buildSide match {
+        case BuildLeft =>
+          joinRow.withRight(streamNullRow)
+          joinRow.withLeft _
+        case BuildRight =>
+          joinRow.withLeft(streamNullRow)
+          joinRow.withRight _
+      }
+    }
+
+    val iter = if (hashedRelation.keyIsUnique) {
+      fullOuterJoinWithUniqueKey(streamIter, hashedRelation, joinKeys, 
joinRowWithStream,
+        joinRowWithBuild, streamNullJoinRowWithBuild, buildNullRow, 
streamNullRow)
+    } else {
+      fullOuterJoinWithNonUniqueKey(streamIter, hashedRelation, joinKeys, 
joinRowWithStream,
+        joinRowWithBuild, streamNullJoinRowWithBuild, buildNullRow, 
streamNullRow)
     }
+
+    val resultProj = UnsafeProjection.create(output, output)
+    iter.map { r =>
+      numOutputRows += 1
+      resultProj(r)
+    }
+  }
+
+  /**
+   * Full outer shuffled hash join with unique join keys:
+   * 1. Process rows from stream side by looking up hash relation.
+   *    Mark the matched rows from build side be looked up.
+   *    A `BitSet` is used to track matched rows with key index.
+   * 2. Process rows from build side by iterating hash relation.
+   *    Filter out rows from build side being matched already,
+   *    by checking key index from `BitSet`.
+   */
+  private def fullOuterJoinWithUniqueKey(
+      streamIter: Iterator[InternalRow],
+      hashedRelation: HashedRelation,
+      joinKeys: UnsafeProjection,
+      joinRowWithStream: InternalRow => JoinedRow,
+      joinRowWithBuild: InternalRow => JoinedRow,
+      streamNullJoinRowWithBuild: => InternalRow => JoinedRow,
+      buildNullRow: GenericInternalRow,
+      streamNullRow: GenericInternalRow): Iterator[InternalRow] = {
+    val matchedKeys = new BitSet(hashedRelation.maxNumKeysIndex)
+
+    // Process stream side with looking up hash relation
+    val streamResultIter = streamIter.map { srow =>
+      joinRowWithStream(srow)
+      val keys = joinKeys(srow)
+      if (keys.anyNull) {
+        joinRowWithBuild(buildNullRow)
+      } else {
+        val matched = hashedRelation.getValueWithKeyIndex(keys)
+        if (matched != null) {
+          val keyIndex = matched.getKeyIndex
+          val buildRow = matched.getValue
+          val joinRow = joinRowWithBuild(buildRow)
+          if (boundCondition(joinRow)) {
+            matchedKeys.set(keyIndex)
+            joinRow
+          } else {
+            joinRowWithBuild(buildNullRow)
+          }
+        } else {
+          joinRowWithBuild(buildNullRow)
+        }
+      }
+    }
+
+    // Process build side with filtering out the matched rows
+    val buildResultIter = hashedRelation.valuesWithKeyIndex().flatMap {
+      valueRowWithKeyIndex =>
+        val keyIndex = valueRowWithKeyIndex.getKeyIndex
+        val isMatched = matchedKeys.get(keyIndex)
+        if (!isMatched) {
+          val buildRow = valueRowWithKeyIndex.getValue
+          Some(streamNullJoinRowWithBuild(buildRow))
+        } else {
+          None
+        }
+    }
+
+    streamResultIter ++ buildResultIter
+  }
+
+  /**
+   * Full outer shuffled hash join with unique join keys:

Review comment:
       @cloud-fan - my bad, updated.

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
##########
@@ -66,6 +66,30 @@ private[execution] sealed trait HashedRelation extends 
KnownSizeEstimation {
     throw new UnsupportedOperationException
   }
 
+  /**
+   * Returns an iterator for key index and matched rows.
+   *
+   * Returns null if there is no matched rows.
+   */
+  def getWithKeyIndex(key: InternalRow): Iterator[ValueRowWithKeyIndex]
+
+  /**
+   * Returns key index and matched single row.
+   *
+   * Returns null if there is no matched rows.
+   */
+  def getValueWithKeyIndex(key: InternalRow): ValueRowWithKeyIndex

Review comment:
       @viirya - sure, added.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to