Github user ron8hu commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19594#discussion_r153665547
  
    --- Diff: 
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala
 ---
    @@ -67,6 +68,205 @@ class JoinEstimationSuite extends 
StatsEstimationTestBase {
         rowCount = 2,
         attributeStats = AttributeMap(Seq("key-1-2", 
"key-2-3").map(nameToColInfo)))
     
    +  private def estimateByHistogram(
    +      histogram1: Histogram,
    +      histogram2: Histogram,
    +      expectedMin: Double,
    +      expectedMax: Double,
    +      expectedNdv: Long,
    +      expectedRows: Long): Unit = {
    +    val col1 = attr("key1")
    +    val col2 = attr("key2")
    +    val c1 = generateJoinChild(col1, histogram1, expectedMin, expectedMax)
    +    val c2 = generateJoinChild(col2, histogram2, expectedMin, expectedMax)
    +
    +    val c1JoinC2 = Join(c1, c2, Inner, Some(EqualTo(col1, col2)))
    +    val c2JoinC1 = Join(c2, c1, Inner, Some(EqualTo(col2, col1)))
    +    val expectedStatsAfterJoin = Statistics(
    +      sizeInBytes = expectedRows * (8 + 2 * 4),
    +      rowCount = Some(expectedRows),
    +      attributeStats = AttributeMap(Seq(
    +        col1 -> c1.stats.attributeStats(col1).copy(
    +          distinctCount = expectedNdv, min = Some(expectedMin), max = 
Some(expectedMax)),
    +        col2 -> c2.stats.attributeStats(col2).copy(
    +          distinctCount = expectedNdv, min = Some(expectedMin), max = 
Some(expectedMax))))
    +    )
    +
    +    // Join order should not affect estimation result.
    +    Seq(c1JoinC2, c2JoinC1).foreach { join =>
    +      assert(join.stats == expectedStatsAfterJoin)
    +    }
    +  }
    +
    +  private def generateJoinChild(
    +      col: Attribute,
    +      histogram: Histogram,
    +      expectedMin: Double,
    +      expectedMax: Double): LogicalPlan = {
    +    val colStat = inferColumnStat(histogram)
    +    val t = StatsTestPlan(
    +      outputList = Seq(col),
    +      rowCount = (histogram.height * histogram.bins.length).toLong,
    +      attributeStats = AttributeMap(Seq(col -> colStat)))
    +
    +    val filterCondition = new ArrayBuffer[Expression]()
    +    if (expectedMin > colStat.min.get.toString.toDouble) {
    +      filterCondition += GreaterThanOrEqual(col, Literal(expectedMin))
    +    }
    +    if (expectedMax < colStat.max.get.toString.toDouble) {
    +      filterCondition += LessThanOrEqual(col, Literal(expectedMax))
    +    }
    +    if (filterCondition.isEmpty) t else 
Filter(filterCondition.reduce(And), t)
    +  }
    +
    +  private def inferColumnStat(histogram: Histogram): ColumnStat = {
    +    var ndv = 0L
    +    for (i <- histogram.bins.indices) {
    +      val bin = histogram.bins(i)
    +      if (i == 0 || bin.hi != histogram.bins(i - 1).hi) {
    +        ndv += bin.ndv
    +      }
    +    }
    +    ColumnStat(distinctCount = ndv, min = Some(histogram.bins.head.lo),
    +      max = Some(histogram.bins.last.hi), nullCount = 0, avgLen = 4, 
maxLen = 4,
    +      histogram = Some(histogram))
    +  }
    +
    +  test("equi-height histograms: a bin is contained by another one") {
    +    val histogram1 = Histogram(height = 300, Array(
    +      HistogramBin(lo = 10, hi = 30, ndv = 10), HistogramBin(lo = 30, hi = 
60, ndv = 30)))
    +    val histogram2 = Histogram(height = 100, Array(
    +      HistogramBin(lo = 0, hi = 50, ndv = 50), HistogramBin(lo = 50, hi = 
100, ndv = 40)))
    +    // test bin trimming
    +    val (t1, h1) = trimBin(histogram2.bins(0), height = 100, min = 10, max 
= 60)
    +    assert(t1 == HistogramBin(lo = 10, hi = 50, ndv = 40) && h1 == 80)
    +    val (t2, h2) = trimBin(histogram2.bins(1), height = 100, min = 10, max 
= 60)
    +    assert(t2 == HistogramBin(lo = 50, hi = 60, ndv = 8) && h2 == 20)
    +
    +    val expectedRanges = Seq(
    +      OverlappedRange(10, 30, math.min(10, 40*1/2), math.max(10, 40*1/2), 
300, 80*1/2),
    +      OverlappedRange(30, 50, math.min(30*2/3, 40*1/2), math.max(30*2/3, 
40*1/2), 300*2/3, 80*1/2),
    +      OverlappedRange(50, 60, math.min(30*1/3, 8), math.max(30*1/3, 8), 
300*1/3, 20)
    +    )
    +    assert(expectedRanges.equals(
    +      getOverlappedRanges(histogram1, histogram2, newMin = 10D, newMax = 
60D)))
    +
    +    estimateByHistogram(
    +      histogram1 = histogram1,
    +      histogram2 = histogram2,
    +      expectedMin = 10D,
    +      expectedMax = 60D,
    +      // 10 + 20 + 8
    +      expectedNdv = 38L,
    +      // 300*40/20 + 200*40/20 + 100*20/10
    +      expectedRows = 1200L)
    +  }
    +
    +  test("equi-height histograms: a bin has only one value") {
    +    val histogram1 = Histogram(height = 300, Array(
    +      HistogramBin(lo = 30, hi = 30, ndv = 1), HistogramBin(lo = 30, hi = 
60, ndv = 30)))
    +    val histogram2 = Histogram(height = 100, Array(
    +      HistogramBin(lo = 0, hi = 50, ndv = 50), HistogramBin(lo = 50, hi = 
100, ndv = 40)))
    +    // test bin trimming
    +    val (t1, h1) = trimBin(histogram2.bins(0), height = 100, min = 30, max 
= 60)
    +    assert(t1 == HistogramBin(lo = 30, hi = 50, ndv = 20) && h1 == 40)
    +    val (t2, h2) = trimBin(histogram2.bins(1), height = 100, min = 30, max 
= 60)
    +    assert(t2 ==HistogramBin(lo = 50, hi = 60, ndv = 8) && h2 == 20)
    +
    +    val expectedRanges = Seq(
    +      OverlappedRange(30, 30, 1, 1, 300, 40/20),
    +      OverlappedRange(30, 50, math.min(30*2/3, 20), math.max(30*2/3, 20), 
300*2/3, 40),
    +      OverlappedRange(50, 60, math.min(30*1/3, 8), math.max(30*1/3, 8), 
300*1/3, 20)
    +    )
    +    assert(expectedRanges.equals(
    +      getOverlappedRanges(histogram1, histogram2, newMin = 30D, newMax = 
60D)))
    +
    +    estimateByHistogram(
    +      histogram1 = histogram1,
    +      histogram2 = histogram2,
    +      expectedMin = 30D,
    +      expectedMax = 60D,
    +      // 1 + 20 + 8
    +      expectedNdv = 29L,
    +      // 300*20/1 + 200*40/20 + 100*20/10
    +      expectedRows = 1200L)
    +  }
    +
    +  test("equi-height histograms: a bin has only one value after trimming") {
    +    val histogram1 = Histogram(height = 300, Array(
    +      HistogramBin(lo = 50, hi = 60, ndv = 10), HistogramBin(lo = 60, hi = 
75, ndv = 3)))
    +    val histogram2 = Histogram(height = 100, Array(
    +      HistogramBin(lo = 0, hi = 50, ndv = 50), HistogramBin(lo = 50, hi = 
100, ndv = 40)))
    --- End diff --
    
    For the very skewed cases, multiple bins in a histogram may have same 
distinct value.  We may add one more test case to cover this situation.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to