http://git-wip-us.apache.org/repos/asf/spark/blob/8077bb04/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index 2b1fe98..43440d5 100755 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -37,59 +37,61 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // column cint has values: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 // Hence, distinctCount:10, min:1, max:10, nullCount:0, avgLen:4, maxLen:4 val attrInt = AttributeReference("cint", IntegerType)() - val colStatInt = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4) + val colStatInt = ColumnStat(distinctCount = Some(10), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) // column cbool has only 2 distinct values val attrBool = AttributeReference("cbool", BooleanType)() - val colStatBool = ColumnStat(distinctCount = 2, min = Some(false), max = Some(true), - nullCount = 0, avgLen = 1, maxLen = 1) + val colStatBool = ColumnStat(distinctCount = Some(2), min = Some(false), max = Some(true), + nullCount = Some(0), avgLen = Some(1), maxLen = Some(1)) // column cdate has 10 values from 2017-01-01 through 2017-01-10. val dMin = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-01")) val dMax = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-10")) val attrDate = AttributeReference("cdate", DateType)() - val colStatDate = ColumnStat(distinctCount = 10, min = Some(dMin), max = Some(dMax), - nullCount = 0, avgLen = 4, maxLen = 4) + val colStatDate = ColumnStat(distinctCount = Some(10), + min = Some(dMin), max = Some(dMax), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) // column cdecimal has 4 values from 0.20 through 0.80 at increment of 0.20. val decMin = Decimal("0.200000000000000000") val decMax = Decimal("0.800000000000000000") val attrDecimal = AttributeReference("cdecimal", DecimalType(18, 18))() - val colStatDecimal = ColumnStat(distinctCount = 4, min = Some(decMin), max = Some(decMax), - nullCount = 0, avgLen = 8, maxLen = 8) + val colStatDecimal = ColumnStat(distinctCount = Some(4), + min = Some(decMin), max = Some(decMax), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8)) // column cdouble has 10 double values: 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0 val attrDouble = AttributeReference("cdouble", DoubleType)() - val colStatDouble = ColumnStat(distinctCount = 10, min = Some(1.0), max = Some(10.0), - nullCount = 0, avgLen = 8, maxLen = 8) + val colStatDouble = ColumnStat(distinctCount = Some(10), min = Some(1.0), max = Some(10.0), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8)) // column cstring has 10 String values: // "A0", "A1", "A2", "A3", "A4", "A5", "A6", "A7", "A8", "A9" val attrString = AttributeReference("cstring", StringType)() - val colStatString = ColumnStat(distinctCount = 10, min = None, max = None, - nullCount = 0, avgLen = 2, maxLen = 2) + val colStatString = ColumnStat(distinctCount = Some(10), min = None, max = None, + nullCount = Some(0), avgLen = Some(2), maxLen = Some(2)) // column cint2 has values: 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 // Hence, distinctCount:10, min:7, max:16, nullCount:0, avgLen:4, maxLen:4 // This column is created to test "cint < cint2 val attrInt2 = AttributeReference("cint2", IntegerType)() - val colStatInt2 = ColumnStat(distinctCount = 10, min = Some(7), max = Some(16), - nullCount = 0, avgLen = 4, maxLen = 4) + val colStatInt2 = ColumnStat(distinctCount = Some(10), min = Some(7), max = Some(16), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) // column cint3 has values: 30, 31, 32, 33, 34, 35, 36, 37, 38, 39 // Hence, distinctCount:10, min:30, max:39, nullCount:0, avgLen:4, maxLen:4 // This column is created to test "cint = cint3 without overlap at all. val attrInt3 = AttributeReference("cint3", IntegerType)() - val colStatInt3 = ColumnStat(distinctCount = 10, min = Some(30), max = Some(39), - nullCount = 0, avgLen = 4, maxLen = 4) + val colStatInt3 = ColumnStat(distinctCount = Some(10), min = Some(30), max = Some(39), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) // column cint4 has values in the range from 1 to 10 // distinctCount:10, min:1, max:10, nullCount:0, avgLen:4, maxLen:4 // This column is created to test complete overlap val attrInt4 = AttributeReference("cint4", IntegerType)() - val colStatInt4 = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4) + val colStatInt4 = ColumnStat(distinctCount = Some(10), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) // column cintHgm has values: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 with histogram. // Note that cintHgm has an even distribution with histogram information built. @@ -98,8 +100,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val hgmInt = Histogram(2.0, Array(HistogramBin(1.0, 2.0, 2), HistogramBin(2.0, 4.0, 2), HistogramBin(4.0, 6.0, 2), HistogramBin(6.0, 8.0, 2), HistogramBin(8.0, 10.0, 2))) - val colStatIntHgm = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmInt)) + val colStatIntHgm = ColumnStat(distinctCount = Some(10), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmInt)) // column cintSkewHgm has values: 1, 4, 4, 5, 5, 5, 5, 6, 6, 10 with histogram. // Note that cintSkewHgm has a skewed distribution with histogram information built. @@ -108,8 +110,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val hgmIntSkew = Histogram(2.0, Array(HistogramBin(1.0, 4.0, 2), HistogramBin(4.0, 5.0, 2), HistogramBin(5.0, 5.0, 1), HistogramBin(5.0, 6.0, 2), HistogramBin(6.0, 10.0, 2))) - val colStatIntSkewHgm = ColumnStat(distinctCount = 5, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmIntSkew)) + val colStatIntSkewHgm = ColumnStat(distinctCount = Some(5), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmIntSkew)) val attributeMap = AttributeMap(Seq( attrInt -> colStatInt, @@ -172,7 +174,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Or(LessThan(attrInt, Literal(3)), Literal(null, IntegerType)) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> colStatInt.copy(distinctCount = 3)), + Seq(attrInt -> colStatInt.copy(distinctCount = Some(3))), expectedRowCount = 3) } @@ -180,7 +182,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Not(And(LessThan(attrInt, Literal(3)), Literal(null, IntegerType))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> colStatInt.copy(distinctCount = 8)), + Seq(attrInt -> colStatInt.copy(distinctCount = Some(8))), expectedRowCount = 8) } @@ -196,23 +198,23 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Not(And(LessThan(attrInt, Literal(3)), Not(Literal(null, IntegerType)))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> colStatInt.copy(distinctCount = 8)), + Seq(attrInt -> colStatInt.copy(distinctCount = Some(8))), expectedRowCount = 8) } test("cint = 2") { validateEstimatedStats( Filter(EqualTo(attrInt, Literal(2)), childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 1, min = Some(2), max = Some(2), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(1), min = Some(2), max = Some(2), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 1) } test("cint <=> 2") { validateEstimatedStats( Filter(EqualNullSafe(attrInt, Literal(2)), childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 1, min = Some(2), max = Some(2), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(1), min = Some(2), max = Some(2), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 1) } @@ -227,8 +229,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cint < 3") { validateEstimatedStats( Filter(LessThan(attrInt, Literal(3)), childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(3), min = Some(1), max = Some(3), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 3) } @@ -243,16 +245,16 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cint <= 3") { validateEstimatedStats( Filter(LessThanOrEqual(attrInt, Literal(3)), childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(3), min = Some(1), max = Some(3), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 3) } test("cint > 6") { validateEstimatedStats( Filter(GreaterThan(attrInt, Literal(6)), childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 5, min = Some(6), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(5), min = Some(6), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 5) } @@ -267,8 +269,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cint >= 6") { validateEstimatedStats( Filter(GreaterThanOrEqual(attrInt, Literal(6)), childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 5, min = Some(6), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(5), min = Some(6), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 5) } @@ -282,8 +284,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cint IS NOT NULL") { validateEstimatedStats( Filter(IsNotNull(attrInt), childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(10), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 10) } @@ -301,8 +303,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = And(GreaterThan(attrInt, Literal(3)), LessThanOrEqual(attrInt, Literal(6))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 4, min = Some(3), max = Some(6), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(4), min = Some(3), max = Some(6), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 4) } @@ -310,7 +312,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Or(EqualTo(attrInt, Literal(3)), EqualTo(attrInt, Literal(6))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> colStatInt.copy(distinctCount = 2)), + Seq(attrInt -> colStatInt.copy(distinctCount = Some(2))), expectedRowCount = 2) } @@ -318,7 +320,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Not(And(GreaterThan(attrInt, Literal(3)), LessThanOrEqual(attrInt, Literal(6)))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> colStatInt.copy(distinctCount = 6)), + Seq(attrInt -> colStatInt.copy(distinctCount = Some(6))), expectedRowCount = 6) } @@ -326,7 +328,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Not(Or(LessThanOrEqual(attrInt, Literal(3)), GreaterThan(attrInt, Literal(6)))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> colStatInt.copy(distinctCount = 5)), + Seq(attrInt -> colStatInt.copy(distinctCount = Some(5))), expectedRowCount = 5) } @@ -342,47 +344,47 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Not(Or(EqualTo(attrInt, Literal(3)), LessThan(attrString, Literal("A8")))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt, attrString), 10L)), - Seq(attrInt -> colStatInt.copy(distinctCount = 9), - attrString -> colStatString.copy(distinctCount = 9)), + Seq(attrInt -> colStatInt.copy(distinctCount = Some(9)), + attrString -> colStatString.copy(distinctCount = Some(9))), expectedRowCount = 9) } test("cint IN (3, 4, 5)") { validateEstimatedStats( Filter(InSet(attrInt, Set(3, 4, 5)), childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(3), max = Some(5), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(3), min = Some(3), max = Some(5), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 3) } test("cint NOT IN (3, 4, 5)") { validateEstimatedStats( Filter(Not(InSet(attrInt, Set(3, 4, 5))), childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> colStatInt.copy(distinctCount = 7)), + Seq(attrInt -> colStatInt.copy(distinctCount = Some(7))), expectedRowCount = 7) } test("cbool IN (true)") { validateEstimatedStats( Filter(InSet(attrBool, Set(true)), childStatsTestPlan(Seq(attrBool), 10L)), - Seq(attrBool -> ColumnStat(distinctCount = 1, min = Some(true), max = Some(true), - nullCount = 0, avgLen = 1, maxLen = 1)), + Seq(attrBool -> ColumnStat(distinctCount = Some(1), min = Some(true), max = Some(true), + nullCount = Some(0), avgLen = Some(1), maxLen = Some(1))), expectedRowCount = 5) } test("cbool = true") { validateEstimatedStats( Filter(EqualTo(attrBool, Literal(true)), childStatsTestPlan(Seq(attrBool), 10L)), - Seq(attrBool -> ColumnStat(distinctCount = 1, min = Some(true), max = Some(true), - nullCount = 0, avgLen = 1, maxLen = 1)), + Seq(attrBool -> ColumnStat(distinctCount = Some(1), min = Some(true), max = Some(true), + nullCount = Some(0), avgLen = Some(1), maxLen = Some(1))), expectedRowCount = 5) } test("cbool > false") { validateEstimatedStats( Filter(GreaterThan(attrBool, Literal(false)), childStatsTestPlan(Seq(attrBool), 10L)), - Seq(attrBool -> ColumnStat(distinctCount = 1, min = Some(false), max = Some(true), - nullCount = 0, avgLen = 1, maxLen = 1)), + Seq(attrBool -> ColumnStat(distinctCount = Some(1), min = Some(false), max = Some(true), + nullCount = Some(0), avgLen = Some(1), maxLen = Some(1))), expectedRowCount = 5) } @@ -391,18 +393,21 @@ class FilterEstimationSuite extends StatsEstimationTestBase { validateEstimatedStats( Filter(EqualTo(attrDate, Literal(d20170102, DateType)), childStatsTestPlan(Seq(attrDate), 10L)), - Seq(attrDate -> ColumnStat(distinctCount = 1, min = Some(d20170102), max = Some(d20170102), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrDate -> ColumnStat(distinctCount = Some(1), + min = Some(d20170102), max = Some(d20170102), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 1) } test("cdate < cast('2017-01-03' AS DATE)") { + val d20170101 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-01")) val d20170103 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-03")) validateEstimatedStats( Filter(LessThan(attrDate, Literal(d20170103, DateType)), childStatsTestPlan(Seq(attrDate), 10L)), - Seq(attrDate -> ColumnStat(distinctCount = 3, min = Some(dMin), max = Some(d20170103), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrDate -> ColumnStat(distinctCount = Some(3), + min = Some(d20170101), max = Some(d20170103), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 3) } @@ -414,8 +419,9 @@ class FilterEstimationSuite extends StatsEstimationTestBase { validateEstimatedStats( Filter(In(attrDate, Seq(Literal(d20170103, DateType), Literal(d20170104, DateType), Literal(d20170105, DateType))), childStatsTestPlan(Seq(attrDate), 10L)), - Seq(attrDate -> ColumnStat(distinctCount = 3, min = Some(d20170103), max = Some(d20170105), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrDate -> ColumnStat(distinctCount = Some(3), + min = Some(d20170103), max = Some(d20170105), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 3) } @@ -424,42 +430,45 @@ class FilterEstimationSuite extends StatsEstimationTestBase { validateEstimatedStats( Filter(EqualTo(attrDecimal, Literal(dec_0_40)), childStatsTestPlan(Seq(attrDecimal), 4L)), - Seq(attrDecimal -> ColumnStat(distinctCount = 1, min = Some(dec_0_40), max = Some(dec_0_40), - nullCount = 0, avgLen = 8, maxLen = 8)), + Seq(attrDecimal -> ColumnStat(distinctCount = Some(1), + min = Some(dec_0_40), max = Some(dec_0_40), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8))), expectedRowCount = 1) } test("cdecimal < 0.60 ") { + val dec_0_20 = Decimal("0.200000000000000000") val dec_0_60 = Decimal("0.600000000000000000") validateEstimatedStats( Filter(LessThan(attrDecimal, Literal(dec_0_60)), childStatsTestPlan(Seq(attrDecimal), 4L)), - Seq(attrDecimal -> ColumnStat(distinctCount = 3, min = Some(decMin), max = Some(dec_0_60), - nullCount = 0, avgLen = 8, maxLen = 8)), + Seq(attrDecimal -> ColumnStat(distinctCount = Some(3), + min = Some(dec_0_20), max = Some(dec_0_60), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8))), expectedRowCount = 3) } test("cdouble < 3.0") { validateEstimatedStats( Filter(LessThan(attrDouble, Literal(3.0)), childStatsTestPlan(Seq(attrDouble), 10L)), - Seq(attrDouble -> ColumnStat(distinctCount = 3, min = Some(1.0), max = Some(3.0), - nullCount = 0, avgLen = 8, maxLen = 8)), + Seq(attrDouble -> ColumnStat(distinctCount = Some(3), min = Some(1.0), max = Some(3.0), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8))), expectedRowCount = 3) } test("cstring = 'A2'") { validateEstimatedStats( Filter(EqualTo(attrString, Literal("A2")), childStatsTestPlan(Seq(attrString), 10L)), - Seq(attrString -> ColumnStat(distinctCount = 1, min = None, max = None, - nullCount = 0, avgLen = 2, maxLen = 2)), + Seq(attrString -> ColumnStat(distinctCount = Some(1), min = None, max = None, + nullCount = Some(0), avgLen = Some(2), maxLen = Some(2))), expectedRowCount = 1) } test("cstring < 'A2' - unsupported condition") { validateEstimatedStats( Filter(LessThan(attrString, Literal("A2")), childStatsTestPlan(Seq(attrString), 10L)), - Seq(attrString -> ColumnStat(distinctCount = 10, min = None, max = None, - nullCount = 0, avgLen = 2, maxLen = 2)), + Seq(attrString -> ColumnStat(distinctCount = Some(10), min = None, max = None, + nullCount = Some(0), avgLen = Some(2), maxLen = Some(2))), expectedRowCount = 10) } @@ -468,8 +477,9 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // valid values in IN clause is greater than the number of distinct values for a given column. // For example, column has only 2 distinct values 1 and 6. // The predicate is: column IN (1, 2, 3, 4, 5). - val cornerChildColStatInt = ColumnStat(distinctCount = 2, min = Some(1), max = Some(6), - nullCount = 0, avgLen = 4, maxLen = 4) + val cornerChildColStatInt = ColumnStat(distinctCount = Some(2), + min = Some(1), max = Some(6), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) val cornerChildStatsTestplan = StatsTestPlan( outputList = Seq(attrInt), rowCount = 2L, @@ -477,16 +487,17 @@ class FilterEstimationSuite extends StatsEstimationTestBase { ) validateEstimatedStats( Filter(InSet(attrInt, Set(1, 2, 3, 4, 5)), cornerChildStatsTestplan), - Seq(attrInt -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(5), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(2), min = Some(1), max = Some(5), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 2) } // This is a limitation test. We should remove it after the limitation is removed. test("don't estimate IsNull or IsNotNull if the child is a non-leaf node") { val attrIntLargerRange = AttributeReference("c1", IntegerType)() - val colStatIntLargerRange = ColumnStat(distinctCount = 20, min = Some(1), max = Some(20), - nullCount = 10, avgLen = 4, maxLen = 4) + val colStatIntLargerRange = ColumnStat(distinctCount = Some(20), + min = Some(1), max = Some(20), + nullCount = Some(10), avgLen = Some(4), maxLen = Some(4)) val smallerTable = childStatsTestPlan(Seq(attrInt), 10L) val largerTable = StatsTestPlan( outputList = Seq(attrIntLargerRange), @@ -508,10 +519,10 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // partial overlap case validateEstimatedStats( Filter(EqualTo(attrInt, attrInt2), childStatsTestPlan(Seq(attrInt, attrInt2), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 4, min = Some(7), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - attrInt2 -> ColumnStat(distinctCount = 4, min = Some(7), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(4), min = Some(7), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attrInt2 -> ColumnStat(distinctCount = Some(4), min = Some(7), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 4) } @@ -519,10 +530,10 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // partial overlap case validateEstimatedStats( Filter(GreaterThan(attrInt, attrInt2), childStatsTestPlan(Seq(attrInt, attrInt2), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 4, min = Some(7), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - attrInt2 -> ColumnStat(distinctCount = 4, min = Some(7), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(4), min = Some(7), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attrInt2 -> ColumnStat(distinctCount = Some(4), min = Some(7), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 4) } @@ -530,10 +541,10 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // partial overlap case validateEstimatedStats( Filter(LessThan(attrInt, attrInt2), childStatsTestPlan(Seq(attrInt, attrInt2), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - attrInt2 -> ColumnStat(distinctCount = 4, min = Some(7), max = Some(16), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(4), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attrInt2 -> ColumnStat(distinctCount = Some(4), min = Some(7), max = Some(16), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 4) } @@ -541,10 +552,10 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // complete overlap case validateEstimatedStats( Filter(EqualTo(attrInt, attrInt4), childStatsTestPlan(Seq(attrInt, attrInt4), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - attrInt4 -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(10), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attrInt4 -> ColumnStat(distinctCount = Some(10), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 10) } @@ -552,10 +563,10 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // partial overlap case validateEstimatedStats( Filter(LessThan(attrInt, attrInt4), childStatsTestPlan(Seq(attrInt, attrInt4), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - attrInt4 -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(4), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attrInt4 -> ColumnStat(distinctCount = Some(4), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 4) } @@ -571,10 +582,10 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // all table records qualify. validateEstimatedStats( Filter(LessThan(attrInt, attrInt3), childStatsTestPlan(Seq(attrInt, attrInt3), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - attrInt3 -> ColumnStat(distinctCount = 10, min = Some(30), max = Some(39), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> ColumnStat(distinctCount = Some(10), min = Some(1), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attrInt3 -> ColumnStat(distinctCount = Some(10), min = Some(30), max = Some(39), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))), expectedRowCount = 10) } @@ -592,11 +603,11 @@ class FilterEstimationSuite extends StatsEstimationTestBase { validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt, attrInt4, attrString), 10L)), Seq( - attrInt -> ColumnStat(distinctCount = 5, min = Some(3), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - attrInt4 -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(6), - nullCount = 0, avgLen = 4, maxLen = 4), - attrString -> colStatString.copy(distinctCount = 5)), + attrInt -> ColumnStat(distinctCount = Some(5), min = Some(3), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attrInt4 -> ColumnStat(distinctCount = Some(5), min = Some(1), max = Some(6), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attrString -> colStatString.copy(distinctCount = Some(5))), expectedRowCount = 5) } @@ -606,15 +617,15 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Not(And(LessThan(attrIntHgm, Literal(3)), Literal(null, IntegerType))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrIntHgm), 10L)), - Seq(attrIntHgm -> colStatIntHgm.copy(distinctCount = 7)), + Seq(attrIntHgm -> colStatIntHgm.copy(distinctCount = Some(7))), expectedRowCount = 7) } test("cintHgm = 5") { validateEstimatedStats( Filter(EqualTo(attrIntHgm, Literal(5)), childStatsTestPlan(Seq(attrIntHgm), 10L)), - Seq(attrIntHgm -> ColumnStat(distinctCount = 1, min = Some(5), max = Some(5), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmInt))), + Seq(attrIntHgm -> ColumnStat(distinctCount = Some(1), min = Some(5), max = Some(5), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmInt))), expectedRowCount = 1) } @@ -629,8 +640,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cintHgm < 3") { validateEstimatedStats( Filter(LessThan(attrIntHgm, Literal(3)), childStatsTestPlan(Seq(attrIntHgm), 10L)), - Seq(attrIntHgm -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmInt))), + Seq(attrIntHgm -> ColumnStat(distinctCount = Some(3), min = Some(1), max = Some(3), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmInt))), expectedRowCount = 3) } @@ -645,16 +656,16 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cintHgm <= 3") { validateEstimatedStats( Filter(LessThanOrEqual(attrIntHgm, Literal(3)), childStatsTestPlan(Seq(attrIntHgm), 10L)), - Seq(attrIntHgm -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmInt))), + Seq(attrIntHgm -> ColumnStat(distinctCount = Some(3), min = Some(1), max = Some(3), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmInt))), expectedRowCount = 3) } test("cintHgm > 6") { validateEstimatedStats( Filter(GreaterThan(attrIntHgm, Literal(6)), childStatsTestPlan(Seq(attrIntHgm), 10L)), - Seq(attrIntHgm -> ColumnStat(distinctCount = 4, min = Some(6), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmInt))), + Seq(attrIntHgm -> ColumnStat(distinctCount = Some(4), min = Some(6), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmInt))), expectedRowCount = 4) } @@ -669,8 +680,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cintHgm >= 6") { validateEstimatedStats( Filter(GreaterThanOrEqual(attrIntHgm, Literal(6)), childStatsTestPlan(Seq(attrIntHgm), 10L)), - Seq(attrIntHgm -> ColumnStat(distinctCount = 5, min = Some(6), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmInt))), + Seq(attrIntHgm -> ColumnStat(distinctCount = Some(5), min = Some(6), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmInt))), expectedRowCount = 5) } @@ -679,8 +690,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Literal(3)), LessThanOrEqual(attrIntHgm, Literal(6))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrIntHgm), 10L)), - Seq(attrIntHgm -> ColumnStat(distinctCount = 4, min = Some(3), max = Some(6), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmInt))), + Seq(attrIntHgm -> ColumnStat(distinctCount = Some(4), min = Some(3), max = Some(6), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmInt))), expectedRowCount = 4) } @@ -688,7 +699,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Or(EqualTo(attrIntHgm, Literal(3)), EqualTo(attrIntHgm, Literal(6))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrIntHgm), 10L)), - Seq(attrIntHgm -> colStatIntHgm.copy(distinctCount = 3)), + Seq(attrIntHgm -> colStatIntHgm.copy(distinctCount = Some(3))), expectedRowCount = 3) } @@ -698,15 +709,15 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Not(And(LessThan(attrIntSkewHgm, Literal(3)), Literal(null, IntegerType))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrIntSkewHgm), 10L)), - Seq(attrIntSkewHgm -> colStatIntSkewHgm.copy(distinctCount = 5)), + Seq(attrIntSkewHgm -> colStatIntSkewHgm.copy(distinctCount = Some(5))), expectedRowCount = 9) } test("cintSkewHgm = 5") { validateEstimatedStats( Filter(EqualTo(attrIntSkewHgm, Literal(5)), childStatsTestPlan(Seq(attrIntSkewHgm), 10L)), - Seq(attrIntSkewHgm -> ColumnStat(distinctCount = 1, min = Some(5), max = Some(5), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmIntSkew))), + Seq(attrIntSkewHgm -> ColumnStat(distinctCount = Some(1), min = Some(5), max = Some(5), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmIntSkew))), expectedRowCount = 4) } @@ -721,8 +732,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cintSkewHgm < 3") { validateEstimatedStats( Filter(LessThan(attrIntSkewHgm, Literal(3)), childStatsTestPlan(Seq(attrIntSkewHgm), 10L)), - Seq(attrIntSkewHgm -> ColumnStat(distinctCount = 1, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmIntSkew))), + Seq(attrIntSkewHgm -> ColumnStat(distinctCount = Some(1), min = Some(1), max = Some(3), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmIntSkew))), expectedRowCount = 2) } @@ -738,16 +749,16 @@ class FilterEstimationSuite extends StatsEstimationTestBase { validateEstimatedStats( Filter(LessThanOrEqual(attrIntSkewHgm, Literal(3)), childStatsTestPlan(Seq(attrIntSkewHgm), 10L)), - Seq(attrIntSkewHgm -> ColumnStat(distinctCount = 1, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmIntSkew))), + Seq(attrIntSkewHgm -> ColumnStat(distinctCount = Some(1), min = Some(1), max = Some(3), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmIntSkew))), expectedRowCount = 2) } test("cintSkewHgm > 6") { validateEstimatedStats( Filter(GreaterThan(attrIntSkewHgm, Literal(6)), childStatsTestPlan(Seq(attrIntSkewHgm), 10L)), - Seq(attrIntSkewHgm -> ColumnStat(distinctCount = 1, min = Some(6), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmIntSkew))), + Seq(attrIntSkewHgm -> ColumnStat(distinctCount = Some(1), min = Some(6), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmIntSkew))), expectedRowCount = 2) } @@ -764,8 +775,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { validateEstimatedStats( Filter(GreaterThanOrEqual(attrIntSkewHgm, Literal(6)), childStatsTestPlan(Seq(attrIntSkewHgm), 10L)), - Seq(attrIntSkewHgm -> ColumnStat(distinctCount = 2, min = Some(6), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmIntSkew))), + Seq(attrIntSkewHgm -> ColumnStat(distinctCount = Some(2), min = Some(6), max = Some(10), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmIntSkew))), expectedRowCount = 3) } @@ -774,8 +785,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { Literal(3)), LessThanOrEqual(attrIntSkewHgm, Literal(6))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrIntSkewHgm), 10L)), - Seq(attrIntSkewHgm -> ColumnStat(distinctCount = 4, min = Some(3), max = Some(6), - nullCount = 0, avgLen = 4, maxLen = 4, histogram = Some(hgmIntSkew))), + Seq(attrIntSkewHgm -> ColumnStat(distinctCount = Some(4), min = Some(3), max = Some(6), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(hgmIntSkew))), expectedRowCount = 8) } @@ -783,7 +794,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Or(EqualTo(attrIntSkewHgm, Literal(3)), EqualTo(attrIntSkewHgm, Literal(6))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrIntSkewHgm), 10L)), - Seq(attrIntSkewHgm -> colStatIntSkewHgm.copy(distinctCount = 2)), + Seq(attrIntSkewHgm -> colStatIntSkewHgm.copy(distinctCount = Some(2))), expectedRowCount = 3) }
http://git-wip-us.apache.org/repos/asf/spark/blob/8077bb04/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala index 26139d8..12c0a7b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala @@ -33,16 +33,16 @@ class JoinEstimationSuite extends StatsEstimationTestBase { /** Set up tables and its columns for testing */ private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( - attr("key-1-5") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), nullCount = 0, - avgLen = 4, maxLen = 4), - attr("key-5-9") -> ColumnStat(distinctCount = 5, min = Some(5), max = Some(9), nullCount = 0, - avgLen = 4, maxLen = 4), - attr("key-1-2") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0, - avgLen = 4, maxLen = 4), - attr("key-2-4") -> ColumnStat(distinctCount = 3, min = Some(2), max = Some(4), nullCount = 0, - avgLen = 4, maxLen = 4), - attr("key-2-3") -> ColumnStat(distinctCount = 2, min = Some(2), max = Some(3), nullCount = 0, - avgLen = 4, maxLen = 4) + attr("key-1-5") -> ColumnStat(distinctCount = Some(5), min = Some(1), max = Some(5), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attr("key-5-9") -> ColumnStat(distinctCount = Some(5), min = Some(5), max = Some(9), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attr("key-1-2") -> ColumnStat(distinctCount = Some(2), min = Some(1), max = Some(2), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attr("key-2-4") -> ColumnStat(distinctCount = Some(3), min = Some(2), max = Some(4), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + attr("key-2-3") -> ColumnStat(distinctCount = Some(2), min = Some(2), max = Some(3), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) )) private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => kv._1.name -> kv._1) @@ -70,8 +70,8 @@ class JoinEstimationSuite extends StatsEstimationTestBase { private def estimateByHistogram( leftHistogram: Histogram, rightHistogram: Histogram, - expectedMin: Double, - expectedMax: Double, + expectedMin: Any, + expectedMax: Any, expectedNdv: Long, expectedRows: Long): Unit = { val col1 = attr("key1") @@ -86,9 +86,11 @@ class JoinEstimationSuite extends StatsEstimationTestBase { rowCount = Some(expectedRows), attributeStats = AttributeMap(Seq( col1 -> c1.stats.attributeStats(col1).copy( - distinctCount = expectedNdv, min = Some(expectedMin), max = Some(expectedMax)), + distinctCount = Some(expectedNdv), + min = Some(expectedMin), max = Some(expectedMax)), col2 -> c2.stats.attributeStats(col2).copy( - distinctCount = expectedNdv, min = Some(expectedMin), max = Some(expectedMax)))) + distinctCount = Some(expectedNdv), + min = Some(expectedMin), max = Some(expectedMax)))) ) // Join order should not affect estimation result. @@ -100,9 +102,9 @@ class JoinEstimationSuite extends StatsEstimationTestBase { private def generateJoinChild( col: Attribute, histogram: Histogram, - expectedMin: Double, - expectedMax: Double): LogicalPlan = { - val colStat = inferColumnStat(histogram) + expectedMin: Any, + expectedMax: Any): LogicalPlan = { + val colStat = inferColumnStat(histogram, expectedMin, expectedMax) StatsTestPlan( outputList = Seq(col), rowCount = (histogram.height * histogram.bins.length).toLong, @@ -110,7 +112,11 @@ class JoinEstimationSuite extends StatsEstimationTestBase { } /** Column statistics should be consistent with histograms in tests. */ - private def inferColumnStat(histogram: Histogram): ColumnStat = { + private def inferColumnStat( + histogram: Histogram, + expectedMin: Any, + expectedMax: Any): ColumnStat = { + var ndv = 0L for (i <- histogram.bins.indices) { val bin = histogram.bins(i) @@ -118,8 +124,9 @@ class JoinEstimationSuite extends StatsEstimationTestBase { ndv += bin.ndv } } - ColumnStat(distinctCount = ndv, min = Some(histogram.bins.head.lo), - max = Some(histogram.bins.last.hi), nullCount = 0, avgLen = 4, maxLen = 4, + ColumnStat(distinctCount = Some(ndv), + min = Some(expectedMin), max = Some(expectedMax), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4), histogram = Some(histogram)) } @@ -343,10 +350,10 @@ class JoinEstimationSuite extends StatsEstimationTestBase { rowCount = Some(5 + 3), attributeStats = AttributeMap( // Update null count in column stats. - Seq(nameToAttr("key-1-5") -> columnInfo(nameToAttr("key-1-5")).copy(nullCount = 3), - nameToAttr("key-5-9") -> columnInfo(nameToAttr("key-5-9")).copy(nullCount = 3), - nameToAttr("key-1-2") -> columnInfo(nameToAttr("key-1-2")).copy(nullCount = 5), - nameToAttr("key-2-4") -> columnInfo(nameToAttr("key-2-4")).copy(nullCount = 5)))) + Seq(nameToAttr("key-1-5") -> columnInfo(nameToAttr("key-1-5")).copy(nullCount = Some(3)), + nameToAttr("key-5-9") -> columnInfo(nameToAttr("key-5-9")).copy(nullCount = Some(3)), + nameToAttr("key-1-2") -> columnInfo(nameToAttr("key-1-2")).copy(nullCount = Some(5)), + nameToAttr("key-2-4") -> columnInfo(nameToAttr("key-2-4")).copy(nullCount = Some(5))))) assert(join.stats == expectedStats) } @@ -356,11 +363,11 @@ class JoinEstimationSuite extends StatsEstimationTestBase { val join = Join(table1, table2, Inner, Some(EqualTo(nameToAttr("key-1-5"), nameToAttr("key-1-2")))) // Update column stats for equi-join keys (key-1-5 and key-1-2). - val joinedColStat = ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0, - avgLen = 4, maxLen = 4) + val joinedColStat = ColumnStat(distinctCount = Some(2), min = Some(1), max = Some(2), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) // Update column stat for other column if #outputRow / #sideRow < 1 (key-5-9), or keep it // unchanged (key-2-4). - val colStatForkey59 = nameToColInfo("key-5-9")._2.copy(distinctCount = 5 * 3 / 5) + val colStatForkey59 = nameToColInfo("key-5-9")._2.copy(distinctCount = Some(5 * 3 / 5)) val expectedStats = Statistics( sizeInBytes = 3 * (8 + 4 * 4), @@ -379,10 +386,10 @@ class JoinEstimationSuite extends StatsEstimationTestBase { EqualTo(nameToAttr("key-2-4"), nameToAttr("key-2-3"))))) // Update column stats for join keys. - val joinedColStat1 = ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0, - avgLen = 4, maxLen = 4) - val joinedColStat2 = ColumnStat(distinctCount = 2, min = Some(2), max = Some(3), nullCount = 0, - avgLen = 4, maxLen = 4) + val joinedColStat1 = ColumnStat(distinctCount = Some(2), min = Some(1), max = Some(2), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) + val joinedColStat2 = ColumnStat(distinctCount = Some(2), min = Some(2), max = Some(3), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) val expectedStats = Statistics( sizeInBytes = 2 * (8 + 4 * 4), @@ -398,8 +405,8 @@ class JoinEstimationSuite extends StatsEstimationTestBase { // table3 (key-1-2 int, key-2-3 int): (1, 2), (2, 3) val join = Join(table3, table2, LeftOuter, Some(EqualTo(nameToAttr("key-2-3"), nameToAttr("key-2-4")))) - val joinedColStat = ColumnStat(distinctCount = 2, min = Some(2), max = Some(3), nullCount = 0, - avgLen = 4, maxLen = 4) + val joinedColStat = ColumnStat(distinctCount = Some(2), min = Some(2), max = Some(3), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) val expectedStats = Statistics( sizeInBytes = 2 * (8 + 4 * 4), @@ -416,8 +423,8 @@ class JoinEstimationSuite extends StatsEstimationTestBase { // table3 (key-1-2 int, key-2-3 int): (1, 2), (2, 3) val join = Join(table2, table3, RightOuter, Some(EqualTo(nameToAttr("key-2-4"), nameToAttr("key-2-3")))) - val joinedColStat = ColumnStat(distinctCount = 2, min = Some(2), max = Some(3), nullCount = 0, - avgLen = 4, maxLen = 4) + val joinedColStat = ColumnStat(distinctCount = Some(2), min = Some(2), max = Some(3), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) val expectedStats = Statistics( sizeInBytes = 2 * (8 + 4 * 4), @@ -466,30 +473,40 @@ class JoinEstimationSuite extends StatsEstimationTestBase { val date = DateTimeUtils.fromJavaDate(Date.valueOf("2016-05-08")) val timestamp = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-05-08 00:00:01")) mutable.LinkedHashMap[Attribute, ColumnStat]( - AttributeReference("cbool", BooleanType)() -> ColumnStat(distinctCount = 1, - min = Some(false), max = Some(false), nullCount = 0, avgLen = 1, maxLen = 1), - AttributeReference("cbyte", ByteType)() -> ColumnStat(distinctCount = 1, - min = Some(1.toByte), max = Some(1.toByte), nullCount = 0, avgLen = 1, maxLen = 1), - AttributeReference("cshort", ShortType)() -> ColumnStat(distinctCount = 1, - min = Some(1.toShort), max = Some(1.toShort), nullCount = 0, avgLen = 2, maxLen = 2), - AttributeReference("cint", IntegerType)() -> ColumnStat(distinctCount = 1, - min = Some(1), max = Some(1), nullCount = 0, avgLen = 4, maxLen = 4), - AttributeReference("clong", LongType)() -> ColumnStat(distinctCount = 1, - min = Some(1L), max = Some(1L), nullCount = 0, avgLen = 8, maxLen = 8), - AttributeReference("cdouble", DoubleType)() -> ColumnStat(distinctCount = 1, - min = Some(1.0), max = Some(1.0), nullCount = 0, avgLen = 8, maxLen = 8), - AttributeReference("cfloat", FloatType)() -> ColumnStat(distinctCount = 1, - min = Some(1.0f), max = Some(1.0f), nullCount = 0, avgLen = 4, maxLen = 4), - AttributeReference("cdec", DecimalType.SYSTEM_DEFAULT)() -> ColumnStat(distinctCount = 1, - min = Some(dec), max = Some(dec), nullCount = 0, avgLen = 16, maxLen = 16), - AttributeReference("cstring", StringType)() -> ColumnStat(distinctCount = 1, - min = None, max = None, nullCount = 0, avgLen = 3, maxLen = 3), - AttributeReference("cbinary", BinaryType)() -> ColumnStat(distinctCount = 1, - min = None, max = None, nullCount = 0, avgLen = 3, maxLen = 3), - AttributeReference("cdate", DateType)() -> ColumnStat(distinctCount = 1, - min = Some(date), max = Some(date), nullCount = 0, avgLen = 4, maxLen = 4), - AttributeReference("ctimestamp", TimestampType)() -> ColumnStat(distinctCount = 1, - min = Some(timestamp), max = Some(timestamp), nullCount = 0, avgLen = 8, maxLen = 8) + AttributeReference("cbool", BooleanType)() -> ColumnStat(distinctCount = Some(1), + min = Some(false), max = Some(false), + nullCount = Some(0), avgLen = Some(1), maxLen = Some(1)), + AttributeReference("cbyte", ByteType)() -> ColumnStat(distinctCount = Some(1), + min = Some(1.toByte), max = Some(1.toByte), + nullCount = Some(0), avgLen = Some(1), maxLen = Some(1)), + AttributeReference("cshort", ShortType)() -> ColumnStat(distinctCount = Some(1), + min = Some(1.toShort), max = Some(1.toShort), + nullCount = Some(0), avgLen = Some(2), maxLen = Some(2)), + AttributeReference("cint", IntegerType)() -> ColumnStat(distinctCount = Some(1), + min = Some(1), max = Some(1), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + AttributeReference("clong", LongType)() -> ColumnStat(distinctCount = Some(1), + min = Some(1L), max = Some(1L), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8)), + AttributeReference("cdouble", DoubleType)() -> ColumnStat(distinctCount = Some(1), + min = Some(1.0), max = Some(1.0), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8)), + AttributeReference("cfloat", FloatType)() -> ColumnStat(distinctCount = Some(1), + min = Some(1.0f), max = Some(1.0f), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + AttributeReference("cdec", DecimalType.SYSTEM_DEFAULT)() -> ColumnStat( + distinctCount = Some(1), min = Some(dec), max = Some(dec), + nullCount = Some(0), avgLen = Some(16), maxLen = Some(16)), + AttributeReference("cstring", StringType)() -> ColumnStat(distinctCount = Some(1), + min = None, max = None, nullCount = Some(0), avgLen = Some(3), maxLen = Some(3)), + AttributeReference("cbinary", BinaryType)() -> ColumnStat(distinctCount = Some(1), + min = None, max = None, nullCount = Some(0), avgLen = Some(3), maxLen = Some(3)), + AttributeReference("cdate", DateType)() -> ColumnStat(distinctCount = Some(1), + min = Some(date), max = Some(date), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + AttributeReference("ctimestamp", TimestampType)() -> ColumnStat(distinctCount = Some(1), + min = Some(timestamp), max = Some(timestamp), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8)) ) } @@ -520,7 +537,8 @@ class JoinEstimationSuite extends StatsEstimationTestBase { test("join with null column") { val (nullColumn, nullColStat) = (attr("cnull"), - ColumnStat(distinctCount = 0, min = None, max = None, nullCount = 1, avgLen = 4, maxLen = 4)) + ColumnStat(distinctCount = Some(0), min = None, max = None, + nullCount = Some(1), avgLen = Some(4), maxLen = Some(4))) val nullTable = StatsTestPlan( outputList = Seq(nullColumn), rowCount = 1, http://git-wip-us.apache.org/repos/asf/spark/blob/8077bb04/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala index cda54fa..dcb3701 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala @@ -28,10 +28,10 @@ import org.apache.spark.sql.types._ class ProjectEstimationSuite extends StatsEstimationTestBase { test("project with alias") { - val (ar1, colStat1) = (attr("key1"), ColumnStat(distinctCount = 2, min = Some(1), - max = Some(2), nullCount = 0, avgLen = 4, maxLen = 4)) - val (ar2, colStat2) = (attr("key2"), ColumnStat(distinctCount = 1, min = Some(10), - max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4)) + val (ar1, colStat1) = (attr("key1"), ColumnStat(distinctCount = Some(2), min = Some(1), + max = Some(2), nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))) + val (ar2, colStat2) = (attr("key2"), ColumnStat(distinctCount = Some(1), min = Some(10), + max = Some(10), nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))) val child = StatsTestPlan( outputList = Seq(ar1, ar2), @@ -49,8 +49,8 @@ class ProjectEstimationSuite extends StatsEstimationTestBase { } test("project on empty table") { - val (ar1, colStat1) = (attr("key1"), ColumnStat(distinctCount = 0, min = None, max = None, - nullCount = 0, avgLen = 4, maxLen = 4)) + val (ar1, colStat1) = (attr("key1"), ColumnStat(distinctCount = Some(0), min = None, max = None, + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))) val child = StatsTestPlan( outputList = Seq(ar1), rowCount = 0, @@ -71,30 +71,40 @@ class ProjectEstimationSuite extends StatsEstimationTestBase { val t2 = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-05-09 00:00:02")) val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( - AttributeReference("cbool", BooleanType)() -> ColumnStat(distinctCount = 2, - min = Some(false), max = Some(true), nullCount = 0, avgLen = 1, maxLen = 1), - AttributeReference("cbyte", ByteType)() -> ColumnStat(distinctCount = 2, - min = Some(1.toByte), max = Some(2.toByte), nullCount = 0, avgLen = 1, maxLen = 1), - AttributeReference("cshort", ShortType)() -> ColumnStat(distinctCount = 2, - min = Some(1.toShort), max = Some(3.toShort), nullCount = 0, avgLen = 2, maxLen = 2), - AttributeReference("cint", IntegerType)() -> ColumnStat(distinctCount = 2, - min = Some(1), max = Some(4), nullCount = 0, avgLen = 4, maxLen = 4), - AttributeReference("clong", LongType)() -> ColumnStat(distinctCount = 2, - min = Some(1L), max = Some(5L), nullCount = 0, avgLen = 8, maxLen = 8), - AttributeReference("cdouble", DoubleType)() -> ColumnStat(distinctCount = 2, - min = Some(1.0), max = Some(6.0), nullCount = 0, avgLen = 8, maxLen = 8), - AttributeReference("cfloat", FloatType)() -> ColumnStat(distinctCount = 2, - min = Some(1.0f), max = Some(7.0f), nullCount = 0, avgLen = 4, maxLen = 4), - AttributeReference("cdecimal", DecimalType.SYSTEM_DEFAULT)() -> ColumnStat(distinctCount = 2, - min = Some(dec1), max = Some(dec2), nullCount = 0, avgLen = 16, maxLen = 16), - AttributeReference("cstring", StringType)() -> ColumnStat(distinctCount = 2, - min = None, max = None, nullCount = 0, avgLen = 3, maxLen = 3), - AttributeReference("cbinary", BinaryType)() -> ColumnStat(distinctCount = 2, - min = None, max = None, nullCount = 0, avgLen = 3, maxLen = 3), - AttributeReference("cdate", DateType)() -> ColumnStat(distinctCount = 2, - min = Some(d1), max = Some(d2), nullCount = 0, avgLen = 4, maxLen = 4), - AttributeReference("ctimestamp", TimestampType)() -> ColumnStat(distinctCount = 2, - min = Some(t1), max = Some(t2), nullCount = 0, avgLen = 8, maxLen = 8) + AttributeReference("cbool", BooleanType)() -> ColumnStat(distinctCount = Some(2), + min = Some(false), max = Some(true), + nullCount = Some(0), avgLen = Some(1), maxLen = Some(1)), + AttributeReference("cbyte", ByteType)() -> ColumnStat(distinctCount = Some(2), + min = Some(1), max = Some(2), + nullCount = Some(0), avgLen = Some(1), maxLen = Some(1)), + AttributeReference("cshort", ShortType)() -> ColumnStat(distinctCount = Some(2), + min = Some(1), max = Some(3), + nullCount = Some(0), avgLen = Some(2), maxLen = Some(2)), + AttributeReference("cint", IntegerType)() -> ColumnStat(distinctCount = Some(2), + min = Some(1), max = Some(4), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + AttributeReference("clong", LongType)() -> ColumnStat(distinctCount = Some(2), + min = Some(1), max = Some(5), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8)), + AttributeReference("cdouble", DoubleType)() -> ColumnStat(distinctCount = Some(2), + min = Some(1.0), max = Some(6.0), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8)), + AttributeReference("cfloat", FloatType)() -> ColumnStat(distinctCount = Some(2), + min = Some(1.0), max = Some(7.0), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + AttributeReference("cdecimal", DecimalType.SYSTEM_DEFAULT)() -> ColumnStat( + distinctCount = Some(2), min = Some(dec1), max = Some(dec2), + nullCount = Some(0), avgLen = Some(16), maxLen = Some(16)), + AttributeReference("cstring", StringType)() -> ColumnStat(distinctCount = Some(2), + min = None, max = None, nullCount = Some(0), avgLen = Some(3), maxLen = Some(3)), + AttributeReference("cbinary", BinaryType)() -> ColumnStat(distinctCount = Some(2), + min = None, max = None, nullCount = Some(0), avgLen = Some(3), maxLen = Some(3)), + AttributeReference("cdate", DateType)() -> ColumnStat(distinctCount = Some(2), + min = Some(d1), max = Some(d2), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)), + AttributeReference("ctimestamp", TimestampType)() -> ColumnStat(distinctCount = Some(2), + min = Some(t1), max = Some(t2), + nullCount = Some(0), avgLen = Some(8), maxLen = Some(8)) )) val columnSizes: Map[Attribute, Long] = columnInfo.map(kv => (kv._1, getColSize(kv._1, kv._2))) val child = StatsTestPlan( http://git-wip-us.apache.org/repos/asf/spark/blob/8077bb04/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala index 31dea2e..9dceca5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala @@ -42,8 +42,8 @@ trait StatsEstimationTestBase extends SparkFunSuite { def getColSize(attribute: Attribute, colStat: ColumnStat): Long = attribute.dataType match { // For UTF8String: base + offset + numBytes - case StringType => colStat.avgLen + 8 + 4 - case _ => colStat.avgLen + case StringType => colStat.avgLen.getOrElse(attribute.dataType.defaultSize.toLong) + 8 + 4 + case _ => colStat.avgLen.getOrElse(attribute.dataType.defaultSize) } def attr(colName: String): AttributeReference = AttributeReference(colName, IntegerType)() @@ -54,6 +54,12 @@ trait StatsEstimationTestBase extends SparkFunSuite { val nameToAttr: Map[String, Attribute] = plan.output.map(a => (a.name, a)).toMap AttributeMap(colStats.map(kv => nameToAttr(kv._1) -> kv._2)) } + + /** Get a test ColumnStat with given distinctCount and nullCount */ + def rangeColumnStat(distinctCount: Int, nullCount: Int): ColumnStat = + ColumnStat(distinctCount = Some(distinctCount), + min = Some(1), max = Some(distinctCount), + nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)) } /** http://git-wip-us.apache.org/repos/asf/spark/blob/8077bb04/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index 1122522..640e013 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -20,13 +20,15 @@ package org.apache.spark.sql.execution.command import scala.collection.mutable import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTableType} +import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} +import org.apache.spark.sql.catalyst.catalog.{CatalogColumnStat, CatalogStatistics, CatalogTableType} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ /** @@ -64,12 +66,12 @@ case class AnalyzeColumnCommand( /** * Compute stats for the given columns. - * @return (row count, map from column name to ColumnStats) + * @return (row count, map from column name to CatalogColumnStats) */ private def computeColumnStats( sparkSession: SparkSession, tableIdent: TableIdentifier, - columnNames: Seq[String]): (Long, Map[String, ColumnStat]) = { + columnNames: Seq[String]): (Long, Map[String, CatalogColumnStat]) = { val conf = sparkSession.sessionState.conf val relation = sparkSession.table(tableIdent).logicalPlan @@ -81,7 +83,7 @@ case class AnalyzeColumnCommand( // Make sure the column types are supported for stats gathering. attributesToAnalyze.foreach { attr => - if (!ColumnStat.supportsType(attr.dataType)) { + if (!supportsType(attr.dataType)) { throw new AnalysisException( s"Column ${attr.name} in table $tableIdent is of type ${attr.dataType}, " + "and Spark does not support statistics collection on this column type.") @@ -103,7 +105,7 @@ case class AnalyzeColumnCommand( // will be structs containing all column stats. // The layout of each struct follows the layout of the ColumnStats. val expressions = Count(Literal(1)).toAggregateExpression() +: - attributesToAnalyze.map(ColumnStat.statExprs(_, conf, attributePercentiles)) + attributesToAnalyze.map(statExprs(_, conf, attributePercentiles)) val namedExpressions = expressions.map(e => Alias(e, e.toString)()) val statsRow = new QueryExecution(sparkSession, Aggregate(Nil, namedExpressions, relation)) @@ -111,9 +113,9 @@ case class AnalyzeColumnCommand( val rowCount = statsRow.getLong(0) val columnStats = attributesToAnalyze.zipWithIndex.map { case (attr, i) => - // according to `ColumnStat.statExprs`, the stats struct always have 7 fields. - (attr.name, ColumnStat.rowToColumnStat(statsRow.getStruct(i + 1, 7), attr, rowCount, - attributePercentiles.get(attr))) + // according to `statExprs`, the stats struct always have 7 fields. + (attr.name, rowToColumnStat(statsRow.getStruct(i + 1, 7), attr, rowCount, + attributePercentiles.get(attr)).toCatalogColumnStat(attr.name, attr.dataType)) }.toMap (rowCount, columnStats) } @@ -124,7 +126,7 @@ case class AnalyzeColumnCommand( sparkSession: SparkSession, relation: LogicalPlan): AttributeMap[ArrayData] = { val attrsToGenHistogram = if (conf.histogramEnabled) { - attributesToAnalyze.filter(a => ColumnStat.supportsHistogram(a.dataType)) + attributesToAnalyze.filter(a => supportsHistogram(a.dataType)) } else { Nil } @@ -154,4 +156,120 @@ case class AnalyzeColumnCommand( AttributeMap(attributePercentiles.toSeq) } + /** Returns true iff the we support gathering column statistics on column of the given type. */ + private def supportsType(dataType: DataType): Boolean = dataType match { + case _: IntegralType => true + case _: DecimalType => true + case DoubleType | FloatType => true + case BooleanType => true + case DateType => true + case TimestampType => true + case BinaryType | StringType => true + case _ => false + } + + /** Returns true iff the we support gathering histogram on column of the given type. */ + private def supportsHistogram(dataType: DataType): Boolean = dataType match { + case _: IntegralType => true + case _: DecimalType => true + case DoubleType | FloatType => true + case DateType => true + case TimestampType => true + case _ => false + } + + /** + * Constructs an expression to compute column statistics for a given column. + * + * The expression should create a single struct column with the following schema: + * distinctCount: Long, min: T, max: T, nullCount: Long, avgLen: Long, maxLen: Long, + * distinctCountsForIntervals: Array[Long] + * + * Together with [[rowToColumnStat]], this function is used to create [[ColumnStat]] and + * as a result should stay in sync with it. + */ + private def statExprs( + col: Attribute, + conf: SQLConf, + colPercentiles: AttributeMap[ArrayData]): CreateNamedStruct = { + def struct(exprs: Expression*): CreateNamedStruct = CreateStruct(exprs.map { expr => + expr.transformUp { case af: AggregateFunction => af.toAggregateExpression() } + }) + val one = Literal(1, LongType) + + // the approximate ndv (num distinct value) should never be larger than the number of rows + val numNonNulls = if (col.nullable) Count(col) else Count(one) + val ndv = Least(Seq(HyperLogLogPlusPlus(col, conf.ndvMaxError), numNonNulls)) + val numNulls = Subtract(Count(one), numNonNulls) + val defaultSize = Literal(col.dataType.defaultSize, LongType) + val nullArray = Literal(null, ArrayType(LongType)) + + def fixedLenTypeStruct: CreateNamedStruct = { + val genHistogram = + supportsHistogram(col.dataType) && colPercentiles.contains(col) + val intervalNdvsExpr = if (genHistogram) { + ApproxCountDistinctForIntervals(col, + Literal(colPercentiles(col), ArrayType(col.dataType)), conf.ndvMaxError) + } else { + nullArray + } + // For fixed width types, avg size should be the same as max size. + struct(ndv, Cast(Min(col), col.dataType), Cast(Max(col), col.dataType), numNulls, + defaultSize, defaultSize, intervalNdvsExpr) + } + + col.dataType match { + case _: IntegralType => fixedLenTypeStruct + case _: DecimalType => fixedLenTypeStruct + case DoubleType | FloatType => fixedLenTypeStruct + case BooleanType => fixedLenTypeStruct + case DateType => fixedLenTypeStruct + case TimestampType => fixedLenTypeStruct + case BinaryType | StringType => + // For string and binary type, we don't compute min, max or histogram + val nullLit = Literal(null, col.dataType) + struct( + ndv, nullLit, nullLit, numNulls, + // Set avg/max size to default size if all the values are null or there is no value. + Coalesce(Seq(Ceil(Average(Length(col))), defaultSize)), + Coalesce(Seq(Cast(Max(Length(col)), LongType), defaultSize)), + nullArray) + case _ => + throw new AnalysisException("Analyzing column statistics is not supported for column " + + s"${col.name} of data type: ${col.dataType}.") + } + } + + /** Convert a struct for column stats (defined in `statExprs`) into [[ColumnStat]]. */ + private def rowToColumnStat( + row: InternalRow, + attr: Attribute, + rowCount: Long, + percentiles: Option[ArrayData]): ColumnStat = { + // The first 6 fields are basic column stats, the 7th is ndvs for histogram bins. + val cs = ColumnStat( + distinctCount = Option(BigInt(row.getLong(0))), + // for string/binary min/max, get should return null + min = Option(row.get(1, attr.dataType)), + max = Option(row.get(2, attr.dataType)), + nullCount = Option(BigInt(row.getLong(3))), + avgLen = Option(row.getLong(4)), + maxLen = Option(row.getLong(5)) + ) + if (row.isNullAt(6) || cs.nullCount.isEmpty) { + cs + } else { + val ndvs = row.getArray(6).toLongArray() + assert(percentiles.get.numElements() == ndvs.length + 1) + val endpoints = percentiles.get.toArray[Any](attr.dataType).map(_.toString.toDouble) + // Construct equi-height histogram + val bins = ndvs.zipWithIndex.map { case (ndv, i) => + HistogramBin(endpoints(i), endpoints(i + 1), ndv) + } + val nonNullRows = rowCount - cs.nullCount.get + val histogram = Histogram(nonNullRows.toDouble / ndvs.length, bins) + cs.copy(histogram = Some(histogram)) + } + } + } http://git-wip-us.apache.org/repos/asf/spark/blob/8077bb04/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index e400975..4474919 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -695,10 +695,11 @@ case class DescribeColumnCommand( // Show column stats when EXTENDED or FORMATTED is specified. buffer += Row("min", cs.flatMap(_.min.map(_.toString)).getOrElse("NULL")) buffer += Row("max", cs.flatMap(_.max.map(_.toString)).getOrElse("NULL")) - buffer += Row("num_nulls", cs.map(_.nullCount.toString).getOrElse("NULL")) - buffer += Row("distinct_count", cs.map(_.distinctCount.toString).getOrElse("NULL")) - buffer += Row("avg_col_len", cs.map(_.avgLen.toString).getOrElse("NULL")) - buffer += Row("max_col_len", cs.map(_.maxLen.toString).getOrElse("NULL")) + buffer += Row("num_nulls", cs.flatMap(_.nullCount.map(_.toString)).getOrElse("NULL")) + buffer += Row("distinct_count", + cs.flatMap(_.distinctCount.map(_.toString)).getOrElse("NULL")) + buffer += Row("avg_col_len", cs.flatMap(_.avgLen.map(_.toString)).getOrElse("NULL")) + buffer += Row("max_col_len", cs.flatMap(_.maxLen.map(_.toString)).getOrElse("NULL")) val histDesc = for { c <- cs hist <- c.histogram http://git-wip-us.apache.org/repos/asf/spark/blob/8077bb04/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index b11e798..ed4ea02 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import scala.collection.mutable import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.CatalogColumnStat import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -95,7 +96,8 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared assert(fetchedStats2.get.sizeInBytes == 0) val expectedColStat = - "key" -> ColumnStat(0, None, None, 0, IntegerType.defaultSize, IntegerType.defaultSize) + "key" -> CatalogColumnStat(Some(0), None, None, Some(0), + Some(IntegerType.defaultSize), Some(IntegerType.defaultSize)) // There won't be histogram for empty column. Seq("true", "false").foreach { histogramEnabled => @@ -156,7 +158,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared Seq(stats, statsWithHgms).foreach { s => s.zip(df.schema).foreach { case ((k, v), field) => withClue(s"column $k with type ${field.dataType}") { - val roundtrip = ColumnStat.fromMap("table_is_foo", field, v.toMap(k, field.dataType)) + val roundtrip = CatalogColumnStat.fromMap("table_is_foo", field.name, v.toMap(k)) assert(roundtrip == Some(v)) } } @@ -187,7 +189,8 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared }.mkString(", ")) val expectedColStats = dataTypes.map { case (tpe, idx) => - (s"col$idx", ColumnStat(0, None, None, 1, tpe.defaultSize.toLong, tpe.defaultSize.toLong)) + (s"col$idx", CatalogColumnStat(Some(0), None, None, Some(1), + Some(tpe.defaultSize.toLong), Some(tpe.defaultSize.toLong))) } // There won't be histograms for null columns. http://git-wip-us.apache.org/repos/asf/spark/blob/8077bb04/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala index 65ccc19..bf4abb6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala @@ -24,8 +24,8 @@ import scala.collection.mutable import scala.util.Random import org.apache.spark.sql.catalyst.{QualifiedTableName, TableIdentifier} -import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTable, HiveTableRelation} -import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Histogram, HistogramBin, LogicalPlan} +import org.apache.spark.sql.catalyst.catalog.{CatalogColumnStat, CatalogStatistics, CatalogTable, HiveTableRelation} +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Histogram, HistogramBin, HistogramSerializer, LogicalPlan} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} @@ -67,18 +67,21 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils /** A mapping from column to the stats collected. */ protected val stats = mutable.LinkedHashMap( - "cbool" -> ColumnStat(2, Some(false), Some(true), 1, 1, 1), - "cbyte" -> ColumnStat(2, Some(1.toByte), Some(2.toByte), 1, 1, 1), - "cshort" -> ColumnStat(2, Some(1.toShort), Some(3.toShort), 1, 2, 2), - "cint" -> ColumnStat(2, Some(1), Some(4), 1, 4, 4), - "clong" -> ColumnStat(2, Some(1L), Some(5L), 1, 8, 8), - "cdouble" -> ColumnStat(2, Some(1.0), Some(6.0), 1, 8, 8), - "cfloat" -> ColumnStat(2, Some(1.0f), Some(7.0f), 1, 4, 4), - "cdecimal" -> ColumnStat(2, Some(Decimal(dec1)), Some(Decimal(dec2)), 1, 16, 16), - "cstring" -> ColumnStat(2, None, None, 1, 3, 3), - "cbinary" -> ColumnStat(2, None, None, 1, 3, 3), - "cdate" -> ColumnStat(2, Some(d1Internal), Some(d2Internal), 1, 4, 4), - "ctimestamp" -> ColumnStat(2, Some(t1Internal), Some(t2Internal), 1, 8, 8) + "cbool" -> CatalogColumnStat(Some(2), Some("false"), Some("true"), Some(1), Some(1), Some(1)), + "cbyte" -> CatalogColumnStat(Some(2), Some("1"), Some("2"), Some(1), Some(1), Some(1)), + "cshort" -> CatalogColumnStat(Some(2), Some("1"), Some("3"), Some(1), Some(2), Some(2)), + "cint" -> CatalogColumnStat(Some(2), Some("1"), Some("4"), Some(1), Some(4), Some(4)), + "clong" -> CatalogColumnStat(Some(2), Some("1"), Some("5"), Some(1), Some(8), Some(8)), + "cdouble" -> CatalogColumnStat(Some(2), Some("1.0"), Some("6.0"), Some(1), Some(8), Some(8)), + "cfloat" -> CatalogColumnStat(Some(2), Some("1.0"), Some("7.0"), Some(1), Some(4), Some(4)), + "cdecimal" -> CatalogColumnStat(Some(2), Some(dec1.toString), Some(dec2.toString), Some(1), + Some(16), Some(16)), + "cstring" -> CatalogColumnStat(Some(2), None, None, Some(1), Some(3), Some(3)), + "cbinary" -> CatalogColumnStat(Some(2), None, None, Some(1), Some(3), Some(3)), + "cdate" -> CatalogColumnStat(Some(2), Some(d1.toString), Some(d2.toString), Some(1), Some(4), + Some(4)), + "ctimestamp" -> CatalogColumnStat(Some(2), Some(t1.toString), Some(t2.toString), Some(1), + Some(8), Some(8)) ) /** @@ -110,6 +113,110 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils colStats } + val expectedSerializedColStats = Map( + "spark.sql.statistics.colStats.cbinary.avgLen" -> "3", + "spark.sql.statistics.colStats.cbinary.distinctCount" -> "2", + "spark.sql.statistics.colStats.cbinary.maxLen" -> "3", + "spark.sql.statistics.colStats.cbinary.nullCount" -> "1", + "spark.sql.statistics.colStats.cbinary.version" -> "1", + "spark.sql.statistics.colStats.cbool.avgLen" -> "1", + "spark.sql.statistics.colStats.cbool.distinctCount" -> "2", + "spark.sql.statistics.colStats.cbool.max" -> "true", + "spark.sql.statistics.colStats.cbool.maxLen" -> "1", + "spark.sql.statistics.colStats.cbool.min" -> "false", + "spark.sql.statistics.colStats.cbool.nullCount" -> "1", + "spark.sql.statistics.colStats.cbool.version" -> "1", + "spark.sql.statistics.colStats.cbyte.avgLen" -> "1", + "spark.sql.statistics.colStats.cbyte.distinctCount" -> "2", + "spark.sql.statistics.colStats.cbyte.max" -> "2", + "spark.sql.statistics.colStats.cbyte.maxLen" -> "1", + "spark.sql.statistics.colStats.cbyte.min" -> "1", + "spark.sql.statistics.colStats.cbyte.nullCount" -> "1", + "spark.sql.statistics.colStats.cbyte.version" -> "1", + "spark.sql.statistics.colStats.cdate.avgLen" -> "4", + "spark.sql.statistics.colStats.cdate.distinctCount" -> "2", + "spark.sql.statistics.colStats.cdate.max" -> "2016-05-09", + "spark.sql.statistics.colStats.cdate.maxLen" -> "4", + "spark.sql.statistics.colStats.cdate.min" -> "2016-05-08", + "spark.sql.statistics.colStats.cdate.nullCount" -> "1", + "spark.sql.statistics.colStats.cdate.version" -> "1", + "spark.sql.statistics.colStats.cdecimal.avgLen" -> "16", + "spark.sql.statistics.colStats.cdecimal.distinctCount" -> "2", + "spark.sql.statistics.colStats.cdecimal.max" -> "8.000000000000000000", + "spark.sql.statistics.colStats.cdecimal.maxLen" -> "16", + "spark.sql.statistics.colStats.cdecimal.min" -> "1.000000000000000000", + "spark.sql.statistics.colStats.cdecimal.nullCount" -> "1", + "spark.sql.statistics.colStats.cdecimal.version" -> "1", + "spark.sql.statistics.colStats.cdouble.avgLen" -> "8", + "spark.sql.statistics.colStats.cdouble.distinctCount" -> "2", + "spark.sql.statistics.colStats.cdouble.max" -> "6.0", + "spark.sql.statistics.colStats.cdouble.maxLen" -> "8", + "spark.sql.statistics.colStats.cdouble.min" -> "1.0", + "spark.sql.statistics.colStats.cdouble.nullCount" -> "1", + "spark.sql.statistics.colStats.cdouble.version" -> "1", + "spark.sql.statistics.colStats.cfloat.avgLen" -> "4", + "spark.sql.statistics.colStats.cfloat.distinctCount" -> "2", + "spark.sql.statistics.colStats.cfloat.max" -> "7.0", + "spark.sql.statistics.colStats.cfloat.maxLen" -> "4", + "spark.sql.statistics.colStats.cfloat.min" -> "1.0", + "spark.sql.statistics.colStats.cfloat.nullCount" -> "1", + "spark.sql.statistics.colStats.cfloat.version" -> "1", + "spark.sql.statistics.colStats.cint.avgLen" -> "4", + "spark.sql.statistics.colStats.cint.distinctCount" -> "2", + "spark.sql.statistics.colStats.cint.max" -> "4", + "spark.sql.statistics.colStats.cint.maxLen" -> "4", + "spark.sql.statistics.colStats.cint.min" -> "1", + "spark.sql.statistics.colStats.cint.nullCount" -> "1", + "spark.sql.statistics.colStats.cint.version" -> "1", + "spark.sql.statistics.colStats.clong.avgLen" -> "8", + "spark.sql.statistics.colStats.clong.distinctCount" -> "2", + "spark.sql.statistics.colStats.clong.max" -> "5", + "spark.sql.statistics.colStats.clong.maxLen" -> "8", + "spark.sql.statistics.colStats.clong.min" -> "1", + "spark.sql.statistics.colStats.clong.nullCount" -> "1", + "spark.sql.statistics.colStats.clong.version" -> "1", + "spark.sql.statistics.colStats.cshort.avgLen" -> "2", + "spark.sql.statistics.colStats.cshort.distinctCount" -> "2", + "spark.sql.statistics.colStats.cshort.max" -> "3", + "spark.sql.statistics.colStats.cshort.maxLen" -> "2", + "spark.sql.statistics.colStats.cshort.min" -> "1", + "spark.sql.statistics.colStats.cshort.nullCount" -> "1", + "spark.sql.statistics.colStats.cshort.version" -> "1", + "spark.sql.statistics.colStats.cstring.avgLen" -> "3", + "spark.sql.statistics.colStats.cstring.distinctCount" -> "2", + "spark.sql.statistics.colStats.cstring.maxLen" -> "3", + "spark.sql.statistics.colStats.cstring.nullCount" -> "1", + "spark.sql.statistics.colStats.cstring.version" -> "1", + "spark.sql.statistics.colStats.ctimestamp.avgLen" -> "8", + "spark.sql.statistics.colStats.ctimestamp.distinctCount" -> "2", + "spark.sql.statistics.colStats.ctimestamp.max" -> "2016-05-09 00:00:02.0", + "spark.sql.statistics.colStats.ctimestamp.maxLen" -> "8", + "spark.sql.statistics.colStats.ctimestamp.min" -> "2016-05-08 00:00:01.0", + "spark.sql.statistics.colStats.ctimestamp.nullCount" -> "1", + "spark.sql.statistics.colStats.ctimestamp.version" -> "1" + ) + + val expectedSerializedHistograms = Map( + "spark.sql.statistics.colStats.cbyte.histogram" -> + HistogramSerializer.serialize(statsWithHgms("cbyte").histogram.get), + "spark.sql.statistics.colStats.cshort.histogram" -> + HistogramSerializer.serialize(statsWithHgms("cshort").histogram.get), + "spark.sql.statistics.colStats.cint.histogram" -> + HistogramSerializer.serialize(statsWithHgms("cint").histogram.get), + "spark.sql.statistics.colStats.clong.histogram" -> + HistogramSerializer.serialize(statsWithHgms("clong").histogram.get), + "spark.sql.statistics.colStats.cdouble.histogram" -> + HistogramSerializer.serialize(statsWithHgms("cdouble").histogram.get), + "spark.sql.statistics.colStats.cfloat.histogram" -> + HistogramSerializer.serialize(statsWithHgms("cfloat").histogram.get), + "spark.sql.statistics.colStats.cdecimal.histogram" -> + HistogramSerializer.serialize(statsWithHgms("cdecimal").histogram.get), + "spark.sql.statistics.colStats.cdate.histogram" -> + HistogramSerializer.serialize(statsWithHgms("cdate").histogram.get), + "spark.sql.statistics.colStats.ctimestamp.histogram" -> + HistogramSerializer.serialize(statsWithHgms("ctimestamp").histogram.get) + ) + private val randomName = new Random(31) def getCatalogTable(tableName: String): CatalogTable = { @@ -151,7 +258,7 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils */ def checkColStats( df: DataFrame, - colStats: mutable.LinkedHashMap[String, ColumnStat]): Unit = { + colStats: mutable.LinkedHashMap[String, CatalogColumnStat]): Unit = { val tableName = "column_stats_test_" + randomName.nextInt(1000) withTable(tableName) { df.write.saveAsTable(tableName) @@ -161,14 +268,24 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils colStats.keys.mkString(", ")) // Validate statistics - val table = getCatalogTable(tableName) - assert(table.stats.isDefined) - assert(table.stats.get.colStats.size == colStats.size) - - colStats.foreach { case (k, v) => - withClue(s"column $k") { - assert(table.stats.get.colStats(k) == v) - } + validateColStats(tableName, colStats) + } + } + + /** + * Validate if the given catalog table has the provided statistics. + */ + def validateColStats( + tableName: String, + colStats: mutable.LinkedHashMap[String, CatalogColumnStat]): Unit = { + + val table = getCatalogTable(tableName) + assert(table.stats.isDefined) + assert(table.stats.get.colStats.size == colStats.size) + + colStats.foreach { case (k, v) => + withClue(s"column $k") { + assert(table.stats.get.colStats(k) == v) } } } @@ -215,12 +332,13 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils case catalogRel: HiveTableRelation => (catalogRel, catalogRel.tableMeta) case logicalRel: LogicalRelation => (logicalRel, logicalRel.catalogTable.get) }.head - val emptyColStat = ColumnStat(0, None, None, 0, 4, 4) + val emptyColStat = ColumnStat(Some(0), None, None, Some(0), Some(4), Some(4)) + val emptyCatalogColStat = CatalogColumnStat(Some(0), None, None, Some(0), Some(4), Some(4)) // Check catalog statistics assert(catalogTable.stats.isDefined) assert(catalogTable.stats.get.sizeInBytes == 0) assert(catalogTable.stats.get.rowCount == Some(0)) - assert(catalogTable.stats.get.colStats == Map("c1" -> emptyColStat)) + assert(catalogTable.stats.get.colStats == Map("c1" -> emptyCatalogColStat)) // Check relation statistics withSQLConf(SQLConf.CBO_ENABLED.key -> "true") { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org