sunchao commented on code in PR #43629:
URL: https://github.com/apache/spark/pull/43629#discussion_r1387500112


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala:
##########
@@ -101,56 +98,19 @@ case class AnalyzePartitionCommand(
       if (noscan) {
         Map.empty
       } else {
-        calculateRowCountsPerPartition(sparkSession, tableMeta, 
partitionValueSpec)
+        CommandUtils.calculateRowCountsPerPartition(sparkSession, tableMeta, 
partitionValueSpec)
       }
 
     // Update the metastore if newly computed statistics are different from 
those
     // recorded in the metastore.
-
-    val sizes = CommandUtils.calculateMultipleLocationSizes(sparkSession, 
tableMeta.identifier,
-      partitions.map(_.storage.locationUri))
-    val newPartitions = partitions.zipWithIndex.flatMap { case (p, idx) =>
-      val newRowCount = rowCounts.get(p.spec)
-      val newStats = CommandUtils.compareAndGetNewStats(p.stats, sizes(idx), 
newRowCount)
-      newStats.map(_ => p.copy(stats = newStats))
-    }
-
+    val (_, newPartitions) = CommandUtils.calculatePartitionStats(
+      sparkSession, tableMeta, partitions, Some(rowCounts))
     if (newPartitions.nonEmpty) {
       sessionState.catalog.alterPartitions(tableMeta.identifier, newPartitions)
     }
 
     Seq.empty[Row]
   }
 
-  private def calculateRowCountsPerPartition(

Review Comment:
   It's now used in `CommandUtils` so moving there and switch to use qualified 
`CommandUtils.calculateRowCountsPerPartition` in this class



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala:
##########
@@ -86,19 +91,31 @@ object CommandUtils extends Logging {
       // Calculate table size as a sum of the visible partitions. See 
SPARK-21079
       val partitions = 
sessionState.catalog.listPartitions(catalogTable.identifier)
       logInfo(s"Starting to calculate sizes for ${partitions.length} 
partitions.")
-      val paths = partitions.map(_.storage.locationUri)
-      val sizes = calculateMultipleLocationSizes(spark, 
catalogTable.identifier, paths)
-      val newPartitions = partitions.zipWithIndex.flatMap { case (p, idx) =>
-        val newStats = CommandUtils.compareAndGetNewStats(p.stats, sizes(idx), 
None)
-        newStats.map(_ => p.copy(stats = newStats))
-      }
+      val (sizes, newPartitions) = calculatePartitionStats(spark, 
catalogTable, partitions,
+        partitionRowCount)
       (sizes.sum, newPartitions)

Review Comment:
   Yea we can use `sizes.sum` and save a line here I think



##########
sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala:
##########
@@ -363,6 +363,85 @@ class StatisticsSuite extends StatisticsCollectionTestBase 
with TestHiveSingleto
     }
   }
 
+  test("SPARK-45731: update partition stats with ANALYZE TABLE") {
+    val tableName = "analyzeTable_part"
+
+    def queryStats(ds: String): Option[CatalogStatistics] = {
+      val partition =
+        spark.sessionState.catalog.getPartition(TableIdentifier(tableName), 
Map("ds" -> ds))
+      partition.stats
+    }
+
+    Seq(true, false).foreach { partitionStatsEnabled =>
+      withSQLConf(SQLConf.UPDATE_PART_STATS_IN_ANALYZE_TABLE_ENABLED.key ->
+          partitionStatsEnabled.toString) {
+        withTable(tableName) {
+          withTempPath { path =>
+            // Create a table with 3 partitions all located under a directory 
'path'
+            sql(
+              s"""
+                 |CREATE TABLE $tableName (key STRING, value STRING)
+                 |USING hive
+                 |PARTITIONED BY (ds STRING)
+                 |LOCATION '${path.toURI}'
+              """.stripMargin)
+
+            val partitionDates = List("2010-01-01", "2010-01-02", "2010-01-03")
+
+            partitionDates.foreach { ds =>
+              sql(s"ALTER TABLE $tableName ADD PARTITION (ds='$ds') LOCATION 
'$path/ds=$ds'")
+              sql("SELECT * FROM src").write.mode(SaveMode.Overwrite)
+                  .format("parquet").save(s"$path/ds=$ds")
+            }
+
+            assert(getCatalogTable(tableName).stats.isEmpty)
+            partitionDates.foreach { ds =>
+              assert(queryStats(ds).isEmpty)
+            }
+
+            sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS NOSCAN")
+
+            val expectedRowCount = 25
+
+            // Table size should also have been updated
+            assert(getTableStats(tableName).sizeInBytes > 0)
+            // Row count should NOT be updated with the `NOSCAN` option
+            assert(getTableStats(tableName).rowCount.isEmpty)
+
+            partitionDates.foreach { ds =>
+              val partStats = queryStats(ds)
+              if (partitionStatsEnabled) {
+                assert(partStats.nonEmpty)
+                assert(partStats.get.sizeInBytes > 0)
+                assert(partStats.get.rowCount.isEmpty)
+              } else {
+                assert(partStats.isEmpty)
+              }
+            }

Review Comment:
   Hmm I actually like the current way better since it has less duplicated 
code, e.g., `partitionDates.foreach`, `val partStats = queryStats(ds)`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to