Github user eyalfa commented on a diff in the pull request:
https://github.com/apache/spark/pull/19054#discussion_r165861581
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
---
@@ -220,45 +220,99 @@ case class EnsureRequirements(conf: SQLConf) extends
Rule[SparkPlan] {
operator.withNewChildren(children)
}
+ private def isSubset(biggerSet: Seq[Expression], smallerSet:
Seq[Expression]): Boolean =
+ smallerSet.length <= biggerSet.length &&
+ smallerSet.forall(x => biggerSet.exists(_.semanticEquals(x)))
+
+ /**
+ * Reorders `leftKeys` and `rightKeys` by aligning `currentOrderOfKeys`
to be a prefix of
+ * `expectedOrderOfKeys`
+ */
private def reorder(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
- expectedOrderOfKeys: Seq[Expression],
- currentOrderOfKeys: Seq[Expression]): (Seq[Expression],
Seq[Expression]) = {
- val leftKeysBuffer = ArrayBuffer[Expression]()
- val rightKeysBuffer = ArrayBuffer[Expression]()
+ expectedOrderOfKeys: Seq[Expression], // comes from child's output
partitioning
+ currentOrderOfKeys: Seq[Expression]): // comes from join predicate
+ (Seq[Expression], Seq[Expression], Seq[Expression], Seq[Expression]) = {
+
+ assert(leftKeys.length == rightKeys.length)
+
+ val allLeftKeys = ArrayBuffer[Expression]()
+ val allRightKeys = ArrayBuffer[Expression]()
+ val reorderedLeftKeys = ArrayBuffer[Expression]()
+ val reorderedRightKeys = ArrayBuffer[Expression]()
+
+ // Tracking indicies here to track to which keys are accounted. Using
a set based approach
+ // won't work because its possible that some keys are repeated in the
join clause
+ // eg. a.key1 = b.key1 AND a.key1 = b.key2
+ val processedIndicies = mutable.Set[Int]()
expectedOrderOfKeys.foreach(expression => {
- val index = currentOrderOfKeys.indexWhere(e =>
e.semanticEquals(expression))
- leftKeysBuffer.append(leftKeys(index))
- rightKeysBuffer.append(rightKeys(index))
+ val index = currentOrderOfKeys.zipWithIndex.find { case (currKey, i)
=>
+ !processedIndicies.contains(i) &&
currKey.semanticEquals(expression)
+ }.get._2
--- End diff --
is the find guaranteed to always succeed?
if so, worth a comment on method's pre/post conditions.
a getOrElse(sys error "...") might also be a good way of documenting this.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]