Repository: spark
Updated Branches:
  refs/heads/master 28b645b1e -> 7486442fe


[SPARK-17073][SQL][FOLLOWUP] generate column-level statistics

## What changes were proposed in this pull request?
This pr adds some test cases for statistics: case sensitive column names, non 
ascii column names, refresh table, and also improves some documentation.

## How was this patch tested?
add test cases

Author: wangzhenhua <wangzhen...@huawei.com>

Closes #15360 from wzhfy/colStats2.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7486442f
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7486442f
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7486442f

Branch: refs/heads/master
Commit: 7486442fe0b70f2aea21d569604e71d7ddf19a77
Parents: 28b645b
Author: wangzhenhua <wangzhen...@huawei.com>
Authored: Fri Oct 14 21:18:49 2016 +0800
Committer: Wenchen Fan <wenc...@databricks.com>
Committed: Fri Oct 14 21:18:49 2016 +0800

----------------------------------------------------------------------
 .../command/AnalyzeColumnCommand.scala          |  53 ++---
 .../org/apache/spark/sql/internal/SQLConf.scala |   3 +-
 .../apache/spark/sql/hive/StatisticsSuite.scala | 198 ++++++++++++++++---
 3 files changed, 197 insertions(+), 57 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/7486442f/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala
index 7066378..4881387 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala
@@ -59,10 +59,12 @@ case class AnalyzeColumnCommand(
 
     def updateStats(catalogTable: CatalogTable, newTotalSize: Long): Unit = {
       val (rowCount, columnStats) = computeColStats(sparkSession, relation)
+      // We also update table-level stats in order to keep them consistent 
with column-level stats.
       val statistics = Statistics(
         sizeInBytes = newTotalSize,
         rowCount = Some(rowCount),
-        colStats = columnStats ++ 
catalogTable.stats.map(_.colStats).getOrElse(Map()))
+        // Newly computed column stats should override the existing ones.
+        colStats = catalogTable.stats.map(_.colStats).getOrElse(Map()) ++ 
columnStats)
       sessionState.catalog.alterTable(catalogTable.copy(stats = 
Some(statistics)))
       // Refresh the cached data source table in the catalog.
       sessionState.catalog.refreshTable(tableIdentWithDB)
@@ -90,8 +92,9 @@ case class AnalyzeColumnCommand(
       }
     }
     if (duplicatedColumns.nonEmpty) {
-      logWarning(s"Duplicated columns ${duplicatedColumns.mkString("(", ", ", 
")")} detected " +
-        s"when analyzing columns ${columnNames.mkString("(", ", ", ")")}, 
ignoring them.")
+      logWarning("Duplicate column names were deduplicated in `ANALYZE TABLE` 
statement. " +
+        s"Input columns: ${columnNames.mkString("(", ", ", ")")}. " +
+        s"Duplicate columns: ${duplicatedColumns.mkString("(", ", ", ")")}.")
     }
 
     // Collect statistics per column.
@@ -116,22 +119,24 @@ case class AnalyzeColumnCommand(
 }
 
 object ColumnStatStruct {
-  val zero = Literal(0, LongType)
-  val one = Literal(1, LongType)
+  private val zero = Literal(0, LongType)
+  private val one = Literal(1, LongType)
 
-  def numNulls(e: Expression): Expression = if (e.nullable) Sum(If(IsNull(e), 
one, zero)) else zero
-  def max(e: Expression): Expression = Max(e)
-  def min(e: Expression): Expression = Min(e)
-  def ndv(e: Expression, relativeSD: Double): Expression = {
+  private def numNulls(e: Expression): Expression = {
+    if (e.nullable) Sum(If(IsNull(e), one, zero)) else zero
+  }
+  private def max(e: Expression): Expression = Max(e)
+  private def min(e: Expression): Expression = Min(e)
+  private def ndv(e: Expression, relativeSD: Double): Expression = {
     // the approximate ndv should never be larger than the number of rows
     Least(Seq(HyperLogLogPlusPlus(e, relativeSD), Count(one)))
   }
-  def avgLength(e: Expression): Expression = Average(Length(e))
-  def maxLength(e: Expression): Expression = Max(Length(e))
-  def numTrues(e: Expression): Expression = Sum(If(e, one, zero))
-  def numFalses(e: Expression): Expression = Sum(If(Not(e), one, zero))
+  private def avgLength(e: Expression): Expression = Average(Length(e))
+  private def maxLength(e: Expression): Expression = Max(Length(e))
+  private def numTrues(e: Expression): Expression = Sum(If(e, one, zero))
+  private def numFalses(e: Expression): Expression = Sum(If(Not(e), one, zero))
 
-  def getStruct(exprs: Seq[Expression]): CreateStruct = {
+  private def getStruct(exprs: Seq[Expression]): CreateStruct = {
     CreateStruct(exprs.map { expr: Expression =>
       expr.transformUp {
         case af: AggregateFunction => af.toAggregateExpression()
@@ -139,19 +144,19 @@ object ColumnStatStruct {
     })
   }
 
-  def numericColumnStat(e: Expression, relativeSD: Double): Seq[Expression] = {
+  private def numericColumnStat(e: Expression, relativeSD: Double): 
Seq[Expression] = {
     Seq(numNulls(e), max(e), min(e), ndv(e, relativeSD))
   }
 
-  def stringColumnStat(e: Expression, relativeSD: Double): Seq[Expression] = {
+  private def stringColumnStat(e: Expression, relativeSD: Double): 
Seq[Expression] = {
     Seq(numNulls(e), avgLength(e), maxLength(e), ndv(e, relativeSD))
   }
 
-  def binaryColumnStat(e: Expression): Seq[Expression] = {
+  private def binaryColumnStat(e: Expression): Seq[Expression] = {
     Seq(numNulls(e), avgLength(e), maxLength(e))
   }
 
-  def booleanColumnStat(e: Expression): Seq[Expression] = {
+  private def booleanColumnStat(e: Expression): Seq[Expression] = {
     Seq(numNulls(e), numTrues(e), numFalses(e))
   }
 
@@ -162,14 +167,14 @@ object ColumnStatStruct {
     }
   }
 
-  def apply(e: Attribute, relativeSD: Double): CreateStruct = e.dataType match 
{
+  def apply(attr: Attribute, relativeSD: Double): CreateStruct = attr.dataType 
match {
     // Use aggregate functions to compute statistics we need.
-    case _: NumericType | TimestampType | DateType => 
getStruct(numericColumnStat(e, relativeSD))
-    case StringType => getStruct(stringColumnStat(e, relativeSD))
-    case BinaryType => getStruct(binaryColumnStat(e))
-    case BooleanType => getStruct(booleanColumnStat(e))
+    case _: NumericType | TimestampType | DateType => 
getStruct(numericColumnStat(attr, relativeSD))
+    case StringType => getStruct(stringColumnStat(attr, relativeSD))
+    case BinaryType => getStruct(binaryColumnStat(attr))
+    case BooleanType => getStruct(booleanColumnStat(attr))
     case otherType =>
       throw new AnalysisException("Analyzing columns is not supported for 
column " +
-        s"${e.name} of data type: ${e.dataType}.")
+        s"${attr.name} of data type: ${attr.dataType}.")
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/7486442f/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index e671604..c844765 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -578,7 +578,8 @@ object SQLConf {
   val NDV_MAX_ERROR =
     SQLConfigBuilder("spark.sql.statistics.ndv.maxError")
       .internal()
-      .doc("The maximum estimation error allowed in HyperLogLog++ algorithm.")
+      .doc("The maximum estimation error allowed in HyperLogLog++ algorithm 
when generating " +
+        "column level statistics.")
       .doubleConf
       .createWithDefault(0.05)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/7486442f/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
index 85228bb..c351063 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
@@ -21,7 +21,7 @@ import java.io.{File, PrintWriter}
 
 import scala.reflect.ClassTag
 
-import org.apache.spark.sql.{AnalysisException, QueryTest, Row, StatisticsTest}
+import org.apache.spark.sql._
 import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier}
 import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Statistics}
 import org.apache.spark.sql.execution.command.{AnalyzeTableCommand, DDLUtils}
@@ -358,53 +358,187 @@ class StatisticsSuite extends QueryTest with 
TestHiveSingleton with SQLTestUtils
     }
   }
 
-  test("generate column-level statistics and load them from hive metastore") {
+  private def getStatsBeforeAfterUpdate(isAnalyzeColumns: Boolean): 
(Statistics, Statistics) = {
+    val tableName = "tbl"
+    var statsBeforeUpdate: Statistics = null
+    var statsAfterUpdate: Statistics = null
+    withTable(tableName) {
+      val tableIndent = TableIdentifier(tableName, Some("default"))
+      val catalog = spark.sessionState.catalog.asInstanceOf[HiveSessionCatalog]
+      sql(s"CREATE TABLE $tableName (key int) USING PARQUET")
+      sql(s"INSERT INTO $tableName SELECT 1")
+      if (isAnalyzeColumns) {
+        sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS key")
+      } else {
+        sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS")
+      }
+      // Table lookup will make the table cached.
+      catalog.lookupRelation(tableIndent)
+      statsBeforeUpdate = catalog.getCachedDataSourceTable(tableIndent)
+        .asInstanceOf[LogicalRelation].catalogTable.get.stats.get
+
+      sql(s"INSERT INTO $tableName SELECT 2")
+      if (isAnalyzeColumns) {
+        sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS key")
+      } else {
+        sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS")
+      }
+      catalog.lookupRelation(tableIndent)
+      statsAfterUpdate = catalog.getCachedDataSourceTable(tableIndent)
+        .asInstanceOf[LogicalRelation].catalogTable.get.stats.get
+    }
+    (statsBeforeUpdate, statsAfterUpdate)
+  }
+
+  test("test refreshing table stats of cached data source table by `ANALYZE 
TABLE` statement") {
+    val (statsBeforeUpdate, statsAfterUpdate) = 
getStatsBeforeAfterUpdate(isAnalyzeColumns = false)
+
+    assert(statsBeforeUpdate.sizeInBytes > 0)
+    assert(statsBeforeUpdate.rowCount == Some(1))
+
+    assert(statsAfterUpdate.sizeInBytes > statsBeforeUpdate.sizeInBytes)
+    assert(statsAfterUpdate.rowCount == Some(2))
+  }
+
+  test("test refreshing column stats of cached data source table by `ANALYZE 
TABLE` statement") {
+    val (statsBeforeUpdate, statsAfterUpdate) = 
getStatsBeforeAfterUpdate(isAnalyzeColumns = true)
+
+    assert(statsBeforeUpdate.sizeInBytes > 0)
+    assert(statsBeforeUpdate.rowCount == Some(1))
+    StatisticsTest.checkColStat(
+      dataType = IntegerType,
+      colStat = statsBeforeUpdate.colStats("key"),
+      expectedColStat = ColumnStat(InternalRow(0L, 1, 1, 1L)),
+      rsd = spark.sessionState.conf.ndvMaxError)
+
+    assert(statsAfterUpdate.sizeInBytes > statsBeforeUpdate.sizeInBytes)
+    assert(statsAfterUpdate.rowCount == Some(2))
+    StatisticsTest.checkColStat(
+      dataType = IntegerType,
+      colStat = statsAfterUpdate.colStats("key"),
+      expectedColStat = ColumnStat(InternalRow(0L, 2, 1, 2L)),
+      rsd = spark.sessionState.conf.ndvMaxError)
+  }
+
+  private lazy val (testDataFrame, expectedColStatsSeq) = {
     import testImplicits._
 
     val intSeq = Seq(1, 2)
     val stringSeq = Seq("a", "bb")
+    val binarySeq = Seq("a", "bb").map(_.getBytes)
     val booleanSeq = Seq(true, false)
-
     val data = intSeq.indices.map { i =>
-      (intSeq(i), stringSeq(i), booleanSeq(i))
+      (intSeq(i), stringSeq(i), binarySeq(i), booleanSeq(i))
     }
-    val tableName = "table"
-    withTable(tableName) {
-      val df = data.toDF("c1", "c2", "c3")
-      df.write.format("parquet").saveAsTable(tableName)
-      val expectedColStatsSeq = df.schema.map { f =>
-        val colStat = f.dataType match {
-          case IntegerType =>
-            ColumnStat(InternalRow(0L, intSeq.max, intSeq.min, 
intSeq.distinct.length.toLong))
-          case StringType =>
-            ColumnStat(InternalRow(0L, stringSeq.map(_.length).sum / 
stringSeq.length.toDouble,
-              stringSeq.map(_.length).max.toInt, 
stringSeq.distinct.length.toLong))
-          case BooleanType =>
-            ColumnStat(InternalRow(0L, booleanSeq.count(_.equals(true)).toLong,
-              booleanSeq.count(_.equals(false)).toLong))
-        }
-        (f, colStat)
+    val df: DataFrame = data.toDF("c1", "c2", "c3", "c4")
+    val expectedColStatsSeq: Seq[(StructField, ColumnStat)] = df.schema.map { 
f =>
+      val colStat = f.dataType match {
+        case IntegerType =>
+          ColumnStat(InternalRow(0L, intSeq.max, intSeq.min, 
intSeq.distinct.length.toLong))
+        case StringType =>
+          ColumnStat(InternalRow(0L, stringSeq.map(_.length).sum / 
stringSeq.length.toDouble,
+            stringSeq.map(_.length).max.toInt, 
stringSeq.distinct.length.toLong))
+        case BinaryType =>
+          ColumnStat(InternalRow(0L, binarySeq.map(_.length).sum / 
binarySeq.length.toDouble,
+            binarySeq.map(_.length).max.toInt))
+        case BooleanType =>
+          ColumnStat(InternalRow(0L, booleanSeq.count(_.equals(true)).toLong,
+            booleanSeq.count(_.equals(false)).toLong))
       }
+      (f, colStat)
+    }
+    (df, expectedColStatsSeq)
+  }
+
+  private def checkColStats(
+      tableName: String,
+      isDataSourceTable: Boolean,
+      expectedColStatsSeq: Seq[(StructField, ColumnStat)]): Unit = {
+    val readback = spark.table(tableName)
+    val stats = readback.queryExecution.analyzed.collect {
+      case rel: MetastoreRelation =>
+        assert(!isDataSourceTable, "Expected a Hive serde table, but got a 
data source table")
+        rel.catalogTable.stats.get
+      case rel: LogicalRelation =>
+        assert(isDataSourceTable, "Expected a data source table, but got a 
Hive serde table")
+        rel.catalogTable.get.stats.get
+    }
+    assert(stats.length == 1)
+    val columnStats = stats.head.colStats
+    assert(columnStats.size == expectedColStatsSeq.length)
+    expectedColStatsSeq.foreach { case (field, expectedColStat) =>
+      StatisticsTest.checkColStat(
+        dataType = field.dataType,
+        colStat = columnStats(field.name),
+        expectedColStat = expectedColStat,
+        rsd = spark.sessionState.conf.ndvMaxError)
+    }
+  }
 
-      sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS c1, c2, 
c3")
-      val readback = spark.table(tableName)
-      val relations = readback.queryExecution.analyzed.collect { case rel: 
LogicalRelation =>
-        val columnStats = rel.catalogTable.get.stats.get.colStats
-        expectedColStatsSeq.foreach { case (field, expectedColStat) =>
-          assert(columnStats.contains(field.name))
-          val colStat = columnStats(field.name)
+  test("generate and load column-level stats for data source table") {
+    val dsTable = "dsTable"
+    withTable(dsTable) {
+      testDataFrame.write.format("parquet").saveAsTable(dsTable)
+      sql(s"ANALYZE TABLE $dsTable COMPUTE STATISTICS FOR COLUMNS c1, c2, c3, 
c4")
+      checkColStats(dsTable, isDataSourceTable = true, expectedColStatsSeq)
+    }
+  }
+
+  test("generate and load column-level stats for hive serde table") {
+    val hTable = "hTable"
+    val tmp = "tmp"
+    withTable(hTable, tmp) {
+      testDataFrame.write.format("parquet").saveAsTable(tmp)
+      sql(s"CREATE TABLE $hTable (c1 int, c2 string, c3 binary, c4 boolean) 
STORED AS TEXTFILE")
+      sql(s"INSERT INTO $hTable SELECT * FROM $tmp")
+      sql(s"ANALYZE TABLE $hTable COMPUTE STATISTICS FOR COLUMNS c1, c2, c3, 
c4")
+      checkColStats(hTable, isDataSourceTable = false, expectedColStatsSeq)
+    }
+  }
+
+  // When caseSensitive is on, for columns with only case difference, they are 
different columns
+  // and we should generate column stats for all of them.
+  private def checkCaseSensitiveColStats(columnName: String): Unit = {
+    val tableName = "tbl"
+    withTable(tableName) {
+      val column1 = columnName.toLowerCase
+      val column2 = columnName.toUpperCase
+      withSQLConf("spark.sql.caseSensitive" -> "true") {
+        sql(s"CREATE TABLE $tableName (`$column1` int, `$column2` double) 
USING PARQUET")
+        sql(s"INSERT INTO $tableName SELECT 1, 3.0")
+        sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS 
`$column1`, `$column2`")
+        val readback = spark.table(tableName)
+        val relations = readback.queryExecution.analyzed.collect { case rel: 
LogicalRelation =>
+          val columnStats = rel.catalogTable.get.stats.get.colStats
+          assert(columnStats.size == 2)
+          StatisticsTest.checkColStat(
+            dataType = IntegerType,
+            colStat = columnStats(column1),
+            expectedColStat = ColumnStat(InternalRow(0L, 1, 1, 1L)),
+            rsd = spark.sessionState.conf.ndvMaxError)
           StatisticsTest.checkColStat(
-            dataType = field.dataType,
-            colStat = colStat,
-            expectedColStat = expectedColStat,
+            dataType = DoubleType,
+            colStat = columnStats(column2),
+            expectedColStat = ColumnStat(InternalRow(0L, 3.0d, 3.0d, 1L)),
             rsd = spark.sessionState.conf.ndvMaxError)
+          rel
         }
-        rel
+        assert(relations.size == 1)
       }
-      assert(relations.size == 1)
     }
   }
 
+  test("check column statistics for case sensitive column names") {
+    checkCaseSensitiveColStats(columnName = "c1")
+  }
+
+  test("check column statistics for case sensitive non-ascii column names") {
+    // scalastyle:off
+    // non ascii characters are not allowed in the source code, so we disable 
the scalastyle.
+    checkCaseSensitiveColStats(columnName = "列c")
+    // scalastyle:on
+  }
+
   test("estimates the size of a test MetastoreRelation") {
     val df = sql("""SELECT * FROM src""")
     val sizes = df.queryExecution.analyzed.collect { case mr: 
MetastoreRelation =>


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to