agrawaldevesh commented on a change in pull request #29342:
URL: https://github.com/apache/spark/pull/29342#discussion_r467372338



##########
File path: sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
##########
@@ -1188,4 +1188,42 @@ class JoinSuite extends QueryTest with 
SharedSparkSession with AdaptiveSparkPlan
         classOf[BroadcastNestedLoopJoinExec]))
     }
   }
+
+  test("SPARK-32399: Full outer shuffled hash join") {

Review comment:
       I am curious if it might be possible to unconditionally enable this 
optimization for _all_ full outer joins to make sure that this optimization 
unconditionally works ? A similar approach was taken in PR #29104, if you 
search for `CONFIG_DIM1 spark.sql.optimizeNullAwareAntiJoin=true` in all `.sql` 
files like `group-by-filter.sql` for example. 

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
##########
@@ -97,7 +102,9 @@ private[execution] object HashedRelation {
       key: Seq[Expression],
       sizeEstimate: Int = 64,
       taskMemoryManager: TaskMemoryManager = null,
-      isNullAware: Boolean = false): HashedRelation = {
+      isNullAware: Boolean = false,
+      isLookupAware: Boolean = false,

Review comment:
       canMarkMatchedRow ? or canMarkLookedupRow ?
   
   stemming from the use of `markRowLookedUp` below.

##########
File path: 
sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
##########
@@ -580,4 +580,28 @@ class HashedRelationSuite extends SharedSparkSession {
       assert(proj(packedKeys).get(0, dt) == -i - 1)
     }
   }
+
+  test("SPARK-32399: test values() method for HashedRelation") {
+    val key = Seq(BoundReference(0, LongType, false))
+    val value = Seq(BoundReference(0, IntegerType, true))
+    val unsafeProj = UnsafeProjection.create(value)
+    val rows = (0 until 100).map(i => unsafeProj(InternalRow(i + 1)).copy())
+
+    // test LongHashedRelation
+    val longRelation = LongHashedRelation(rows.iterator, key, 10, mm)
+    var values = longRelation.values()
+    assert(values.map(_.getInt(0)).toArray.sortWith(_ < _) === (0 until 
100).map(i => i + 1))

Review comment:
       This is quite a mouthful to understand (on this line and the sorts on 
lines 598 and 604 below). Can you add some comments on what is this test trying 
to assert ?

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
##########
@@ -71,8 +89,134 @@ case class ShuffledHashJoinExec(
     val numOutputRows = longMetric("numOutputRows")
     streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, 
buildIter) =>
       val hashed = buildHashedRelation(buildIter)
-      join(streamIter, hashed, numOutputRows)
+      joinType match {
+        case FullOuter => fullOuterJoin(streamIter, hashed, numOutputRows)
+        case _ => join(streamIter, hashed, numOutputRows)
+      }
+    }
+  }
+
+  /**
+   * Full outer shuffled hash join has three steps:
+   * 1. Construct hash relation from build side,
+   *    with extra boolean value at the end of row to track look up information
+   *    (done in `buildHashedRelation`).
+   * 2. Process rows from stream side by looking up hash relation,
+   *    and mark the matched rows from build side be looked up.
+   * 3. Process rows from build side by iterating hash relation,
+   *    and filter out rows from build side being looked up already.
+   */
+  private def fullOuterJoin(
+      streamIter: Iterator[InternalRow],
+      hashedRelation: HashedRelation,
+      numOutputRows: SQLMetric): Iterator[InternalRow] = {
+    abstract class HashJoinedRow extends JoinedRow {
+      /** Updates this JoinedRow by updating its stream side row. Returns 
itself. */
+      def withStream(newStream: InternalRow): JoinedRow
+
+      /** Updates this JoinedRow by updating its build side row. Returns 
itself. */
+      def withBuild(newBuild: InternalRow): JoinedRow
     }
+    val joinRow: HashJoinedRow = buildSide match {
+      case BuildLeft =>
+        new HashJoinedRow {
+          override def withStream(newStream: InternalRow): JoinedRow = 
withRight(newStream)
+          override def withBuild(newBuild: InternalRow): JoinedRow = 
withLeft(newBuild)
+        }
+      case BuildRight =>
+        new HashJoinedRow {
+          override def withStream(newStream: InternalRow): JoinedRow = 
withLeft(newStream)
+          override def withBuild(newBuild: InternalRow): JoinedRow = 
withRight(newBuild)
+        }
+    }
+    val joinKeys = streamSideKeyGenerator()
+    val buildRowGenerator = UnsafeProjection.create(buildOutput, buildOutput)
+    val buildNullRow = new GenericInternalRow(buildOutput.length)
+    val streamNullRow = new GenericInternalRow(streamedOutput.length)
+
+    def markRowLookedUp(row: UnsafeRow): Unit =
+      row.setBoolean(row.numFields() - 1, true)
+
+    // Process stream side with looking up hash relation
+    val streamResultIter =
+      if (hashedRelation.keyIsUnique) {
+        streamIter.map { srow =>
+          joinRow.withStream(srow)
+          val keys = joinKeys(srow)
+          if (keys.anyNull) {
+            joinRow.withBuild(buildNullRow)
+          } else {
+            val matched = hashedRelation.getValue(keys)
+            if (matched != null) {
+              val buildRow = buildRowGenerator(matched)
+              if (boundCondition(joinRow.withBuild(buildRow))) {
+                markRowLookedUp(matched.asInstanceOf[UnsafeRow])
+                joinRow
+              } else {
+                joinRow.withBuild(buildNullRow)
+              }
+            } else {
+              joinRow.withBuild(buildNullRow)
+            }
+          }
+        }
+      } else {
+        streamIter.flatMap { srow =>
+          joinRow.withStream(srow)
+          val keys = joinKeys(srow)
+          if (keys.anyNull) {
+            Iterator.single(joinRow.withBuild(buildNullRow))
+          } else {
+            val buildIter = hashedRelation.get(keys)
+            new RowIterator {
+              private var found = false
+              override def advanceNext(): Boolean = {
+                while (buildIter != null && buildIter.hasNext) {
+                  val matched = buildIter.next()
+                  val buildRow = buildRowGenerator(matched)
+                  if (boundCondition(joinRow.withBuild(buildRow))) {
+                    markRowLookedUp(matched.asInstanceOf[UnsafeRow])
+                    found = true
+                    return true
+                  }
+                }
+                if (!found) {
+                  joinRow.withBuild(buildNullRow)
+                  found = true
+                  return true
+                }
+                false
+              }
+              override def getRow: InternalRow = joinRow
+            }.toScala
+          }
+        }
+      }
+
+    // Process build side with filtering out rows looked up already
+    val buildResultIter = hashedRelation.values().flatMap { brow =>

Review comment:
       Sure !. Thanks for considering it.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to