Github user cloud-fan commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19531#discussion_r147666360
  
    --- Diff: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala
 ---
    @@ -157,64 +157,90 @@ case class InnerOuterEstimation(join: Join) extends 
Logging {
       // scalastyle:off
       /**
        * The number of rows of A inner join B on A.k1 = B.k1 is estimated by 
this basic formula:
    -   * T(A IJ B) = T(A) * T(B) / max(V(A.k1), V(B.k1)), where V is the 
number of distinct values of
    -   * that column. The underlying assumption for this formula is: each 
value of the smaller domain
    -   * is included in the larger domain.
    -   * Generally, inner join with multiple join keys can also be estimated 
based on the above
    -   * formula:
    +   * T(A IJ B) = T(A) * T(B) / max(V(A.k1), V(B.k1)),
    +   * where V is the number of distinct values (ndv) of that column. The 
underlying assumption for
    +   * this formula is: each value of the smaller domain is included in the 
larger domain.
    +   *
    +   * Generally, inner join with multiple join keys can be estimated based 
on the above formula:
        * T(A IJ B) = T(A) * T(B) / (max(V(A.k1), V(B.k1)) * max(V(A.k2), 
V(B.k2)) * ... * max(V(A.kn), V(B.kn)))
        * However, the denominator can become very large and excessively reduce 
the result, so we use a
        * conservative strategy to take only the largest max(V(A.ki), V(B.ki)) 
as the denominator.
    +   *
    +   * That is, join estimation is based on the most selective join keys. We 
follow this strategy
    +   * when different types of column statistics are available. E.g., if 
card1 is the cardinality
    +   * estimated by ndv of join key A.k1 and B.k1, card2 is the cardinality 
estimated by histograms
    +   * of join key A.k2 and B.k2, then the result cardinality would be 
min(card1, card2).
    +   *
    +   * @param keyPairs pairs of join keys
    +   *
    +   * @return join cardinality, and column stats for join keys after the 
join
        */
       // scalastyle:on
    -  def joinSelectivity(joinKeyPairs: Seq[(AttributeReference, 
AttributeReference)]): BigDecimal = {
    -    var ndvDenom: BigInt = -1
    +  private def computeCardinalityAndStats(keyPairs: 
Seq[(AttributeReference, AttributeReference)])
    +    : (BigInt, Map[Attribute, ColumnStat]) = {
    +    // If there's no column stats available for join keys, estimate as 
cartesian product.
    +    var cardJoin: BigInt = leftStats.rowCount.get * rightStats.rowCount.get
    +    val keyStatsAfterJoin = new mutable.HashMap[Attribute, ColumnStat]()
         var i = 0
    -    while(i < joinKeyPairs.length && ndvDenom != 0) {
    -      val (leftKey, rightKey) = joinKeyPairs(i)
    +    while(i < keyPairs.length && cardJoin != 0) {
    +      val (leftKey, rightKey) = keyPairs(i)
           // Check if the two sides are disjoint
    -      val leftKeyStats = leftStats.attributeStats(leftKey)
    -      val rightKeyStats = rightStats.attributeStats(rightKey)
    -      val lInterval = ValueInterval(leftKeyStats.min, leftKeyStats.max, 
leftKey.dataType)
    -      val rInterval = ValueInterval(rightKeyStats.min, rightKeyStats.max, 
rightKey.dataType)
    +      val leftKeyStat = leftStats.attributeStats(leftKey)
    +      val rightKeyStat = rightStats.attributeStats(rightKey)
    +      val lInterval = ValueInterval(leftKeyStat.min, leftKeyStat.max, 
leftKey.dataType)
    +      val rInterval = ValueInterval(rightKeyStat.min, rightKeyStat.max, 
rightKey.dataType)
           if (ValueInterval.isIntersected(lInterval, rInterval)) {
    -        // Get the largest ndv among pairs of join keys
    -        val maxNdv = 
leftKeyStats.distinctCount.max(rightKeyStats.distinctCount)
    -        if (maxNdv > ndvDenom) ndvDenom = maxNdv
    +        val (newMin, newMax) = ValueInterval.intersect(lInterval, 
rInterval, leftKey.dataType)
    +        val (cardKeyPair, joinStatsKeyPair) = computeByNdv(leftKey, 
rightKey, newMin, newMax)
    +        keyStatsAfterJoin ++= joinStatsKeyPair
    +        // Return cardinality estimated from the most selective join keys.
    +        if (cardKeyPair < cardJoin) cardJoin = cardKeyPair
           } else {
    -        // Set ndvDenom to zero to indicate that this join should have no 
output
    -        ndvDenom = 0
    +        // One of the join key pairs is disjoint, thus the two sides of 
join is disjoint.
    +        cardJoin = 0
           }
           i += 1
         }
    +    (cardJoin, keyStatsAfterJoin.toMap)
    +  }
     
    -    if (ndvDenom < 0) {
    -      // We can't find any join key pairs with column stats, estimate it 
as cartesian join.
    -      1
    -    } else if (ndvDenom == 0) {
    -      // One of the join key pairs is disjoint, thus the two sides of join 
is disjoint.
    -      0
    -    } else {
    -      1 / BigDecimal(ndvDenom)
    -    }
    +  /** Compute join cardinality using the basic formula, and update column 
stats for join keys. */
    +  private def computeByNdv(
    +      leftKey: AttributeReference,
    +      rightKey: AttributeReference,
    +      newMin: Option[Any],
    +      newMax: Option[Any]): (BigInt, Map[Attribute, ColumnStat]) = {
    --- End diff --
    
    I think we should return `(BigInt, ColumnStat)`, which means the column 
stats of the join key. Left and right keys must have same stats.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to