agrawaldevesh commented on a change in pull request #29342:
URL: https://github.com/apache/spark/pull/29342#discussion_r466533520



##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
##########
@@ -314,7 +338,9 @@ private[joins] object UnsafeHashedRelation {
       key: Seq[Expression],
       sizeEstimate: Int,
       taskMemoryManager: TaskMemoryManager,
-      isNullAware: Boolean = false): HashedRelation = {
+      isNullAware: Boolean = false,
+      isLookupAware: Boolean = false,
+      value: Option[Seq[Expression]] = None): HashedRelation = {

Review comment:
       Would `isLookedUp` be a better name for `value` ? Is there a way to 
force the Expression type to be a Boolean ?

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
##########
@@ -71,8 +89,134 @@ case class ShuffledHashJoinExec(
     val numOutputRows = longMetric("numOutputRows")
     streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, 
buildIter) =>
       val hashed = buildHashedRelation(buildIter)
-      join(streamIter, hashed, numOutputRows)
+      joinType match {
+        case FullOuter => fullOuterJoin(streamIter, hashed, numOutputRows)
+        case _ => join(streamIter, hashed, numOutputRows)
+      }
+    }
+  }
+
+  /**
+   * Full outer shuffled hash join has three steps:
+   * 1. Construct hash relation from build side,
+   *    with extra boolean value at the end of row to track look up information
+   *    (done in `buildHashedRelation`).
+   * 2. Process rows from stream side by looking up hash relation,
+   *    and mark the matched rows from build side be looked up.
+   * 3. Process rows from build side by iterating hash relation,
+   *    and filter out rows from build side being looked up already.
+   */
+  private def fullOuterJoin(
+      streamIter: Iterator[InternalRow],
+      hashedRelation: HashedRelation,
+      numOutputRows: SQLMetric): Iterator[InternalRow] = {
+    abstract class HashJoinedRow extends JoinedRow {
+      /** Updates this JoinedRow by updating its stream side row. Returns 
itself. */
+      def withStream(newStream: InternalRow): JoinedRow
+
+      /** Updates this JoinedRow by updating its build side row. Returns 
itself. */
+      def withBuild(newBuild: InternalRow): JoinedRow
     }
+    val joinRow: HashJoinedRow = buildSide match {
+      case BuildLeft =>
+        new HashJoinedRow {
+          override def withStream(newStream: InternalRow): JoinedRow = 
withRight(newStream)
+          override def withBuild(newBuild: InternalRow): JoinedRow = 
withLeft(newBuild)
+        }
+      case BuildRight =>
+        new HashJoinedRow {
+          override def withStream(newStream: InternalRow): JoinedRow = 
withLeft(newStream)
+          override def withBuild(newBuild: InternalRow): JoinedRow = 
withRight(newBuild)
+        }
+    }
+    val joinKeys = streamSideKeyGenerator()
+    val buildRowGenerator = UnsafeProjection.create(buildOutput, buildOutput)
+    val buildNullRow = new GenericInternalRow(buildOutput.length)
+    val streamNullRow = new GenericInternalRow(streamedOutput.length)
+
+    def markRowLookedUp(row: UnsafeRow): Unit =
+      row.setBoolean(row.numFields() - 1, true)
+
+    // Process stream side with looking up hash relation
+    val streamResultIter =
+      if (hashedRelation.keyIsUnique) {
+        streamIter.map { srow =>
+          joinRow.withStream(srow)
+          val keys = joinKeys(srow)
+          if (keys.anyNull) {
+            joinRow.withBuild(buildNullRow)
+          } else {
+            val matched = hashedRelation.getValue(keys)
+            if (matched != null) {
+              val buildRow = buildRowGenerator(matched)
+              if (boundCondition(joinRow.withBuild(buildRow))) {
+                markRowLookedUp(matched.asInstanceOf[UnsafeRow])

Review comment:
       As my overall review comment states, I believe this marking does not 
have to be stored in the hash table. Have you considered using a position-list 
or a bitset (compressed or not) for this ? Marked up rows are only considered 
within this full outer join RDD closure, so the space for them does not have to 
be allocated inside the hash table.
   
   (If I am understanding correctly)

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
##########
@@ -97,7 +102,9 @@ private[execution] object HashedRelation {
       key: Seq[Expression],
       sizeEstimate: Int = 64,
       taskMemoryManager: TaskMemoryManager = null,
-      isNullAware: Boolean = false): HashedRelation = {
+      isNullAware: Boolean = false,
+      isLookupAware: Boolean = false,
+      value: Option[Seq[Expression]] = None): HashedRelation = {

Review comment:
       Please add some documentation for 'value' too.

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
##########
@@ -71,8 +89,134 @@ case class ShuffledHashJoinExec(
     val numOutputRows = longMetric("numOutputRows")
     streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, 
buildIter) =>
       val hashed = buildHashedRelation(buildIter)
-      join(streamIter, hashed, numOutputRows)
+      joinType match {
+        case FullOuter => fullOuterJoin(streamIter, hashed, numOutputRows)

Review comment:
       Should we have a config to disable this feature ?

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
##########
@@ -314,7 +338,9 @@ private[joins] object UnsafeHashedRelation {
       key: Seq[Expression],
       sizeEstimate: Int,
       taskMemoryManager: TaskMemoryManager,
-      isNullAware: Boolean = false): HashedRelation = {
+      isNullAware: Boolean = false,

Review comment:
       If isLookupAware is not implemented to work with isNullAware then we 
should assert that the two are mutually exclusively enabled. 

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
##########
@@ -71,8 +89,134 @@ case class ShuffledHashJoinExec(
     val numOutputRows = longMetric("numOutputRows")
     streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, 
buildIter) =>
       val hashed = buildHashedRelation(buildIter)
-      join(streamIter, hashed, numOutputRows)
+      joinType match {
+        case FullOuter => fullOuterJoin(streamIter, hashed, numOutputRows)
+        case _ => join(streamIter, hashed, numOutputRows)
+      }
+    }
+  }
+
+  /**
+   * Full outer shuffled hash join has three steps:
+   * 1. Construct hash relation from build side,
+   *    with extra boolean value at the end of row to track look up information
+   *    (done in `buildHashedRelation`).
+   * 2. Process rows from stream side by looking up hash relation,
+   *    and mark the matched rows from build side be looked up.
+   * 3. Process rows from build side by iterating hash relation,
+   *    and filter out rows from build side being looked up already.
+   */
+  private def fullOuterJoin(
+      streamIter: Iterator[InternalRow],
+      hashedRelation: HashedRelation,
+      numOutputRows: SQLMetric): Iterator[InternalRow] = {
+    abstract class HashJoinedRow extends JoinedRow {
+      /** Updates this JoinedRow by updating its stream side row. Returns 
itself. */
+      def withStream(newStream: InternalRow): JoinedRow
+
+      /** Updates this JoinedRow by updating its build side row. Returns 
itself. */
+      def withBuild(newBuild: InternalRow): JoinedRow
     }
+    val joinRow: HashJoinedRow = buildSide match {
+      case BuildLeft =>
+        new HashJoinedRow {
+          override def withStream(newStream: InternalRow): JoinedRow = 
withRight(newStream)
+          override def withBuild(newBuild: InternalRow): JoinedRow = 
withLeft(newBuild)
+        }
+      case BuildRight =>
+        new HashJoinedRow {
+          override def withStream(newStream: InternalRow): JoinedRow = 
withLeft(newStream)
+          override def withBuild(newBuild: InternalRow): JoinedRow = 
withRight(newBuild)
+        }
+    }
+    val joinKeys = streamSideKeyGenerator()
+    val buildRowGenerator = UnsafeProjection.create(buildOutput, buildOutput)
+    val buildNullRow = new GenericInternalRow(buildOutput.length)
+    val streamNullRow = new GenericInternalRow(streamedOutput.length)
+
+    def markRowLookedUp(row: UnsafeRow): Unit =
+      row.setBoolean(row.numFields() - 1, true)
+
+    // Process stream side with looking up hash relation
+    val streamResultIter =
+      if (hashedRelation.keyIsUnique) {
+        streamIter.map { srow =>
+          joinRow.withStream(srow)
+          val keys = joinKeys(srow)
+          if (keys.anyNull) {
+            joinRow.withBuild(buildNullRow)
+          } else {
+            val matched = hashedRelation.getValue(keys)
+            if (matched != null) {
+              val buildRow = buildRowGenerator(matched)
+              if (boundCondition(joinRow.withBuild(buildRow))) {
+                markRowLookedUp(matched.asInstanceOf[UnsafeRow])
+                joinRow
+              } else {
+                joinRow.withBuild(buildNullRow)
+              }
+            } else {
+              joinRow.withBuild(buildNullRow)
+            }
+          }
+        }
+      } else {
+        streamIter.flatMap { srow =>
+          joinRow.withStream(srow)
+          val keys = joinKeys(srow)
+          if (keys.anyNull) {
+            Iterator.single(joinRow.withBuild(buildNullRow))
+          } else {
+            val buildIter = hashedRelation.get(keys)
+            new RowIterator {
+              private var found = false
+              override def advanceNext(): Boolean = {
+                while (buildIter != null && buildIter.hasNext) {
+                  val matched = buildIter.next()
+                  val buildRow = buildRowGenerator(matched)
+                  if (boundCondition(joinRow.withBuild(buildRow))) {
+                    markRowLookedUp(matched.asInstanceOf[UnsafeRow])
+                    found = true
+                    return true
+                  }
+                }
+                if (!found) {
+                  joinRow.withBuild(buildNullRow)
+                  found = true
+                  return true
+                }
+                false
+              }
+              override def getRow: InternalRow = joinRow
+            }.toScala
+          }
+        }
+      }
+
+    // Process build side with filtering out rows looked up already
+    val buildResultIter = hashedRelation.values().flatMap { brow =>
+      val unsafebrow = brow.asInstanceOf[UnsafeRow]
+      val isLookup = unsafebrow.getBoolean(unsafebrow.numFields() - 1)
+      if (!isLookup) {
+        val buildRow = buildRowGenerator(unsafebrow)
+        joinRow.withBuild(buildRow)
+        joinRow.withStream(streamNullRow)
+        Some(joinRow)
+      } else {
+        None
+      }
+    }
+
+    val resultProj = UnsafeProjection.create(output, output)
+    (streamResultIter ++ buildResultIter).map { r =>
+      numOutputRows += 1
+      resultProj(r)
+    }
+  }
+
+  // TODO: support full outer shuffled hash join code-gen

Review comment:
       nit: Link this to a spark jira if a jira to support codegen exists.

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
##########
@@ -885,6 +936,12 @@ class LongHashedRelation(
    * Returns an iterator for keys of InternalRow type.
    */
   override def keys(): Iterator[InternalRow] = map.keys()
+
+  override def values(): Iterator[InternalRow] = {

Review comment:
       I am a bit confused about the meaning of 'value' : Is it the "value" of 
the hash-table (if we think of a hash-table as a dictionary or a key-value 
store), or is it the "extra bit at the end to know if the row was matched or 
not" ?
   
   I believe the use of the word 'value' is a bit overloaded in this PR. 
Personally, I think it would be great if you don't use the word value for the 
second concept of "extra bit at the end to know if the row was matched or not".

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
##########
@@ -97,7 +102,9 @@ private[execution] object HashedRelation {
       key: Seq[Expression],
       sizeEstimate: Int = 64,
       taskMemoryManager: TaskMemoryManager = null,
-      isNullAware: Boolean = false): HashedRelation = {
+      isNullAware: Boolean = false,
+      isLookupAware: Boolean = false,

Review comment:
       I sort of feel that we need more documentation around `isLookupAware`. 
The reason I find this ambiguous is because hash tables are typically used for 
"lookup". So what does it mean to be lookupAware or not here ?

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
##########
@@ -71,8 +89,134 @@ case class ShuffledHashJoinExec(
     val numOutputRows = longMetric("numOutputRows")
     streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, 
buildIter) =>
       val hashed = buildHashedRelation(buildIter)
-      join(streamIter, hashed, numOutputRows)
+      joinType match {
+        case FullOuter => fullOuterJoin(streamIter, hashed, numOutputRows)
+        case _ => join(streamIter, hashed, numOutputRows)
+      }
+    }
+  }
+
+  /**
+   * Full outer shuffled hash join has three steps:
+   * 1. Construct hash relation from build side,
+   *    with extra boolean value at the end of row to track look up information
+   *    (done in `buildHashedRelation`).
+   * 2. Process rows from stream side by looking up hash relation,
+   *    and mark the matched rows from build side be looked up.
+   * 3. Process rows from build side by iterating hash relation,
+   *    and filter out rows from build side being looked up already.
+   */
+  private def fullOuterJoin(
+      streamIter: Iterator[InternalRow],
+      hashedRelation: HashedRelation,
+      numOutputRows: SQLMetric): Iterator[InternalRow] = {
+    abstract class HashJoinedRow extends JoinedRow {
+      /** Updates this JoinedRow by updating its stream side row. Returns 
itself. */
+      def withStream(newStream: InternalRow): JoinedRow
+
+      /** Updates this JoinedRow by updating its build side row. Returns 
itself. */
+      def withBuild(newBuild: InternalRow): JoinedRow
     }
+    val joinRow: HashJoinedRow = buildSide match {
+      case BuildLeft =>
+        new HashJoinedRow {
+          override def withStream(newStream: InternalRow): JoinedRow = 
withRight(newStream)
+          override def withBuild(newBuild: InternalRow): JoinedRow = 
withLeft(newBuild)
+        }
+      case BuildRight =>
+        new HashJoinedRow {
+          override def withStream(newStream: InternalRow): JoinedRow = 
withLeft(newStream)
+          override def withBuild(newBuild: InternalRow): JoinedRow = 
withRight(newBuild)
+        }
+    }
+    val joinKeys = streamSideKeyGenerator()
+    val buildRowGenerator = UnsafeProjection.create(buildOutput, buildOutput)
+    val buildNullRow = new GenericInternalRow(buildOutput.length)
+    val streamNullRow = new GenericInternalRow(streamedOutput.length)
+
+    def markRowLookedUp(row: UnsafeRow): Unit =

Review comment:
       Now see this is a better name than `value`.

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
##########
@@ -71,8 +89,134 @@ case class ShuffledHashJoinExec(
     val numOutputRows = longMetric("numOutputRows")
     streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, 
buildIter) =>
       val hashed = buildHashedRelation(buildIter)
-      join(streamIter, hashed, numOutputRows)
+      joinType match {
+        case FullOuter => fullOuterJoin(streamIter, hashed, numOutputRows)
+        case _ => join(streamIter, hashed, numOutputRows)
+      }
+    }
+  }
+
+  /**
+   * Full outer shuffled hash join has three steps:
+   * 1. Construct hash relation from build side,
+   *    with extra boolean value at the end of row to track look up information
+   *    (done in `buildHashedRelation`).
+   * 2. Process rows from stream side by looking up hash relation,
+   *    and mark the matched rows from build side be looked up.
+   * 3. Process rows from build side by iterating hash relation,
+   *    and filter out rows from build side being looked up already.
+   */
+  private def fullOuterJoin(
+      streamIter: Iterator[InternalRow],
+      hashedRelation: HashedRelation,
+      numOutputRows: SQLMetric): Iterator[InternalRow] = {
+    abstract class HashJoinedRow extends JoinedRow {
+      /** Updates this JoinedRow by updating its stream side row. Returns 
itself. */
+      def withStream(newStream: InternalRow): JoinedRow
+
+      /** Updates this JoinedRow by updating its build side row. Returns 
itself. */
+      def withBuild(newBuild: InternalRow): JoinedRow
     }
+    val joinRow: HashJoinedRow = buildSide match {
+      case BuildLeft =>
+        new HashJoinedRow {
+          override def withStream(newStream: InternalRow): JoinedRow = 
withRight(newStream)
+          override def withBuild(newBuild: InternalRow): JoinedRow = 
withLeft(newBuild)
+        }
+      case BuildRight =>
+        new HashJoinedRow {
+          override def withStream(newStream: InternalRow): JoinedRow = 
withLeft(newStream)
+          override def withBuild(newBuild: InternalRow): JoinedRow = 
withRight(newBuild)
+        }
+    }
+    val joinKeys = streamSideKeyGenerator()
+    val buildRowGenerator = UnsafeProjection.create(buildOutput, buildOutput)
+    val buildNullRow = new GenericInternalRow(buildOutput.length)
+    val streamNullRow = new GenericInternalRow(streamedOutput.length)
+
+    def markRowLookedUp(row: UnsafeRow): Unit =
+      row.setBoolean(row.numFields() - 1, true)
+
+    // Process stream side with looking up hash relation
+    val streamResultIter =
+      if (hashedRelation.keyIsUnique) {
+        streamIter.map { srow =>
+          joinRow.withStream(srow)
+          val keys = joinKeys(srow)
+          if (keys.anyNull) {
+            joinRow.withBuild(buildNullRow)
+          } else {
+            val matched = hashedRelation.getValue(keys)
+            if (matched != null) {
+              val buildRow = buildRowGenerator(matched)
+              if (boundCondition(joinRow.withBuild(buildRow))) {
+                markRowLookedUp(matched.asInstanceOf[UnsafeRow])
+                joinRow
+              } else {
+                joinRow.withBuild(buildNullRow)
+              }
+            } else {
+              joinRow.withBuild(buildNullRow)
+            }
+          }
+        }
+      } else {
+        streamIter.flatMap { srow =>
+          joinRow.withStream(srow)
+          val keys = joinKeys(srow)
+          if (keys.anyNull) {
+            Iterator.single(joinRow.withBuild(buildNullRow))
+          } else {
+            val buildIter = hashedRelation.get(keys)
+            new RowIterator {
+              private var found = false
+              override def advanceNext(): Boolean = {
+                while (buildIter != null && buildIter.hasNext) {
+                  val matched = buildIter.next()
+                  val buildRow = buildRowGenerator(matched)
+                  if (boundCondition(joinRow.withBuild(buildRow))) {
+                    markRowLookedUp(matched.asInstanceOf[UnsafeRow])
+                    found = true
+                    return true
+                  }
+                }
+                if (!found) {
+                  joinRow.withBuild(buildNullRow)
+                  found = true
+                  return true
+                }
+                false
+              }
+              override def getRow: InternalRow = joinRow
+            }.toScala
+          }
+        }
+      }
+
+    // Process build side with filtering out rows looked up already
+    val buildResultIter = hashedRelation.values().flatMap { brow =>

Review comment:
       nit: brow -> buildRow ? 

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
##########
@@ -327,23 +353,48 @@ private[joins] object UnsafeHashedRelation {
     // Create a mapping of buildKeys -> rows
     val keyGenerator = UnsafeProjection.create(key)
     var numFields = 0
-    while (input.hasNext) {
-      val row = input.next().asInstanceOf[UnsafeRow]
-      numFields = row.numFields()
-      val key = keyGenerator(row)
-      if (!key.anyNull) {
+
+    if (isLookupAware) {
+      // Add one extra boolean value at the end as part of the row,
+      // to track the information that whether the corresponding key
+      // has been looked up or not. See `ShuffledHashJoin.fullOuterJoin` for 
example of usage.
+      val valueGenerator = UnsafeProjection.create(value.get :+ Literal(false))
+
+      while (input.hasNext) {
+        val row = input.next().asInstanceOf[UnsafeRow]
+        numFields = row.numFields() + 1
+        val key = keyGenerator(row)
+        val value = valueGenerator(row)
         val loc = binaryMap.lookup(key.getBaseObject, key.getBaseOffset, 
key.getSizeInBytes)
         val success = loc.append(
           key.getBaseObject, key.getBaseOffset, key.getSizeInBytes,
-          row.getBaseObject, row.getBaseOffset, row.getSizeInBytes)
+          value.getBaseObject, value.getBaseOffset, value.getSizeInBytes)
         if (!success) {
           binaryMap.free()
           // scalastyle:off throwerror
           throw new SparkOutOfMemoryError("There is not enough memory to build 
hash map")
           // scalastyle:on throwerror
         }
-      } else if (isNullAware) {
-        return EmptyHashedRelationWithAllNullKeys
+      }
+    } else {
+      while (input.hasNext) {
+        val row = input.next().asInstanceOf[UnsafeRow]
+        numFields = row.numFields()
+        val key = keyGenerator(row)
+        if (!key.anyNull) {
+          val loc = binaryMap.lookup(key.getBaseObject, key.getBaseOffset, 
key.getSizeInBytes)
+          val success = loc.append(
+            key.getBaseObject, key.getBaseOffset, key.getSizeInBytes,
+            row.getBaseObject, row.getBaseOffset, row.getSizeInBytes)
+          if (!success) {
+            binaryMap.free()
+            // scalastyle:off throwerror
+            throw new SparkOutOfMemoryError("There is not enough memory to 
build hash map")
+            // scalastyle:on throwerror
+          }
+        } else if (isNullAware) {
+          return EmptyHashedRelationWithAllNullKeys
+        }

Review comment:
       +1




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to