Github user cloud-fan commented on a diff in the pull request: https://github.com/apache/spark/pull/19531#discussion_r147666360 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala --- @@ -157,64 +157,90 @@ case class InnerOuterEstimation(join: Join) extends Logging { // scalastyle:off /** * The number of rows of A inner join B on A.k1 = B.k1 is estimated by this basic formula: - * T(A IJ B) = T(A) * T(B) / max(V(A.k1), V(B.k1)), where V is the number of distinct values of - * that column. The underlying assumption for this formula is: each value of the smaller domain - * is included in the larger domain. - * Generally, inner join with multiple join keys can also be estimated based on the above - * formula: + * T(A IJ B) = T(A) * T(B) / max(V(A.k1), V(B.k1)), + * where V is the number of distinct values (ndv) of that column. The underlying assumption for + * this formula is: each value of the smaller domain is included in the larger domain. + * + * Generally, inner join with multiple join keys can be estimated based on the above formula: * T(A IJ B) = T(A) * T(B) / (max(V(A.k1), V(B.k1)) * max(V(A.k2), V(B.k2)) * ... * max(V(A.kn), V(B.kn))) * However, the denominator can become very large and excessively reduce the result, so we use a * conservative strategy to take only the largest max(V(A.ki), V(B.ki)) as the denominator. + * + * That is, join estimation is based on the most selective join keys. We follow this strategy + * when different types of column statistics are available. E.g., if card1 is the cardinality + * estimated by ndv of join key A.k1 and B.k1, card2 is the cardinality estimated by histograms + * of join key A.k2 and B.k2, then the result cardinality would be min(card1, card2). + * + * @param keyPairs pairs of join keys + * + * @return join cardinality, and column stats for join keys after the join */ // scalastyle:on - def joinSelectivity(joinKeyPairs: Seq[(AttributeReference, AttributeReference)]): BigDecimal = { - var ndvDenom: BigInt = -1 + private def computeCardinalityAndStats(keyPairs: Seq[(AttributeReference, AttributeReference)]) + : (BigInt, Map[Attribute, ColumnStat]) = { + // If there's no column stats available for join keys, estimate as cartesian product. + var cardJoin: BigInt = leftStats.rowCount.get * rightStats.rowCount.get + val keyStatsAfterJoin = new mutable.HashMap[Attribute, ColumnStat]() var i = 0 - while(i < joinKeyPairs.length && ndvDenom != 0) { - val (leftKey, rightKey) = joinKeyPairs(i) + while(i < keyPairs.length && cardJoin != 0) { + val (leftKey, rightKey) = keyPairs(i) // Check if the two sides are disjoint - val leftKeyStats = leftStats.attributeStats(leftKey) - val rightKeyStats = rightStats.attributeStats(rightKey) - val lInterval = ValueInterval(leftKeyStats.min, leftKeyStats.max, leftKey.dataType) - val rInterval = ValueInterval(rightKeyStats.min, rightKeyStats.max, rightKey.dataType) + val leftKeyStat = leftStats.attributeStats(leftKey) + val rightKeyStat = rightStats.attributeStats(rightKey) + val lInterval = ValueInterval(leftKeyStat.min, leftKeyStat.max, leftKey.dataType) + val rInterval = ValueInterval(rightKeyStat.min, rightKeyStat.max, rightKey.dataType) if (ValueInterval.isIntersected(lInterval, rInterval)) { - // Get the largest ndv among pairs of join keys - val maxNdv = leftKeyStats.distinctCount.max(rightKeyStats.distinctCount) - if (maxNdv > ndvDenom) ndvDenom = maxNdv + val (newMin, newMax) = ValueInterval.intersect(lInterval, rInterval, leftKey.dataType) + val (cardKeyPair, joinStatsKeyPair) = computeByNdv(leftKey, rightKey, newMin, newMax) + keyStatsAfterJoin ++= joinStatsKeyPair + // Return cardinality estimated from the most selective join keys. + if (cardKeyPair < cardJoin) cardJoin = cardKeyPair } else { - // Set ndvDenom to zero to indicate that this join should have no output - ndvDenom = 0 + // One of the join key pairs is disjoint, thus the two sides of join is disjoint. + cardJoin = 0 } i += 1 } + (cardJoin, keyStatsAfterJoin.toMap) + } - if (ndvDenom < 0) { - // We can't find any join key pairs with column stats, estimate it as cartesian join. - 1 - } else if (ndvDenom == 0) { - // One of the join key pairs is disjoint, thus the two sides of join is disjoint. - 0 - } else { - 1 / BigDecimal(ndvDenom) - } + /** Compute join cardinality using the basic formula, and update column stats for join keys. */ + private def computeByNdv( + leftKey: AttributeReference, + rightKey: AttributeReference, + newMin: Option[Any], + newMax: Option[Any]): (BigInt, Map[Attribute, ColumnStat]) = { --- End diff -- I think we should return `(BigInt, ColumnStat)`, which means the column stats of the join key. Left and right keys must have same stats.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org