Github user wzhfy commented on a diff in the pull request: https://github.com/apache/spark/pull/19531#discussion_r147544011 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala --- @@ -157,64 +154,100 @@ case class InnerOuterEstimation(join: Join) extends Logging { // scalastyle:off /** * The number of rows of A inner join B on A.k1 = B.k1 is estimated by this basic formula: - * T(A IJ B) = T(A) * T(B) / max(V(A.k1), V(B.k1)), where V is the number of distinct values of - * that column. The underlying assumption for this formula is: each value of the smaller domain - * is included in the larger domain. - * Generally, inner join with multiple join keys can also be estimated based on the above - * formula: + * T(A IJ B) = T(A) * T(B) / max(V(A.k1), V(B.k1)), + * where V is the number of distinct values (ndv) of that column. The underlying assumption for + * this formula is: each value of the smaller domain is included in the larger domain. + * + * Generally, inner join with multiple join keys can be estimated based on the above formula: * T(A IJ B) = T(A) * T(B) / (max(V(A.k1), V(B.k1)) * max(V(A.k2), V(B.k2)) * ... * max(V(A.kn), V(B.kn))) * However, the denominator can become very large and excessively reduce the result, so we use a * conservative strategy to take only the largest max(V(A.ki), V(B.ki)) as the denominator. + * + * That is, join estimation is based on the most selective join keys. We follow this strategy + * when different types of column statistics are available. E.g., if card1 is the cardinality + * estimated by ndv of join key A.k1 and B.k1, card2 is the cardinality estimated by histograms + * of join key A.k2 and B.k2, then the result cardinality would be min(card1, card2). */ // scalastyle:on - def joinSelectivity(joinKeyPairs: Seq[(AttributeReference, AttributeReference)]): BigDecimal = { - var ndvDenom: BigInt = -1 + private def joinCardinality(joinKeyPairs: Seq[(AttributeReference, AttributeReference)]) + : BigInt = { + // If there's no column stats available for join keys, estimate as cartesian product. + var minCard: BigInt = leftStats.rowCount.get * rightStats.rowCount.get var i = 0 - while(i < joinKeyPairs.length && ndvDenom != 0) { + while(i < joinKeyPairs.length && minCard != 0) { val (leftKey, rightKey) = joinKeyPairs(i) // Check if the two sides are disjoint - val leftKeyStats = leftStats.attributeStats(leftKey) - val rightKeyStats = rightStats.attributeStats(rightKey) - val lInterval = ValueInterval(leftKeyStats.min, leftKeyStats.max, leftKey.dataType) - val rInterval = ValueInterval(rightKeyStats.min, rightKeyStats.max, rightKey.dataType) + val leftKeyStat = leftStats.attributeStats(leftKey) + val rightKeyStat = rightStats.attributeStats(rightKey) + val lInterval = ValueInterval(leftKeyStat.min, leftKeyStat.max, leftKey.dataType) + val rInterval = ValueInterval(rightKeyStat.min, rightKeyStat.max, rightKey.dataType) if (ValueInterval.isIntersected(lInterval, rInterval)) { - // Get the largest ndv among pairs of join keys - val maxNdv = leftKeyStats.distinctCount.max(rightKeyStats.distinctCount) - if (maxNdv > ndvDenom) ndvDenom = maxNdv + val (newMin, newMax) = ValueInterval.intersect(lInterval, rInterval, leftKey.dataType) + val card = joinCardByNdv(leftKey, rightKey, newMin, newMax) + // Return cardinality estimated from the most selective join keys. + if (card < minCard) minCard = card } else { - // Set ndvDenom to zero to indicate that this join should have no output - ndvDenom = 0 + // One of the join key pairs is disjoint, thus the two sides of join is disjoint. + minCard = 0 } i += 1 } + minCard + } - if (ndvDenom < 0) { - // We can't find any join key pairs with column stats, estimate it as cartesian join. - 1 - } else if (ndvDenom == 0) { - // One of the join key pairs is disjoint, thus the two sides of join is disjoint. - 0 - } else { - 1 / BigDecimal(ndvDenom) + /** Compute join cardinality using the basic formula, and update column stats for join keys. */ + private def joinCardByNdv( + leftKey: AttributeReference, + rightKey: AttributeReference, + newMin: Option[Any], + newMax: Option[Any]): BigInt = { + val leftKeyStat = leftStats.attributeStats(leftKey) + val rightKeyStat = rightStats.attributeStats(rightKey) + val maxNdv = leftKeyStat.distinctCount.max(rightKeyStat.distinctCount) + // Compute cardinality by the basic formula. + val card = BigDecimal(leftStats.rowCount.get * rightStats.rowCount.get) / BigDecimal(maxNdv) + + // Update intersected column stats. + val newNdv = leftKeyStat.distinctCount.min(rightKeyStat.distinctCount) + val newMaxLen = math.min(leftKeyStat.maxLen, rightKeyStat.maxLen) + val newAvgLen = (leftKeyStat.avgLen + rightKeyStat.avgLen) / 2 + + join.joinType match { + case LeftOuter => + keyStatsAfterJoin.put(leftKey, leftKeyStat) + keyStatsAfterJoin.put(rightKey, + ColumnStat(newNdv, newMin, newMax, rightKeyStat.nullCount, newAvgLen, newMaxLen)) + case RightOuter => + keyStatsAfterJoin.put(leftKey, + ColumnStat(newNdv, newMin, newMax, leftKeyStat.nullCount, newAvgLen, newMaxLen)) + keyStatsAfterJoin.put(rightKey, rightKeyStat) + case FullOuter => + keyStatsAfterJoin.put(leftKey, leftKeyStat) + keyStatsAfterJoin.put(rightKey, rightKeyStat) + case _ => + val newStats = ColumnStat(newNdv, newMin, newMax, 0, newAvgLen, newMaxLen) + keyStatsAfterJoin.put(leftKey, newStats) + keyStatsAfterJoin.put(rightKey, newStats) --- End diff -- They are a part of the old method [getIntersectedStats](https://github.com/apache/spark/pull/19531/files#diff-6387e7aaeb7d8e0cb1457b9d0fe5cd00L235). I need to revert those lines for outer join cases, thanks for pointing it out.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org