Github user wzhfy commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19531#discussion_r147544011
  
    --- Diff: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala
 ---
    @@ -157,64 +154,100 @@ case class InnerOuterEstimation(join: Join) extends 
Logging {
       // scalastyle:off
       /**
        * The number of rows of A inner join B on A.k1 = B.k1 is estimated by 
this basic formula:
    -   * T(A IJ B) = T(A) * T(B) / max(V(A.k1), V(B.k1)), where V is the 
number of distinct values of
    -   * that column. The underlying assumption for this formula is: each 
value of the smaller domain
    -   * is included in the larger domain.
    -   * Generally, inner join with multiple join keys can also be estimated 
based on the above
    -   * formula:
    +   * T(A IJ B) = T(A) * T(B) / max(V(A.k1), V(B.k1)),
    +   * where V is the number of distinct values (ndv) of that column. The 
underlying assumption for
    +   * this formula is: each value of the smaller domain is included in the 
larger domain.
    +   *
    +   * Generally, inner join with multiple join keys can be estimated based 
on the above formula:
        * T(A IJ B) = T(A) * T(B) / (max(V(A.k1), V(B.k1)) * max(V(A.k2), 
V(B.k2)) * ... * max(V(A.kn), V(B.kn)))
        * However, the denominator can become very large and excessively reduce 
the result, so we use a
        * conservative strategy to take only the largest max(V(A.ki), V(B.ki)) 
as the denominator.
    +   *
    +   * That is, join estimation is based on the most selective join keys. We 
follow this strategy
    +   * when different types of column statistics are available. E.g., if 
card1 is the cardinality
    +   * estimated by ndv of join key A.k1 and B.k1, card2 is the cardinality 
estimated by histograms
    +   * of join key A.k2 and B.k2, then the result cardinality would be 
min(card1, card2).
        */
       // scalastyle:on
    -  def joinSelectivity(joinKeyPairs: Seq[(AttributeReference, 
AttributeReference)]): BigDecimal = {
    -    var ndvDenom: BigInt = -1
    +  private def joinCardinality(joinKeyPairs: Seq[(AttributeReference, 
AttributeReference)])
    +    : BigInt = {
    +    // If there's no column stats available for join keys, estimate as 
cartesian product.
    +    var minCard: BigInt = leftStats.rowCount.get * rightStats.rowCount.get
         var i = 0
    -    while(i < joinKeyPairs.length && ndvDenom != 0) {
    +    while(i < joinKeyPairs.length && minCard != 0) {
           val (leftKey, rightKey) = joinKeyPairs(i)
           // Check if the two sides are disjoint
    -      val leftKeyStats = leftStats.attributeStats(leftKey)
    -      val rightKeyStats = rightStats.attributeStats(rightKey)
    -      val lInterval = ValueInterval(leftKeyStats.min, leftKeyStats.max, 
leftKey.dataType)
    -      val rInterval = ValueInterval(rightKeyStats.min, rightKeyStats.max, 
rightKey.dataType)
    +      val leftKeyStat = leftStats.attributeStats(leftKey)
    +      val rightKeyStat = rightStats.attributeStats(rightKey)
    +      val lInterval = ValueInterval(leftKeyStat.min, leftKeyStat.max, 
leftKey.dataType)
    +      val rInterval = ValueInterval(rightKeyStat.min, rightKeyStat.max, 
rightKey.dataType)
           if (ValueInterval.isIntersected(lInterval, rInterval)) {
    -        // Get the largest ndv among pairs of join keys
    -        val maxNdv = 
leftKeyStats.distinctCount.max(rightKeyStats.distinctCount)
    -        if (maxNdv > ndvDenom) ndvDenom = maxNdv
    +        val (newMin, newMax) = ValueInterval.intersect(lInterval, 
rInterval, leftKey.dataType)
    +        val card = joinCardByNdv(leftKey, rightKey, newMin, newMax)
    +        // Return cardinality estimated from the most selective join keys.
    +        if (card < minCard) minCard = card
           } else {
    -        // Set ndvDenom to zero to indicate that this join should have no 
output
    -        ndvDenom = 0
    +        // One of the join key pairs is disjoint, thus the two sides of 
join is disjoint.
    +        minCard = 0
           }
           i += 1
         }
    +    minCard
    +  }
     
    -    if (ndvDenom < 0) {
    -      // We can't find any join key pairs with column stats, estimate it 
as cartesian join.
    -      1
    -    } else if (ndvDenom == 0) {
    -      // One of the join key pairs is disjoint, thus the two sides of join 
is disjoint.
    -      0
    -    } else {
    -      1 / BigDecimal(ndvDenom)
    +  /** Compute join cardinality using the basic formula, and update column 
stats for join keys. */
    +  private def joinCardByNdv(
    +      leftKey: AttributeReference,
    +      rightKey: AttributeReference,
    +      newMin: Option[Any],
    +      newMax: Option[Any]): BigInt = {
    +    val leftKeyStat = leftStats.attributeStats(leftKey)
    +    val rightKeyStat = rightStats.attributeStats(rightKey)
    +    val maxNdv = leftKeyStat.distinctCount.max(rightKeyStat.distinctCount)
    +    // Compute cardinality by the basic formula.
    +    val card = BigDecimal(leftStats.rowCount.get * 
rightStats.rowCount.get) / BigDecimal(maxNdv)
    +
    +    // Update intersected column stats.
    +    val newNdv = leftKeyStat.distinctCount.min(rightKeyStat.distinctCount)
    +    val newMaxLen = math.min(leftKeyStat.maxLen, rightKeyStat.maxLen)
    +    val newAvgLen = (leftKeyStat.avgLen + rightKeyStat.avgLen) / 2
    +
    +    join.joinType match {
    +      case LeftOuter =>
    +        keyStatsAfterJoin.put(leftKey, leftKeyStat)
    +        keyStatsAfterJoin.put(rightKey,
    +          ColumnStat(newNdv, newMin, newMax, rightKeyStat.nullCount, 
newAvgLen, newMaxLen))
    +      case RightOuter =>
    +        keyStatsAfterJoin.put(leftKey,
    +          ColumnStat(newNdv, newMin, newMax, leftKeyStat.nullCount, 
newAvgLen, newMaxLen))
    +        keyStatsAfterJoin.put(rightKey, rightKeyStat)
    +      case FullOuter =>
    +        keyStatsAfterJoin.put(leftKey, leftKeyStat)
    +        keyStatsAfterJoin.put(rightKey, rightKeyStat)
    +      case _ =>
    +        val newStats = ColumnStat(newNdv, newMin, newMax, 0, newAvgLen, 
newMaxLen)
    +        keyStatsAfterJoin.put(leftKey, newStats)
    +        keyStatsAfterJoin.put(rightKey, newStats)
    --- End diff --
    
    They are a part of the old method 
[getIntersectedStats](https://github.com/apache/spark/pull/19531/files#diff-6387e7aaeb7d8e0cb1457b9d0fe5cd00L235).
 I need to revert those lines for outer join cases, thanks for pointing it out.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to