Github user ron8hu commented on a diff in the pull request: https://github.com/apache/spark/pull/19594#discussion_r153665547 --- Diff: sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala --- @@ -67,6 +68,205 @@ class JoinEstimationSuite extends StatsEstimationTestBase { rowCount = 2, attributeStats = AttributeMap(Seq("key-1-2", "key-2-3").map(nameToColInfo))) + private def estimateByHistogram( + histogram1: Histogram, + histogram2: Histogram, + expectedMin: Double, + expectedMax: Double, + expectedNdv: Long, + expectedRows: Long): Unit = { + val col1 = attr("key1") + val col2 = attr("key2") + val c1 = generateJoinChild(col1, histogram1, expectedMin, expectedMax) + val c2 = generateJoinChild(col2, histogram2, expectedMin, expectedMax) + + val c1JoinC2 = Join(c1, c2, Inner, Some(EqualTo(col1, col2))) + val c2JoinC1 = Join(c2, c1, Inner, Some(EqualTo(col2, col1))) + val expectedStatsAfterJoin = Statistics( + sizeInBytes = expectedRows * (8 + 2 * 4), + rowCount = Some(expectedRows), + attributeStats = AttributeMap(Seq( + col1 -> c1.stats.attributeStats(col1).copy( + distinctCount = expectedNdv, min = Some(expectedMin), max = Some(expectedMax)), + col2 -> c2.stats.attributeStats(col2).copy( + distinctCount = expectedNdv, min = Some(expectedMin), max = Some(expectedMax)))) + ) + + // Join order should not affect estimation result. + Seq(c1JoinC2, c2JoinC1).foreach { join => + assert(join.stats == expectedStatsAfterJoin) + } + } + + private def generateJoinChild( + col: Attribute, + histogram: Histogram, + expectedMin: Double, + expectedMax: Double): LogicalPlan = { + val colStat = inferColumnStat(histogram) + val t = StatsTestPlan( + outputList = Seq(col), + rowCount = (histogram.height * histogram.bins.length).toLong, + attributeStats = AttributeMap(Seq(col -> colStat))) + + val filterCondition = new ArrayBuffer[Expression]() + if (expectedMin > colStat.min.get.toString.toDouble) { + filterCondition += GreaterThanOrEqual(col, Literal(expectedMin)) + } + if (expectedMax < colStat.max.get.toString.toDouble) { + filterCondition += LessThanOrEqual(col, Literal(expectedMax)) + } + if (filterCondition.isEmpty) t else Filter(filterCondition.reduce(And), t) + } + + private def inferColumnStat(histogram: Histogram): ColumnStat = { + var ndv = 0L + for (i <- histogram.bins.indices) { + val bin = histogram.bins(i) + if (i == 0 || bin.hi != histogram.bins(i - 1).hi) { + ndv += bin.ndv + } + } + ColumnStat(distinctCount = ndv, min = Some(histogram.bins.head.lo), + max = Some(histogram.bins.last.hi), nullCount = 0, avgLen = 4, maxLen = 4, + histogram = Some(histogram)) + } + + test("equi-height histograms: a bin is contained by another one") { + val histogram1 = Histogram(height = 300, Array( + HistogramBin(lo = 10, hi = 30, ndv = 10), HistogramBin(lo = 30, hi = 60, ndv = 30))) + val histogram2 = Histogram(height = 100, Array( + HistogramBin(lo = 0, hi = 50, ndv = 50), HistogramBin(lo = 50, hi = 100, ndv = 40))) + // test bin trimming + val (t1, h1) = trimBin(histogram2.bins(0), height = 100, min = 10, max = 60) + assert(t1 == HistogramBin(lo = 10, hi = 50, ndv = 40) && h1 == 80) + val (t2, h2) = trimBin(histogram2.bins(1), height = 100, min = 10, max = 60) + assert(t2 == HistogramBin(lo = 50, hi = 60, ndv = 8) && h2 == 20) + + val expectedRanges = Seq( + OverlappedRange(10, 30, math.min(10, 40*1/2), math.max(10, 40*1/2), 300, 80*1/2), + OverlappedRange(30, 50, math.min(30*2/3, 40*1/2), math.max(30*2/3, 40*1/2), 300*2/3, 80*1/2), + OverlappedRange(50, 60, math.min(30*1/3, 8), math.max(30*1/3, 8), 300*1/3, 20) + ) + assert(expectedRanges.equals( + getOverlappedRanges(histogram1, histogram2, newMin = 10D, newMax = 60D))) + + estimateByHistogram( + histogram1 = histogram1, + histogram2 = histogram2, + expectedMin = 10D, + expectedMax = 60D, + // 10 + 20 + 8 + expectedNdv = 38L, + // 300*40/20 + 200*40/20 + 100*20/10 + expectedRows = 1200L) + } + + test("equi-height histograms: a bin has only one value") { + val histogram1 = Histogram(height = 300, Array( + HistogramBin(lo = 30, hi = 30, ndv = 1), HistogramBin(lo = 30, hi = 60, ndv = 30))) + val histogram2 = Histogram(height = 100, Array( + HistogramBin(lo = 0, hi = 50, ndv = 50), HistogramBin(lo = 50, hi = 100, ndv = 40))) + // test bin trimming + val (t1, h1) = trimBin(histogram2.bins(0), height = 100, min = 30, max = 60) + assert(t1 == HistogramBin(lo = 30, hi = 50, ndv = 20) && h1 == 40) + val (t2, h2) = trimBin(histogram2.bins(1), height = 100, min = 30, max = 60) + assert(t2 ==HistogramBin(lo = 50, hi = 60, ndv = 8) && h2 == 20) + + val expectedRanges = Seq( + OverlappedRange(30, 30, 1, 1, 300, 40/20), + OverlappedRange(30, 50, math.min(30*2/3, 20), math.max(30*2/3, 20), 300*2/3, 40), + OverlappedRange(50, 60, math.min(30*1/3, 8), math.max(30*1/3, 8), 300*1/3, 20) + ) + assert(expectedRanges.equals( + getOverlappedRanges(histogram1, histogram2, newMin = 30D, newMax = 60D))) + + estimateByHistogram( + histogram1 = histogram1, + histogram2 = histogram2, + expectedMin = 30D, + expectedMax = 60D, + // 1 + 20 + 8 + expectedNdv = 29L, + // 300*20/1 + 200*40/20 + 100*20/10 + expectedRows = 1200L) + } + + test("equi-height histograms: a bin has only one value after trimming") { + val histogram1 = Histogram(height = 300, Array( + HistogramBin(lo = 50, hi = 60, ndv = 10), HistogramBin(lo = 60, hi = 75, ndv = 3))) + val histogram2 = Histogram(height = 100, Array( + HistogramBin(lo = 0, hi = 50, ndv = 50), HistogramBin(lo = 50, hi = 100, ndv = 40))) --- End diff -- For the very skewed cases, multiple bins in a histogram may have same distinct value. We may add one more test case to cover this situation.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org