Github user tejasapatil commented on a diff in the pull request:
https://github.com/apache/spark/pull/19054#discussion_r162768446
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
---
@@ -220,45 +220,76 @@ case class EnsureRequirements(conf: SQLConf) extends
Rule[SparkPlan] {
operator.withNewChildren(children)
}
+ private def isSubset(biggerSet: Seq[Expression], smallerSet:
Seq[Expression]): Boolean =
+ smallerSet.length <= biggerSet.length &&
+ smallerSet.forall(x => biggerSet.exists(_.semanticEquals(x)))
+
private def reorder(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
- expectedOrderOfKeys: Seq[Expression],
- currentOrderOfKeys: Seq[Expression]): (Seq[Expression],
Seq[Expression]) = {
- val leftKeysBuffer = ArrayBuffer[Expression]()
- val rightKeysBuffer = ArrayBuffer[Expression]()
+ expectedOrderOfKeys: Seq[Expression], // comes from child's output
partitioning
+ currentOrderOfKeys: Seq[Expression]): // comes from join predicate
+ (Seq[Expression], Seq[Expression], Seq[Expression], Seq[Expression]) = {
+
+ assert(leftKeys.length == rightKeys.length)
+
+ val allLeftKeys = ArrayBuffer[Expression]()
+ val allRightKeys = ArrayBuffer[Expression]()
+ val reorderedLeftKeys = ArrayBuffer[Expression]()
+ val reorderedRightKeys = ArrayBuffer[Expression]()
+ val processedIndicies = mutable.Set[Int]()
expectedOrderOfKeys.foreach(expression => {
- val index = currentOrderOfKeys.indexWhere(e =>
e.semanticEquals(expression))
- leftKeysBuffer.append(leftKeys(index))
- rightKeysBuffer.append(rightKeys(index))
+ val index = currentOrderOfKeys.zipWithIndex.find { case (currKey, i)
=>
+ !processedIndicies.contains(i) &&
currKey.semanticEquals(expression)
+ }.get._2
+ processedIndicies.add(index)
+
+ reorderedLeftKeys.append(leftKeys(index))
+ allLeftKeys.append(leftKeys(index))
+
+ reorderedRightKeys.append(rightKeys(index))
+ allRightKeys.append(rightKeys(index))
})
- (leftKeysBuffer, rightKeysBuffer)
+
+ // If len(currentOrderOfKeys) > len(expectedOrderOfKeys), then the
re-ordering won't have
+ // all the keys. Append the remaining keys to the end so that we are
covering all the keys
+ for (i <- leftKeys.indices) {
+ if (!processedIndicies.contains(i)) {
+ allLeftKeys.append(leftKeys(i))
+ allRightKeys.append(rightKeys(i))
+ }
+ }
+
+ assert(allLeftKeys.length == leftKeys.length)
+ assert(allRightKeys.length == rightKeys.length)
+ assert(reorderedLeftKeys.length == reorderedRightKeys.length)
+
+ (allLeftKeys, reorderedLeftKeys, allRightKeys, reorderedRightKeys)
}
private def reorderJoinKeys(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
leftPartitioning: Partitioning,
- rightPartitioning: Partitioning): (Seq[Expression], Seq[Expression])
= {
+ rightPartitioning: Partitioning):
+ (Seq[Expression], Seq[Expression], Seq[Expression], Seq[Expression]) = {
--- End diff --
added more doc. I wasn't sure how to make it easier to understand. Hope
that the example helps with that
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]