Github user cloud-fan commented on a diff in the pull request:
https://github.com/apache/spark/pull/19594#discussion_r157524477
--- Diff:
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala
---
@@ -67,6 +67,222 @@ class JoinEstimationSuite extends
StatsEstimationTestBase {
rowCount = 2,
attributeStats = AttributeMap(Seq("key-1-2",
"key-2-3").map(nameToColInfo)))
+ private def estimateByHistogram(
+ leftHistogram: Histogram,
+ rightHistogram: Histogram,
+ expectedMin: Double,
+ expectedMax: Double,
+ expectedNdv: Long,
+ expectedRows: Long): Unit = {
+ val col1 = attr("key1")
+ val col2 = attr("key2")
+ val c1 = generateJoinChild(col1, leftHistogram, expectedMin,
expectedMax)
+ val c2 = generateJoinChild(col2, rightHistogram, expectedMin,
expectedMax)
+
+ val c1JoinC2 = Join(c1, c2, Inner, Some(EqualTo(col1, col2)))
+ val c2JoinC1 = Join(c2, c1, Inner, Some(EqualTo(col2, col1)))
+ val expectedStatsAfterJoin = Statistics(
+ sizeInBytes = expectedRows * (8 + 2 * 4),
+ rowCount = Some(expectedRows),
+ attributeStats = AttributeMap(Seq(
+ col1 -> c1.stats.attributeStats(col1).copy(
+ distinctCount = expectedNdv, min = Some(expectedMin), max =
Some(expectedMax)),
+ col2 -> c2.stats.attributeStats(col2).copy(
+ distinctCount = expectedNdv, min = Some(expectedMin), max =
Some(expectedMax))))
+ )
+
+ // Join order should not affect estimation result.
+ Seq(c1JoinC2, c2JoinC1).foreach { join =>
+ assert(join.stats == expectedStatsAfterJoin)
+ }
+ }
+
+ private def generateJoinChild(
+ col: Attribute,
+ histogram: Histogram,
+ expectedMin: Double,
+ expectedMax: Double): LogicalPlan = {
+ val colStat = inferColumnStat(histogram)
+ StatsTestPlan(
+ outputList = Seq(col),
+ rowCount = (histogram.height * histogram.bins.length).toLong,
+ attributeStats = AttributeMap(Seq(col -> colStat)))
+ }
+
+ /** Column statistics should be consistent with histograms in tests. */
+ private def inferColumnStat(histogram: Histogram): ColumnStat = {
+ var ndv = 0L
+ for (i <- histogram.bins.indices) {
+ val bin = histogram.bins(i)
+ if (i == 0 || bin.hi != histogram.bins(i - 1).hi) {
+ ndv += bin.ndv
+ }
+ }
+ ColumnStat(distinctCount = ndv, min = Some(histogram.bins.head.lo),
+ max = Some(histogram.bins.last.hi), nullCount = 0, avgLen = 4,
maxLen = 4,
+ histogram = Some(histogram))
+ }
+
+ test("equi-height histograms: a bin is contained by another one") {
+ val histogram1 = Histogram(height = 300, Array(
+ HistogramBin(lo = 10, hi = 30, ndv = 10), HistogramBin(lo = 30, hi =
60, ndv = 30)))
+ val histogram2 = Histogram(height = 100, Array(
+ HistogramBin(lo = 0, hi = 50, ndv = 50), HistogramBin(lo = 50, hi =
100, ndv = 40)))
+ // test bin trimming
+ val (t0, h0) = trimBin(histogram2.bins(0), height = 100, lowerBound =
10, upperBound = 60)
+ assert(t0 == HistogramBin(lo = 10, hi = 50, ndv = 40) && h0 == 80)
+ val (t1, h1) = trimBin(histogram2.bins(1), height = 100, lowerBound =
10, upperBound = 60)
+ assert(t1 == HistogramBin(lo = 50, hi = 60, ndv = 8) && h1 == 20)
+
+ val expectedRanges = Seq(
+ // histogram1.bins(0) overlaps t0
+ OverlappedRange(10, 30, 10, 40*1/2, 300, 80*1/2),
+ // histogram1.bins(1) overlaps t0
+ OverlappedRange(30, 50, 30*2/3, 40*1/2, 300*2/3, 80*1/2),
+ // histogram1.bins(1) overlaps t1
+ OverlappedRange(50, 60, 30*1/3, 8, 300*1/3, 20)
+ )
+ assert(expectedRanges.equals(
+ getOverlappedRanges(histogram1, histogram2, lowerBound = 10D,
upperBound = 60D)))
--- End diff --
10D looks weird, how about 10.0
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]