sunchao commented on code in PR #43629:
URL: https://github.com/apache/spark/pull/43629#discussion_r1387500112
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala:
##########
@@ -101,56 +98,19 @@ case class AnalyzePartitionCommand(
if (noscan) {
Map.empty
} else {
- calculateRowCountsPerPartition(sparkSession, tableMeta,
partitionValueSpec)
+ CommandUtils.calculateRowCountsPerPartition(sparkSession, tableMeta,
partitionValueSpec)
}
// Update the metastore if newly computed statistics are different from
those
// recorded in the metastore.
-
- val sizes = CommandUtils.calculateMultipleLocationSizes(sparkSession,
tableMeta.identifier,
- partitions.map(_.storage.locationUri))
- val newPartitions = partitions.zipWithIndex.flatMap { case (p, idx) =>
- val newRowCount = rowCounts.get(p.spec)
- val newStats = CommandUtils.compareAndGetNewStats(p.stats, sizes(idx),
newRowCount)
- newStats.map(_ => p.copy(stats = newStats))
- }
-
+ val (_, newPartitions) = CommandUtils.calculatePartitionStats(
+ sparkSession, tableMeta, partitions, Some(rowCounts))
if (newPartitions.nonEmpty) {
sessionState.catalog.alterPartitions(tableMeta.identifier, newPartitions)
}
Seq.empty[Row]
}
- private def calculateRowCountsPerPartition(
Review Comment:
It's now used in `CommandUtils` so moving there and switch to use qualified
`CommandUtils.calculateRowCountsPerPartition` in this class
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala:
##########
@@ -86,19 +91,31 @@ object CommandUtils extends Logging {
// Calculate table size as a sum of the visible partitions. See
SPARK-21079
val partitions =
sessionState.catalog.listPartitions(catalogTable.identifier)
logInfo(s"Starting to calculate sizes for ${partitions.length}
partitions.")
- val paths = partitions.map(_.storage.locationUri)
- val sizes = calculateMultipleLocationSizes(spark,
catalogTable.identifier, paths)
- val newPartitions = partitions.zipWithIndex.flatMap { case (p, idx) =>
- val newStats = CommandUtils.compareAndGetNewStats(p.stats, sizes(idx),
None)
- newStats.map(_ => p.copy(stats = newStats))
- }
+ val (sizes, newPartitions) = calculatePartitionStats(spark,
catalogTable, partitions,
+ partitionRowCount)
(sizes.sum, newPartitions)
Review Comment:
Yea we can use `sizes.sum` and save a line here I think
##########
sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala:
##########
@@ -363,6 +363,85 @@ class StatisticsSuite extends StatisticsCollectionTestBase
with TestHiveSingleto
}
}
+ test("SPARK-45731: update partition stats with ANALYZE TABLE") {
+ val tableName = "analyzeTable_part"
+
+ def queryStats(ds: String): Option[CatalogStatistics] = {
+ val partition =
+ spark.sessionState.catalog.getPartition(TableIdentifier(tableName),
Map("ds" -> ds))
+ partition.stats
+ }
+
+ Seq(true, false).foreach { partitionStatsEnabled =>
+ withSQLConf(SQLConf.UPDATE_PART_STATS_IN_ANALYZE_TABLE_ENABLED.key ->
+ partitionStatsEnabled.toString) {
+ withTable(tableName) {
+ withTempPath { path =>
+ // Create a table with 3 partitions all located under a directory
'path'
+ sql(
+ s"""
+ |CREATE TABLE $tableName (key STRING, value STRING)
+ |USING hive
+ |PARTITIONED BY (ds STRING)
+ |LOCATION '${path.toURI}'
+ """.stripMargin)
+
+ val partitionDates = List("2010-01-01", "2010-01-02", "2010-01-03")
+
+ partitionDates.foreach { ds =>
+ sql(s"ALTER TABLE $tableName ADD PARTITION (ds='$ds') LOCATION
'$path/ds=$ds'")
+ sql("SELECT * FROM src").write.mode(SaveMode.Overwrite)
+ .format("parquet").save(s"$path/ds=$ds")
+ }
+
+ assert(getCatalogTable(tableName).stats.isEmpty)
+ partitionDates.foreach { ds =>
+ assert(queryStats(ds).isEmpty)
+ }
+
+ sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS NOSCAN")
+
+ val expectedRowCount = 25
+
+ // Table size should also have been updated
+ assert(getTableStats(tableName).sizeInBytes > 0)
+ // Row count should NOT be updated with the `NOSCAN` option
+ assert(getTableStats(tableName).rowCount.isEmpty)
+
+ partitionDates.foreach { ds =>
+ val partStats = queryStats(ds)
+ if (partitionStatsEnabled) {
+ assert(partStats.nonEmpty)
+ assert(partStats.get.sizeInBytes > 0)
+ assert(partStats.get.rowCount.isEmpty)
+ } else {
+ assert(partStats.isEmpty)
+ }
+ }
Review Comment:
Hmm I actually like the current way better since it has less duplicated
code, e.g., `partitionDates.foreach`, `val partStats = queryStats(ds)`.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]