This is an automated email from the ASF dual-hosted git repository.

codope pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/hudi.git


The following commit(s) were added to refs/heads/master by this push:
     new 17279589c16 [HUDI-6649] Fix column stat based data filtering for MOR 
(#9345)
17279589c16 is described below

commit 17279589c1678a750b09b5e29cda7ce57178d57b
Author: Lokesh Jain <[email protected]>
AuthorDate: Sun Aug 6 22:16:30 2023 +0530

    [HUDI-6649] Fix column stat based data filtering for MOR (#9345)
    
    Currently MOR snapshot relation does not use the column stats index for 
pruning
    the files in its queries. This PR aims to add support for pruning the file 
slices
    based on column stats in case of MOR.
    
    The approach is similar to what was used for 
org.apache.hudi.HoodieFileIndex#listFiles.
    For every partition path, logic for filtering the file slices has been 
moved to function org.apache.hudi.HoodieFileIndex#filterFileSlices. This 
function is called by various
    relations for pruning the file slices.
    
    ---------
    
    Co-authored-by: Sagar Sumit <[email protected]>
---
 .../scala/org/apache/hudi/HoodieBaseRelation.scala |  12 +-
 .../scala/org/apache/hudi/HoodieFileIndex.scala    | 196 ++++++----
 .../apache/hudi/MergeOnReadSnapshotRelation.scala  |   8 +-
 .../org/apache/hudi/TestHoodieFileIndex.scala      |   4 +-
 .../hudi/functional/ColumnStatIndexTestBase.scala  | 283 +++++++++++++++
 .../hudi/functional/TestColumnStatsIndex.scala     | 241 +------------
 .../functional/TestColumnStatsIndexWithSQL.scala   | 398 +++++++++++++++++++++
 .../hudi/procedure/TestClusteringProcedure.scala   |   4 +-
 8 files changed, 840 insertions(+), 306 deletions(-)

diff --git 
a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieBaseRelation.scala
 
b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieBaseRelation.scala
index a67d4463bf5..2f9579d629e 100644
--- 
a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieBaseRelation.scala
+++ 
b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieBaseRelation.scala
@@ -254,7 +254,7 @@ abstract class HoodieBaseRelation(val sqlContext: 
SQLContext,
    */
   lazy val fileIndex: HoodieFileIndex =
     HoodieFileIndex(sparkSession, metaClient, Some(tableStructSchema), 
optParams,
-      FileStatusCache.getOrCreate(sparkSession))
+      FileStatusCache.getOrCreate(sparkSession), shouldIncludeLogFiles())
 
   lazy val tableState: HoodieTableState = {
     val recordMergerImpls = 
ConfigUtils.split2List(getConfigValue(HoodieWriteConfig.RECORD_MERGER_IMPLS)).asScala.toList
@@ -343,7 +343,7 @@ abstract class HoodieBaseRelation(val sqlContext: 
SQLContext,
    */
   override final def needConversion: Boolean = false
 
-  override def inputFiles: Array[String] = 
fileIndex.allFiles.map(_.getPath.toUri.toString).toArray
+  override def inputFiles: Array[String] = 
fileIndex.allBaseFiles.map(_.getPath.toUri.toString).toArray
 
   /**
    * NOTE: DO NOT OVERRIDE THIS METHOD
@@ -644,6 +644,14 @@ abstract class HoodieBaseRelation(val sqlContext: 
SQLContext,
     optParams.getOrElse(config.key(),
       sqlContext.getConf(config.key(), 
defaultValueOption.getOrElse(config.defaultValue())))
   }
+
+  /**
+   * Determines if fileIndex should consider log files when filtering file 
slices. Defaults to false.
+   * The subclass can have their own implementation based on the table or 
relation type.
+   */
+  protected def shouldIncludeLogFiles(): Boolean = {
+    false
+  }
 }
 
 object HoodieBaseRelation extends SparkAdapterSupport {
diff --git 
a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieFileIndex.scala
 
b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieFileIndex.scala
index 66c4ae13b38..9791d39e280 100644
--- 
a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieFileIndex.scala
+++ 
b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieFileIndex.scala
@@ -22,12 +22,13 @@ import 
org.apache.hudi.HoodieFileIndex.{DataSkippingFailureMode, collectReferenc
 import org.apache.hudi.HoodieSparkConfUtils.getConfigValue
 import 
org.apache.hudi.common.config.TimestampKeyGeneratorConfig.{TIMESTAMP_INPUT_DATE_FORMAT,
 TIMESTAMP_OUTPUT_DATE_FORMAT}
 import org.apache.hudi.common.config.{HoodieMetadataConfig, TypedProperties}
-import org.apache.hudi.common.model.HoodieBaseFile
+import org.apache.hudi.common.model.{FileSlice, HoodieBaseFile, HoodieLogFile}
 import org.apache.hudi.common.table.HoodieTableMetaClient
 import org.apache.hudi.common.util.StringUtils
 import org.apache.hudi.exception.HoodieException
 import org.apache.hudi.keygen.{TimestampBasedAvroKeyGenerator, 
TimestampBasedKeyGenerator}
 import org.apache.hudi.metadata.HoodieMetadataPayload
+import org.apache.hudi.util.JFunction
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{And, Expression, Literal}
@@ -40,9 +41,9 @@ import org.apache.spark.sql.{Column, SparkSession}
 import org.apache.spark.unsafe.types.UTF8String
 
 import java.text.SimpleDateFormat
+import java.util.stream.Collectors
 import javax.annotation.concurrent.NotThreadSafe
 import scala.collection.JavaConverters._
-import scala.collection.mutable
 import scala.util.control.NonFatal
 import scala.util.{Failure, Success, Try}
 
@@ -74,7 +75,8 @@ case class HoodieFileIndex(spark: SparkSession,
                            metaClient: HoodieTableMetaClient,
                            schemaSpec: Option[StructType],
                            options: Map[String, String],
-                           @transient fileStatusCache: FileStatusCache = 
NoopCache)
+                           @transient fileStatusCache: FileStatusCache = 
NoopCache,
+                           includeLogFiles: Boolean = false)
   extends SparkHoodieTableFileIndex(
     spark = spark,
     metaClient = metaClient,
@@ -106,7 +108,7 @@ case class HoodieFileIndex(spark: SparkSession,
    *
    * @return List of FileStatus for base files
    */
-  def allFiles: Seq[FileStatus] = {
+  def allBaseFiles: Seq[FileStatus] = {
     getAllInputFileSlices.values.asScala.flatMap(_.asScala)
       .map(fs => fs.getBaseFile.orElse(null))
       .filter(_ != null)
@@ -114,6 +116,22 @@ case class HoodieFileIndex(spark: SparkSession,
       .toSeq
   }
 
+  /**
+   * Returns the FileStatus for all the base files and log files.
+   *
+   * @return List of FileStatus for base files and log files
+   */
+  private def allBaseFilesAndLogFiles: Seq[FileStatus] = {
+    getAllInputFileSlices.values.asScala.flatMap(_.asScala)
+      .flatMap(fs => {
+        val baseFileStatusOpt = 
getBaseFileStatus(Option.apply(fs.getBaseFile.orElse(null)))
+        val logFilesStatus = 
fs.getLogFiles.map[FileStatus](JFunction.toJavaFunction[HoodieLogFile, 
FileStatus](lf => lf.getFileStatus))
+        val files = 
logFilesStatus.collect(Collectors.toList[FileStatus]).asScala
+        baseFileStatusOpt.foreach(f => files.append(f))
+        files
+      }).toSeq
+  }
+
   /**
    * Invoked by Spark to fetch list of latest base files per partition.
    *
@@ -122,11 +140,55 @@ case class HoodieFileIndex(spark: SparkSession,
    * @return list of PartitionDirectory containing partition to base files 
mapping
    */
   override def listFiles(partitionFilters: Seq[Expression], dataFilters: 
Seq[Expression]): Seq[PartitionDirectory] = {
-    // Look up candidate files names in the col-stats index, if all of the 
following conditions are true
-    //    - Data-skipping is enabled
-    //    - Col-Stats Index is present
-    //    - List of predicates (filters) is present
-    val candidateFilesNamesOpt: Option[Set[String]] =
+    val prunedPartitionsAndFilteredFileSlices = filterFileSlices(dataFilters, 
partitionFilters).map {
+      case (partitionOpt, fileSlices) =>
+        val allCandidateFiles: Seq[FileStatus] = fileSlices.flatMap(fs => {
+          val baseFileStatusOpt = 
getBaseFileStatus(Option.apply(fs.getBaseFile.orElse(null)))
+          val logFilesStatus = if (includeLogFiles) {
+            
fs.getLogFiles.map[FileStatus](JFunction.toJavaFunction[HoodieLogFile, 
FileStatus](lf => lf.getFileStatus))
+          } else {
+            java.util.stream.Stream.empty()
+          }
+          val files = 
logFilesStatus.collect(Collectors.toList[FileStatus]).asScala
+          baseFileStatusOpt.foreach(f => files.append(f))
+          files
+        })
+
+        PartitionDirectory(InternalRow.fromSeq(partitionOpt.get.values), 
allCandidateFiles)
+    }
+
+    hasPushedDownPartitionPredicates = true
+
+    if (shouldReadAsPartitionedTable()) {
+      prunedPartitionsAndFilteredFileSlices
+    } else {
+      Seq(PartitionDirectory(InternalRow.empty, 
prunedPartitionsAndFilteredFileSlices.flatMap(_.files)))
+    }
+  }
+
+  /**
+   * The functions prunes the partition paths based on the input partition 
filters. For every partition path, the file
+   * slices are further filtered after querying metadata table based on the 
data filters.
+   *
+   * @param dataFilters data columns filters
+   * @param partitionFilters partition column filters
+   * @return A sequence of pruned partitions and corresponding filtered file 
slices
+   */
+  def filterFileSlices(dataFilters: Seq[Expression], partitionFilters: 
Seq[Expression])
+  : Seq[(Option[BaseHoodieTableFileIndex.PartitionPath], Seq[FileSlice])] = {
+
+    val prunedPartitionsAndFileSlices = 
getFileSlicesForPrunedPartitions(partitionFilters)
+
+    // If there are no data filters, return all the file slices.
+    // If there are no file slices, return empty list.
+    if (prunedPartitionsAndFileSlices.isEmpty || dataFilters.isEmpty) {
+      prunedPartitionsAndFileSlices
+    } else {
+      // Look up candidate files names in the col-stats index, if all of the 
following conditions are true
+      //    - Data-skipping is enabled
+      //    - Col-Stats Index is present
+      //    - List of predicates (filters) is present
+      val candidateFilesNamesOpt: Option[Set[String]] =
       lookupCandidateFilesInMetadataTable(dataFilters) match {
         case Success(opt) => opt
         case Failure(e) =>
@@ -134,74 +196,80 @@ case class HoodieFileIndex(spark: SparkSession,
 
           spark.sqlContext.getConf(DataSkippingFailureMode.configName, 
DataSkippingFailureMode.Fallback.value) match {
             case DataSkippingFailureMode.Fallback.value => Option.empty
-            case DataSkippingFailureMode.Strict.value   => throw new 
HoodieException(e);
+            case DataSkippingFailureMode.Strict.value => throw new 
HoodieException(e);
           }
       }
 
-    logDebug(s"Overlapping candidate files from Column Stats Index: 
${candidateFilesNamesOpt.getOrElse(Set.empty)}")
-
-    var totalFileSize = 0
-    var candidateFileSize = 0
+      logDebug(s"Overlapping candidate files from Column Stats Index: 
${candidateFilesNamesOpt.getOrElse(Set.empty)}")
+
+      var totalFileSliceSize = 0
+      var candidateFileSliceSize = 0
+
+      val prunedPartitionsAndFilteredFileSlices = 
prunedPartitionsAndFileSlices.map {
+        case (partitionOpt, fileSlices) =>
+          // Filter in candidate files based on the col-stats index lookup
+          val candidateFileSlices: Seq[FileSlice] = {
+            fileSlices.filter(fs => {
+              val fileSliceFiles = 
fs.getLogFiles.map[String](JFunction.toJavaFunction[HoodieLogFile, String](lf 
=> lf.getPath.getName))
+                .collect(Collectors.toSet[String])
+              val baseFileStatusOpt = 
getBaseFileStatus(Option.apply(fs.getBaseFile.orElse(null)))
+              baseFileStatusOpt.exists(f => 
fileSliceFiles.add(f.getPath.getName))
+              // NOTE: This predicate is true when {@code Option} is empty
+              candidateFilesNamesOpt.forall(files => files.exists(elem => 
fileSliceFiles.contains(elem)))
+            })
+          }
 
-    // Prune the partition path by the partition filters
-    // NOTE: Non-partitioned tables are assumed to consist from a single 
partition
-    //       encompassing the whole table
-    val prunedPartitions = listMatchingPartitionPaths(partitionFilters)
-    val listedPartitions = getInputFileSlices(prunedPartitions: 
_*).asScala.toSeq.map {
-      case (partition, fileSlices) =>
-        val baseFileStatuses: Seq[FileStatus] = getBaseFileStatus(fileSlices
-          .asScala
-          .map(fs => fs.getBaseFile.orElse(null))
-          .filter(_ != null))
-
-        // Filter in candidate files based on the col-stats index lookup
-        val candidateFiles = baseFileStatuses.filter(fs =>
-          // NOTE: This predicate is true when {@code Option} is empty
-          candidateFilesNamesOpt.forall(_.contains(fs.getPath.getName)))
-
-        totalFileSize += baseFileStatuses.size
-        candidateFileSize += candidateFiles.size
-        PartitionDirectory(InternalRow.fromSeq(partition.values), 
candidateFiles)
-    }
+          totalFileSliceSize += fileSlices.size
+          candidateFileSliceSize += candidateFileSlices.size
+          (partitionOpt, candidateFileSlices)
+      }
 
-    val skippingRatio =
-      if (!areAllFileSlicesCached) -1
-      else if (allFiles.nonEmpty && totalFileSize > 0) (totalFileSize - 
candidateFileSize) / totalFileSize.toDouble
-      else 0
+      val skippingRatio =
+        if (!areAllFileSlicesCached) -1
+        else if (getAllFiles().nonEmpty && totalFileSliceSize > 0)
+          (totalFileSliceSize - candidateFileSliceSize) / 
totalFileSliceSize.toDouble
+        else 0
 
-    logInfo(s"Total base files: $totalFileSize; " +
-      s"candidate files after data skipping: $candidateFileSize; " +
-      s"skipping percentage $skippingRatio")
+      logInfo(s"Total file slices: $totalFileSliceSize; " +
+        s"candidate file slices after data skipping: $candidateFileSliceSize; 
" +
+        s"skipping percentage $skippingRatio")
 
-    hasPushedDownPartitionPredicates = true
+      hasPushedDownPartitionPredicates = true
 
-    if (shouldReadAsPartitionedTable()) {
-      listedPartitions
-    } else {
-      Seq(PartitionDirectory(InternalRow.empty, 
listedPartitions.flatMap(_.files)))
+      prunedPartitionsAndFilteredFileSlices
     }
   }
 
+  def getFileSlicesForPrunedPartitions(partitionFilters: Seq[Expression]) : 
Seq[(Option[BaseHoodieTableFileIndex.PartitionPath], Seq[FileSlice])] = {
+    // Prune the partition path by the partition filters
+    // NOTE: Non-partitioned tables are assumed to consist from a single 
partition
+    //       encompassing the whole table
+    val prunedPartitions = listMatchingPartitionPaths (partitionFilters)
+    getInputFileSlices(prunedPartitions: _*).asScala.toSeq.map(
+      { case (partition, fileSlices) => (Option.apply(partition), 
fileSlices.asScala) })
+  }
+
   /**
-   * In the fast bootstrap read code path, it gets the file status for the 
bootstrap base files instead of
-   * skeleton files.
+   * In the fast bootstrap read code path, it gets the file status for the 
bootstrap base file instead of
+   * skeleton file. Returns file status for the base file if available.
    */
-  private def getBaseFileStatus(baseFiles: mutable.Buffer[HoodieBaseFile]): 
mutable.Buffer[FileStatus] = {
-    if (shouldFastBootstrap) {
-     baseFiles.map(f =>
-        if (f.getBootstrapBaseFile.isPresent) {
-         f.getBootstrapBaseFile.get().getFileStatus
+  private def getBaseFileStatus(baseFileOpt: Option[HoodieBaseFile]): 
Option[FileStatus] = {
+    baseFileOpt.map(baseFile => {
+      if (shouldFastBootstrap) {
+        if (baseFile.getBootstrapBaseFile.isPresent) {
+          baseFile.getBootstrapBaseFile.get().getFileStatus
         } else {
-          f.getFileStatus
-        })
-    } else {
-      baseFiles.map(_.getFileStatus)
-    }
+          baseFile.getFileStatus
+        }
+      } else {
+        baseFile.getFileStatus
+      }
+    })
   }
 
   private def lookupFileNamesMissingFromIndex(allIndexedFileNames: 
Set[String]) = {
-    val allBaseFileNames = allFiles.map(f => f.getPath.getName).toSet
-    allBaseFileNames -- allIndexedFileNames
+    val allFileNames = getAllFiles().map(f => f.getPath.getName).toSet
+    allFileNames -- allIndexedFileNames
   }
 
   /**
@@ -229,7 +297,7 @@ case class HoodieFileIndex(spark: SparkSession,
       validateConfig()
       Option.empty
     } else if (recordLevelIndex.isIndexApplicable(queryFilters)) {
-      Option.apply(recordLevelIndex.getCandidateFiles(allFiles, queryFilters))
+      Option.apply(recordLevelIndex.getCandidateFiles(getAllFiles(), 
queryFilters))
     } else if (!columnStatsIndex.isIndexAvailable || queryFilters.isEmpty || 
queryReferencedColumns.isEmpty) {
       validateConfig()
       Option.empty
@@ -281,8 +349,12 @@ case class HoodieFileIndex(spark: SparkSession,
     hasPushedDownPartitionPredicates = false
   }
 
+  private def getAllFiles(): Seq[FileStatus] = {
+    if (includeLogFiles) allBaseFilesAndLogFiles else allBaseFiles
+  }
+
   override def inputFiles: Array[String] =
-    allFiles.map(_.getPath.toString).toArray
+    getAllFiles().map(_.getPath.toString).toArray
 
   override def sizeInBytes: Long = getTotalCachedFilesSize
 
diff --git 
a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/MergeOnReadSnapshotRelation.scala
 
b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/MergeOnReadSnapshotRelation.scala
index e8468f0a7a1..8e35a9a8665 100644
--- 
a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/MergeOnReadSnapshotRelation.scala
+++ 
b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/MergeOnReadSnapshotRelation.scala
@@ -52,6 +52,10 @@ case class MergeOnReadSnapshotRelation(override val 
sqlContext: SQLContext,
   override def updatePrunedDataSchema(prunedSchema: StructType): 
MergeOnReadSnapshotRelation =
     this.copy(prunedDataSchema = Some(prunedSchema))
 
+  override protected def shouldIncludeLogFiles(): Boolean = {
+    true
+  }
+
 }
 
 /**
@@ -215,8 +219,8 @@ abstract class BaseMergeOnReadSnapshotRelation(sqlContext: 
SQLContext,
       HoodieFileIndex.convertFilterForTimestampKeyGenerator(metaClient, 
partitionFilters)
 
     if (globPaths.isEmpty) {
-      val fileSlices = fileIndex.listFileSlices(convertedPartitionFilters)
-      buildSplits(fileSlices.values.flatten.toSeq)
+      val fileSlices = fileIndex.filterFileSlices(dataFilters, 
convertedPartitionFilters).flatMap(s => s._2)
+      buildSplits(fileSlices)
     } else {
       val fileSlices = listLatestFileSlices(globPaths, partitionFilters, 
dataFilters)
       buildSplits(fileSlices)
diff --git 
a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestHoodieFileIndex.scala
 
b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestHoodieFileIndex.scala
index 157f4fea854..cd9dbc8df79 100644
--- 
a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestHoodieFileIndex.scala
+++ 
b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestHoodieFileIndex.scala
@@ -479,7 +479,7 @@ class TestHoodieFileIndex extends HoodieSparkClientTestBase 
with ScalaAssertionS
     val expectedListedFiles = if (enablePartitionPathPrefixAnalysis) {
       getFileCountInPartitionPaths("2021/03/01/0", "2021/03/01/1", 
"2021/03/01/2")
     } else {
-      fileIndex.allFiles.length
+      fileIndex.allBaseFiles.length
     }
 
     assertEquals(expectedListedFiles, perPartitionFilesSeq.map(_.size).sum)
@@ -592,7 +592,7 @@ class TestHoodieFileIndex extends HoodieSparkClientTestBase 
with ScalaAssertionS
       val expectedPartitionPaths = if (testCase._3) {
         testCase._4.map(e => e._1 + "/" + e._2)
       } else {
-        fileIndex.allFiles
+        fileIndex.allBaseFiles
           .map(file => extractPartitionPathFromFilePath(file.getPath))
           .distinct
           .sorted
diff --git 
a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/ColumnStatIndexTestBase.scala
 
b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/ColumnStatIndexTestBase.scala
new file mode 100644
index 00000000000..6a9efb3371d
--- /dev/null
+++ 
b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/ColumnStatIndexTestBase.scala
@@ -0,0 +1,283 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hudi.functional
+
+import org.apache.hadoop.fs.{LocatedFileStatus, Path}
+import org.apache.hudi.ColumnStatsIndexSupport.composeIndexSchema
+import org.apache.hudi.HoodieConversionUtils.toProperties
+import org.apache.hudi.common.config.{HoodieMetadataConfig, 
HoodieStorageConfig}
+import org.apache.hudi.common.model.HoodieTableType
+import org.apache.hudi.common.table.HoodieTableMetaClient
+import org.apache.hudi.functional.ColumnStatIndexTestBase.ColumnStatsTestCase
+import org.apache.hudi.testutils.HoodieSparkClientTestBase
+import org.apache.hudi.{ColumnStatsIndexSupport, DataSourceWriteOptions}
+import org.apache.spark.sql._
+import org.apache.spark.sql.functions.typedLit
+import org.apache.spark.sql.types._
+import org.junit.jupiter.api.Assertions.assertEquals
+import org.junit.jupiter.api._
+import org.junit.jupiter.params.provider.Arguments
+
+import java.math.BigInteger
+import java.sql.{Date, Timestamp}
+import scala.collection.JavaConverters._
+import scala.util.Random
+
+@Tag("functional")
+class ColumnStatIndexTestBase extends HoodieSparkClientTestBase {
+  var spark: SparkSession = _
+  var dfList: Seq[DataFrame] = Seq()
+
+  val sourceTableSchema =
+    new StructType()
+      .add("c1", IntegerType)
+      .add("c2", StringType)
+      .add("c3", DecimalType(9, 3))
+      .add("c4", TimestampType)
+      .add("c5", ShortType)
+      .add("c6", DateType)
+      .add("c7", BinaryType)
+      .add("c8", ByteType)
+
+  @BeforeEach
+  override def setUp() {
+    initPath()
+    initSparkContexts()
+    initFileSystem()
+
+    setTableName("hoodie_test")
+    initMetaClient()
+
+    spark = sqlContext.sparkSession
+  }
+
+  @AfterEach
+  override def tearDown() = {
+    cleanupFileSystem()
+    cleanupSparkContexts()
+  }
+
+  protected def doWriteAndValidateColumnStats(testCase: ColumnStatsTestCase,
+                                            metadataOpts: Map[String, String],
+                                            hudiOpts: Map[String, String],
+                                            dataSourcePath: String,
+                                            expectedColStatsSourcePath: String,
+                                            operation: String,
+                                            saveMode: SaveMode,
+                                            shouldValidate: Boolean = true): 
Unit = {
+    val sourceJSONTablePath = 
getClass.getClassLoader.getResource(dataSourcePath).toString
+
+    // NOTE: Schema here is provided for validation that the input date is in 
the appropriate format
+    val inputDF = 
spark.read.schema(sourceTableSchema).json(sourceJSONTablePath)
+
+    inputDF
+      .sort("c1")
+      .repartition(4, new Column("c1"))
+      .write
+      .format("hudi")
+      .options(hudiOpts)
+      .option(HoodieStorageConfig.PARQUET_MAX_FILE_SIZE.key, 10 * 1024)
+      .option(DataSourceWriteOptions.OPERATION.key, operation)
+      .mode(saveMode)
+      .save(basePath)
+    dfList = dfList :+ inputDF
+
+    metaClient = HoodieTableMetaClient.reload(metaClient)
+
+    if (shouldValidate) {
+      // Currently, routine manually validating the column stats (by actually 
reading every column of every file)
+      // only supports parquet files. Therefore we skip such validation when 
delta-log files are present, and only
+      // validate in following cases: (1) COW: all operations; (2) MOR: insert 
only.
+      val shouldValidateColumnStatsManually = testCase.tableType == 
HoodieTableType.COPY_ON_WRITE ||
+        operation.equals(DataSourceWriteOptions.INSERT_OPERATION_OPT_VAL)
+
+      validateColumnStatsIndex(
+        testCase, metadataOpts, expectedColStatsSourcePath, 
shouldValidateColumnStatsManually)
+    }
+  }
+
+  protected def buildColumnStatsTableManually(tablePath: String,
+                                            includedCols: Seq[String],
+                                            indexedCols: Seq[String],
+                                            indexSchema: StructType): 
DataFrame = {
+    val files = {
+      val it = fs.listFiles(new Path(tablePath), true)
+      var seq = Seq[LocatedFileStatus]()
+      while (it.hasNext) {
+        seq = seq :+ it.next()
+      }
+      seq.filter(fs => fs.getPath.getName.endsWith(".parquet"))
+    }
+
+    spark.createDataFrame(
+      files.flatMap(file => {
+        val df = 
spark.read.schema(sourceTableSchema).parquet(file.getPath.toString)
+        val exprs: Seq[String] =
+          s"'${typedLit(file.getPath.getName)}' AS file" +:
+            s"sum(1) AS valueCount" +:
+            df.columns
+              .filter(col => includedCols.contains(col))
+              .filter(col => indexedCols.contains(col))
+              .flatMap(col => {
+                val minColName = s"${col}_minValue"
+                val maxColName = s"${col}_maxValue"
+                if (indexedCols.contains(col)) {
+                  Seq(
+                    s"min($col) AS $minColName",
+                    s"max($col) AS $maxColName",
+                    s"sum(cast(isnull($col) AS long)) AS ${col}_nullCount"
+                  )
+                } else {
+                  Seq(
+                    s"null AS $minColName",
+                    s"null AS $maxColName",
+                    s"null AS ${col}_nullCount"
+                  )
+                }
+              })
+
+        df.selectExpr(exprs: _*)
+          .collect()
+      }).asJava,
+      indexSchema
+    )
+  }
+
+  protected def validateColumnStatsIndex(testCase: ColumnStatsTestCase,
+                                       metadataOpts: Map[String, String],
+                                       expectedColStatsSourcePath: String,
+                                       validateColumnStatsManually: Boolean): 
Unit = {
+    val metadataConfig = HoodieMetadataConfig.newBuilder()
+      .fromProperties(toProperties(metadataOpts))
+      .build()
+
+    val columnStatsIndex = new ColumnStatsIndexSupport(spark, 
sourceTableSchema, metadataConfig, metaClient)
+
+    val indexedColumns: Set[String] = {
+      val customIndexedColumns = 
metadataConfig.getColumnsEnabledForColumnStatsIndex
+      if (customIndexedColumns.isEmpty) {
+        sourceTableSchema.fieldNames.toSet
+      } else {
+        customIndexedColumns.asScala.toSet
+      }
+    }
+    val (expectedColStatsSchema, _) = 
composeIndexSchema(sourceTableSchema.fieldNames, indexedColumns, 
sourceTableSchema)
+    val validationSortColumns = Seq("c1_maxValue", "c1_minValue", 
"c2_maxValue", "c2_minValue")
+
+    columnStatsIndex.loadTransposed(sourceTableSchema.fieldNames, 
testCase.shouldReadInMemory) { transposedColStatsDF =>
+      // Match against expected column stats table
+      val expectedColStatsIndexTableDf =
+        spark.read
+          .schema(expectedColStatsSchema)
+          
.json(getClass.getClassLoader.getResource(expectedColStatsSourcePath).toString)
+
+      assertEquals(expectedColStatsIndexTableDf.schema, 
transposedColStatsDF.schema)
+      // NOTE: We have to drop the `fileName` column as it contains 
semi-random components
+      //       that we can't control in this test. Nevertheless, since we 
manually verify composition of the
+      //       ColStats Index by reading Parquet footers from individual 
Parquet files, this is not an issue
+      assertEquals(asJson(sort(expectedColStatsIndexTableDf, 
validationSortColumns)),
+        asJson(sort(transposedColStatsDF.drop("fileName"), 
validationSortColumns)))
+
+      if (validateColumnStatsManually) {
+        // TODO(HUDI-4557): support validation of column stats of avro log 
files
+        // Collect Column Stats manually (reading individual Parquet files)
+        val manualColStatsTableDF =
+        buildColumnStatsTableManually(basePath, sourceTableSchema.fieldNames, 
sourceTableSchema.fieldNames, expectedColStatsSchema)
+
+        assertEquals(asJson(sort(manualColStatsTableDF, 
validationSortColumns)),
+          asJson(sort(transposedColStatsDF, validationSortColumns)))
+      }
+    }
+  }
+
+  protected def generateRandomDataFrame(spark: SparkSession): DataFrame = {
+    val sourceTableSchema =
+      new StructType()
+        .add("c1", IntegerType)
+        .add("c2", StringType)
+        // NOTE: We're testing different values for precision of the decimal 
to make sure
+        //       we execute paths bearing different underlying representations 
in Parquet
+        // REF: 
https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#DECIMAL
+        .add("c3a", DecimalType(9, 3))
+        .add("c3b", DecimalType(10, 3))
+        .add("c3c", DecimalType(20, 3))
+        .add("c4", TimestampType)
+        .add("c5", ShortType)
+        .add("c6", DateType)
+        .add("c7", BinaryType)
+        .add("c8", ByteType)
+
+    val rdd = spark.sparkContext.parallelize(0 to 1000, 1).map { item =>
+      val c1 = Integer.valueOf(item)
+      val c2 = Random.nextString(10)
+      val c3a = java.math.BigDecimal.valueOf(Random.nextInt() % (1 << 24), 3)
+      val c3b = java.math.BigDecimal.valueOf(Random.nextLong() % (1L << 32), 3)
+      // NOTE: We cap it at 2^64 to make sure we're not exceeding target 
decimal's range
+      val c3c = new java.math.BigDecimal(new BigInteger(64, new 
java.util.Random()), 3)
+      val c4 = new Timestamp(System.currentTimeMillis())
+      val c5 = java.lang.Short.valueOf(s"${(item + 16) / 10}")
+      val c6 = Date.valueOf(s"${2020}-${item % 11 + 1}-${item % 28 + 1}")
+      val c7 = Array(item).map(_.toByte)
+      val c8 = java.lang.Byte.valueOf("9")
+
+      RowFactory.create(c1, c2, c3a, c3b, c3c, c4, c5, c6, c7, c8)
+    }
+
+    spark.createDataFrame(rdd, sourceTableSchema)
+  }
+
+  protected def asJson(df: DataFrame) =
+    df.toJSON
+      .select("value")
+      .collect()
+      .toSeq
+      .map(_.getString(0))
+      .mkString("\n")
+
+  protected def sort(df: DataFrame): DataFrame = {
+    sort(df, Seq("c1_maxValue", "c1_minValue"))
+  }
+
+  private def sort(df: DataFrame, sortColumns: Seq[String]): DataFrame = {
+    val sortedCols = df.columns.sorted
+    // Sort dataset by specified columns (to minimize non-determinism in case 
multiple files have the same
+    // value of the first column)
+    df.select(sortedCols.head, sortedCols.tail: _*)
+      .sort(sortColumns.head, sortColumns.tail: _*)
+  }
+}
+
+object ColumnStatIndexTestBase {
+
+  case class ColumnStatsTestCase(tableType: HoodieTableType, 
shouldReadInMemory: Boolean)
+
+  def testMetadataColumnStatsIndexParams: java.util.stream.Stream[Arguments] = 
{
+    
java.util.stream.Stream.of(HoodieTableType.values().toStream.flatMap(tableType 
=>
+      Seq(Arguments.arguments(ColumnStatsTestCase(tableType, 
shouldReadInMemory = true)),
+        Arguments.arguments(ColumnStatsTestCase(tableType, shouldReadInMemory 
= false)))
+    ): _*)
+  }
+
+  def testMetadataColumnStatsIndexParamsForMOR: 
java.util.stream.Stream[Arguments] = {
+    java.util.stream.Stream.of(
+      
Seq(Arguments.arguments(ColumnStatsTestCase(HoodieTableType.MERGE_ON_READ, 
shouldReadInMemory = true)),
+        Arguments.arguments(ColumnStatsTestCase(HoodieTableType.MERGE_ON_READ, 
shouldReadInMemory = false)))
+    : _*)
+  }
+}
diff --git 
a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestColumnStatsIndex.scala
 
b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestColumnStatsIndex.scala
index a30a72f9bd3..ac83cf81918 100644
--- 
a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestColumnStatsIndex.scala
+++ 
b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestColumnStatsIndex.scala
@@ -19,7 +19,7 @@
 package org.apache.hudi.functional
 
 import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.fs.{LocatedFileStatus, Path}
+import org.apache.hadoop.fs.Path
 import org.apache.hudi.ColumnStatsIndexSupport.composeIndexSchema
 import org.apache.hudi.DataSourceWriteOptions.{PRECOMBINE_FIELD, 
RECORDKEY_FIELD}
 import org.apache.hudi.HoodieConversionUtils.toProperties
@@ -28,57 +28,22 @@ import org.apache.hudi.common.model.HoodieTableType
 import org.apache.hudi.common.table.{HoodieTableConfig, HoodieTableMetaClient}
 import org.apache.hudi.common.util.ParquetUtils
 import org.apache.hudi.config.HoodieWriteConfig
-import org.apache.hudi.functional.TestColumnStatsIndex.ColumnStatsTestCase
-import org.apache.hudi.testutils.HoodieSparkClientTestBase
+import org.apache.hudi.functional.ColumnStatIndexTestBase.ColumnStatsTestCase
 import org.apache.hudi.{ColumnStatsIndexSupport, DataSourceWriteOptions}
 import org.apache.spark.sql._
 import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
 import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, 
GreaterThan, Literal, Or}
-import org.apache.spark.sql.functions.typedLit
 import 
org.apache.spark.sql.hudi.DataSkippingUtils.translateIntoColumnStatsIndexFilterExpr
 import org.apache.spark.sql.types._
 import org.junit.jupiter.api.Assertions.{assertEquals, assertNotNull, 
assertTrue}
 import org.junit.jupiter.api._
 import org.junit.jupiter.params.ParameterizedTest
-import org.junit.jupiter.params.provider.{Arguments, EnumSource, MethodSource, 
ValueSource}
+import org.junit.jupiter.params.provider.{EnumSource, MethodSource, 
ValueSource}
 
-import java.math.BigInteger
-import java.sql.{Date, Timestamp}
 import scala.collection.JavaConverters._
-import scala.util.Random
 
 @Tag("functional")
-class TestColumnStatsIndex extends HoodieSparkClientTestBase {
-  var spark: SparkSession = _
-
-  val sourceTableSchema =
-    new StructType()
-      .add("c1", IntegerType)
-      .add("c2", StringType)
-      .add("c3", DecimalType(9, 3))
-      .add("c4", TimestampType)
-      .add("c5", ShortType)
-      .add("c6", DateType)
-      .add("c7", BinaryType)
-      .add("c8", ByteType)
-
-  @BeforeEach
-  override def setUp() {
-    initPath()
-    initSparkContexts()
-    initFileSystem()
-
-    setTableName("hoodie_test")
-    initMetaClient()
-
-    spark = sqlContext.sparkSession
-  }
-
-  @AfterEach
-  override def tearDown() = {
-    cleanupFileSystem()
-    cleanupSparkContexts()
-  }
+class TestColumnStatsIndex extends ColumnStatIndexTestBase {
 
   @ParameterizedTest
   @MethodSource(Array("testMetadataColumnStatsIndexParams"))
@@ -125,6 +90,7 @@ class TestColumnStatsIndex extends HoodieSparkClientTestBase 
{
       saveMode = SaveMode.Append)
   }
 
+
   @ParameterizedTest
   @EnumSource(classOf[HoodieTableType])
   def testMetadataColumnStatsIndexValueCount(tableType: HoodieTableType): Unit 
= {
@@ -454,201 +420,4 @@ class TestColumnStatsIndex extends 
HoodieSparkClientTestBase {
       
assertTrue(r.getMinValue.asInstanceOf[Comparable[Object]].compareTo(r.getMaxValue.asInstanceOf[Object])
 <= 0)
     })
   }
-
-  private def doWriteAndValidateColumnStats(testCase: ColumnStatsTestCase,
-                                            metadataOpts: Map[String, String],
-                                            hudiOpts: Map[String, String],
-                                            dataSourcePath: String,
-                                            expectedColStatsSourcePath: String,
-                                            operation: String,
-                                            saveMode: SaveMode): Unit = {
-    val sourceJSONTablePath = 
getClass.getClassLoader.getResource(dataSourcePath).toString
-
-    // NOTE: Schema here is provided for validation that the input date is in 
the appropriate format
-    val inputDF = 
spark.read.schema(sourceTableSchema).json(sourceJSONTablePath)
-
-    inputDF
-      .sort("c1")
-      .repartition(4, new Column("c1"))
-      .write
-      .format("hudi")
-      .options(hudiOpts)
-      .option(HoodieStorageConfig.PARQUET_MAX_FILE_SIZE.key, 10 * 1024)
-      .option(DataSourceWriteOptions.OPERATION.key, operation)
-      .mode(saveMode)
-      .save(basePath)
-
-    metaClient = HoodieTableMetaClient.reload(metaClient)
-
-    // Currently, routine manually validating the column stats (by actually 
reading every column of every file)
-    // only supports parquet files. Therefore we skip such validation when 
delta-log files are present, and only
-    // validate in following cases: (1) COW: all operations; (2) MOR: insert 
only.
-    val shouldValidateColumnStatsManually = testCase.tableType == 
HoodieTableType.COPY_ON_WRITE ||
-      operation.equals(DataSourceWriteOptions.INSERT_OPERATION_OPT_VAL)
-
-    validateColumnStatsIndex(
-      testCase, metadataOpts, expectedColStatsSourcePath, 
shouldValidateColumnStatsManually)
-  }
-
-  private def buildColumnStatsTableManually(tablePath: String,
-                                            includedCols: Seq[String],
-                                            indexedCols: Seq[String],
-                                            indexSchema: StructType): 
DataFrame = {
-    val files = {
-      val it = fs.listFiles(new Path(tablePath), true)
-      var seq = Seq[LocatedFileStatus]()
-      while (it.hasNext) {
-        seq = seq :+ it.next()
-      }
-      seq.filter(fs => fs.getPath.getName.endsWith(".parquet"))
-    }
-
-    spark.createDataFrame(
-      files.flatMap(file => {
-        val df = 
spark.read.schema(sourceTableSchema).parquet(file.getPath.toString)
-        val exprs: Seq[String] =
-          s"'${typedLit(file.getPath.getName)}' AS file" +:
-          s"sum(1) AS valueCount" +:
-            df.columns
-              .filter(col => includedCols.contains(col))
-              .filter(col => indexedCols.contains(col))
-              .flatMap(col => {
-                val minColName = s"${col}_minValue"
-                val maxColName = s"${col}_maxValue"
-                if (indexedCols.contains(col)) {
-                  Seq(
-                    s"min($col) AS $minColName",
-                    s"max($col) AS $maxColName",
-                    s"sum(cast(isnull($col) AS long)) AS ${col}_nullCount"
-                  )
-                } else {
-                  Seq(
-                    s"null AS $minColName",
-                    s"null AS $maxColName",
-                    s"null AS ${col}_nullCount"
-                  )
-                }
-              })
-
-        df.selectExpr(exprs: _*)
-          .collect()
-      }).asJava,
-      indexSchema
-    )
-  }
-
-  private def validateColumnStatsIndex(testCase: ColumnStatsTestCase,
-                                       metadataOpts: Map[String, String],
-                                       expectedColStatsSourcePath: String,
-                                       validateColumnStatsManually: Boolean): 
Unit = {
-    val metadataConfig = HoodieMetadataConfig.newBuilder()
-      .fromProperties(toProperties(metadataOpts))
-      .build()
-
-    val columnStatsIndex = new ColumnStatsIndexSupport(spark, 
sourceTableSchema, metadataConfig, metaClient)
-
-    val indexedColumns: Set[String] = {
-      val customIndexedColumns = 
metadataConfig.getColumnsEnabledForColumnStatsIndex
-      if (customIndexedColumns.isEmpty) {
-        sourceTableSchema.fieldNames.toSet
-      } else {
-        customIndexedColumns.asScala.toSet
-      }
-    }
-    val (expectedColStatsSchema, _) = 
composeIndexSchema(sourceTableSchema.fieldNames, indexedColumns, 
sourceTableSchema)
-    val validationSortColumns = Seq("c1_maxValue", "c1_minValue", 
"c2_maxValue", "c2_minValue")
-
-    columnStatsIndex.loadTransposed(sourceTableSchema.fieldNames, 
testCase.shouldReadInMemory) { transposedColStatsDF =>
-      // Match against expected column stats table
-      val expectedColStatsIndexTableDf =
-        spark.read
-          .schema(expectedColStatsSchema)
-          
.json(getClass.getClassLoader.getResource(expectedColStatsSourcePath).toString)
-
-      assertEquals(expectedColStatsIndexTableDf.schema, 
transposedColStatsDF.schema)
-      // NOTE: We have to drop the `fileName` column as it contains 
semi-random components
-      //       that we can't control in this test. Nevertheless, since we 
manually verify composition of the
-      //       ColStats Index by reading Parquet footers from individual 
Parquet files, this is not an issue
-      assertEquals(asJson(sort(expectedColStatsIndexTableDf, 
validationSortColumns)),
-        asJson(sort(transposedColStatsDF.drop("fileName"), 
validationSortColumns)))
-
-      if (validateColumnStatsManually) {
-        // TODO(HUDI-4557): support validation of column stats of avro log 
files
-        // Collect Column Stats manually (reading individual Parquet files)
-        val manualColStatsTableDF =
-          buildColumnStatsTableManually(basePath, 
sourceTableSchema.fieldNames, sourceTableSchema.fieldNames, 
expectedColStatsSchema)
-
-        assertEquals(asJson(sort(manualColStatsTableDF, 
validationSortColumns)),
-          asJson(sort(transposedColStatsDF, validationSortColumns)))
-      }
-    }
-  }
-
-  private def generateRandomDataFrame(spark: SparkSession): DataFrame = {
-    val sourceTableSchema =
-      new StructType()
-        .add("c1", IntegerType)
-        .add("c2", StringType)
-        // NOTE: We're testing different values for precision of the decimal 
to make sure
-        //       we execute paths bearing different underlying representations 
in Parquet
-        // REF: 
https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#DECIMAL
-        .add("c3a", DecimalType(9, 3))
-        .add("c3b", DecimalType(10, 3))
-        .add("c3c", DecimalType(20, 3))
-        .add("c4", TimestampType)
-        .add("c5", ShortType)
-        .add("c6", DateType)
-        .add("c7", BinaryType)
-        .add("c8", ByteType)
-
-    val rdd = spark.sparkContext.parallelize(0 to 1000, 1).map { item =>
-      val c1 = Integer.valueOf(item)
-      val c2 = Random.nextString(10)
-      val c3a = java.math.BigDecimal.valueOf(Random.nextInt() % (1 << 24), 3)
-      val c3b = java.math.BigDecimal.valueOf(Random.nextLong() % (1L << 32), 3)
-      // NOTE: We cap it at 2^64 to make sure we're not exceeding target 
decimal's range
-      val c3c = new java.math.BigDecimal(new BigInteger(64, new 
java.util.Random()), 3)
-      val c4 = new Timestamp(System.currentTimeMillis())
-      val c5 = java.lang.Short.valueOf(s"${(item + 16) / 10}")
-      val c6 = Date.valueOf(s"${2020}-${item % 11 + 1}-${item % 28 + 1}")
-      val c7 = Array(item).map(_.toByte)
-      val c8 = java.lang.Byte.valueOf("9")
-
-      RowFactory.create(c1, c2, c3a, c3b, c3c, c4, c5, c6, c7, c8)
-    }
-
-    spark.createDataFrame(rdd, sourceTableSchema)
-  }
-
-  private def asJson(df: DataFrame) =
-    df.toJSON
-      .select("value")
-      .collect()
-      .toSeq
-      .map(_.getString(0))
-      .mkString("\n")
-
-  private def sort(df: DataFrame): DataFrame = {
-    sort(df, Seq("c1_maxValue", "c1_minValue"))
-  }
-
-  private def sort(df: DataFrame, sortColumns: Seq[String]): DataFrame = {
-    val sortedCols = df.columns.sorted
-    // Sort dataset by specified columns (to minimize non-determinism in case 
multiple files have the same
-    // value of the first column)
-    df.select(sortedCols.head, sortedCols.tail: _*)
-      .sort(sortColumns.head, sortColumns.tail: _*)
-  }
-}
-
-object TestColumnStatsIndex {
-
-  case class ColumnStatsTestCase(tableType: HoodieTableType, 
shouldReadInMemory: Boolean)
-
-  def testMetadataColumnStatsIndexParams: java.util.stream.Stream[Arguments] = 
{
-    
java.util.stream.Stream.of(HoodieTableType.values().toStream.flatMap(tableType 
=>
-      Seq(Arguments.arguments(ColumnStatsTestCase(tableType, 
shouldReadInMemory = true)),
-        Arguments.arguments(ColumnStatsTestCase(tableType, shouldReadInMemory 
= false)))
-    ): _*)
-  }
 }
diff --git 
a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestColumnStatsIndexWithSQL.scala
 
b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestColumnStatsIndexWithSQL.scala
new file mode 100644
index 00000000000..1bb35bc150c
--- /dev/null
+++ 
b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestColumnStatsIndexWithSQL.scala
@@ -0,0 +1,398 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hudi.functional
+
+import org.apache.hudi.DataSourceWriteOptions.{DELETE_OPERATION_OPT_VAL, 
PRECOMBINE_FIELD, RECORDKEY_FIELD}
+import org.apache.hudi.client.SparkRDDWriteClient
+import org.apache.hudi.client.common.HoodieSparkEngineContext
+import org.apache.hudi.client.utils.MetadataConversionUtils
+import org.apache.hudi.common.config.HoodieMetadataConfig
+import org.apache.hudi.common.fs.FSUtils
+import org.apache.hudi.common.model.{HoodieCommitMetadata, HoodieTableType, 
WriteOperationType}
+import org.apache.hudi.common.table.HoodieTableConfig
+import org.apache.hudi.common.table.timeline.HoodieInstant
+import org.apache.hudi.config.{HoodieCompactionConfig, HoodieIndexConfig, 
HoodieWriteConfig}
+import org.apache.hudi.functional.ColumnStatIndexTestBase.ColumnStatsTestCase
+import org.apache.hudi.index.HoodieIndex.IndexType.INMEMORY
+import org.apache.hudi.metadata.HoodieMetadataFileSystemView
+import org.apache.hudi.util.JavaConversions
+import org.apache.hudi.{DataSourceReadOptions, DataSourceWriteOptions, 
HoodieFileIndex}
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, 
Expression, GreaterThan, Literal}
+import org.apache.spark.sql.types.StringType
+import org.junit.jupiter.api.Assertions.{assertEquals, assertFalse, assertTrue}
+import org.junit.jupiter.params.ParameterizedTest
+import org.junit.jupiter.params.provider.MethodSource
+
+import java.util.Properties
+import scala.collection.JavaConverters
+import scala.jdk.CollectionConverters.{asScalaIteratorConverter, 
collectionAsScalaIterableConverter}
+
+class TestColumnStatsIndexWithSQL extends ColumnStatIndexTestBase {
+
+  @ParameterizedTest
+  @MethodSource(Array("testMetadataColumnStatsIndexParams"))
+  def testMetadataColumnStatsIndexWithSQL(testCase: ColumnStatsTestCase): Unit 
= {
+    val metadataOpts = Map(
+      HoodieMetadataConfig.ENABLE.key -> "true",
+      HoodieMetadataConfig.ENABLE_METADATA_INDEX_COLUMN_STATS.key -> "true"
+    )
+
+    val commonOpts = Map(
+      "hoodie.insert.shuffle.parallelism" -> "4",
+      "hoodie.upsert.shuffle.parallelism" -> "4",
+      HoodieWriteConfig.TBL_NAME.key -> "hoodie_test",
+      DataSourceWriteOptions.TABLE_TYPE.key -> testCase.tableType.toString,
+      RECORDKEY_FIELD.key -> "c1",
+      PRECOMBINE_FIELD.key -> "c1",
+      HoodieTableConfig.POPULATE_META_FIELDS.key -> "true",
+      DataSourceReadOptions.ENABLE_DATA_SKIPPING.key -> "true",
+      DataSourceReadOptions.QUERY_TYPE.key -> 
DataSourceReadOptions.QUERY_TYPE_INCREMENTAL_OPT_VAL
+    ) ++ metadataOpts
+    setupTable(testCase, metadataOpts, commonOpts, shouldValidate = true)
+    verifyFileIndexAndSQLQueries(commonOpts)
+  }
+
+  @ParameterizedTest
+  @MethodSource(Array("testMetadataColumnStatsIndexParamsForMOR"))
+  def testMetadataColumnStatsIndexSQLWithInMemoryIndex(testCase: 
ColumnStatsTestCase): Unit = {
+    val metadataOpts = Map(
+      HoodieMetadataConfig.ENABLE.key -> "true",
+      HoodieMetadataConfig.ENABLE_METADATA_INDEX_COLUMN_STATS.key -> "true"
+    )
+
+    val commonOpts = Map(
+      "hoodie.insert.shuffle.parallelism" -> "4",
+      "hoodie.upsert.shuffle.parallelism" -> "4",
+      HoodieWriteConfig.TBL_NAME.key -> "hoodie_test",
+      DataSourceWriteOptions.TABLE_TYPE.key -> testCase.tableType.toString,
+      RECORDKEY_FIELD.key -> "c1",
+      PRECOMBINE_FIELD.key -> "c1",
+      HoodieTableConfig.POPULATE_META_FIELDS.key -> "true",
+      DataSourceReadOptions.ENABLE_DATA_SKIPPING.key -> "true",
+      DataSourceReadOptions.QUERY_TYPE.key -> 
DataSourceReadOptions.QUERY_TYPE_INCREMENTAL_OPT_VAL,
+      HoodieIndexConfig.INDEX_TYPE.key() -> INMEMORY.name()
+    ) ++ metadataOpts
+
+    doWriteAndValidateColumnStats(testCase, metadataOpts, commonOpts,
+      dataSourcePath = "index/colstats/input-table-json",
+      expectedColStatsSourcePath = 
"index/colstats/column-stats-index-table.json",
+      operation = DataSourceWriteOptions.INSERT_OPERATION_OPT_VAL,
+      saveMode = SaveMode.Overwrite,
+      shouldValidate = false)
+
+    assertEquals(4, getLatestDataFilesCount(commonOpts))
+    assertEquals(0, getLatestDataFilesCount(commonOpts, includeLogFiles = 
false))
+    var dataFilter = GreaterThan(attribute("c5"), literal("90"))
+    verifyPruningFileCount(commonOpts, dataFilter)
+    dataFilter = GreaterThan(attribute("c5"), literal("95"))
+    verifyPruningFileCount(commonOpts, dataFilter)
+  }
+
+  @ParameterizedTest
+  @MethodSource(Array("testMetadataColumnStatsIndexParams"))
+  def testMetadataColumnStatsIndexDeletionWithSQL(testCase: 
ColumnStatsTestCase): Unit = {
+    val metadataOpts = Map(
+      HoodieMetadataConfig.ENABLE.key -> "true",
+      HoodieMetadataConfig.ENABLE_METADATA_INDEX_COLUMN_STATS.key -> "true"
+    )
+
+    val commonOpts = Map(
+      "hoodie.insert.shuffle.parallelism" -> "4",
+      "hoodie.upsert.shuffle.parallelism" -> "4",
+      HoodieWriteConfig.TBL_NAME.key -> "hoodie_test",
+      DataSourceWriteOptions.TABLE_TYPE.key -> testCase.tableType.toString,
+      RECORDKEY_FIELD.key -> "c1",
+      PRECOMBINE_FIELD.key -> "c1",
+      HoodieTableConfig.POPULATE_META_FIELDS.key -> "true",
+      DataSourceReadOptions.ENABLE_DATA_SKIPPING.key -> "true",
+      DataSourceReadOptions.QUERY_TYPE.key -> 
DataSourceReadOptions.QUERY_TYPE_INCREMENTAL_OPT_VAL
+    ) ++ metadataOpts
+    setupTable(testCase, metadataOpts, commonOpts, shouldValidate = true)
+    val lastDf = dfList.last
+
+    lastDf.write.format("org.apache.hudi")
+      .options(commonOpts)
+      .option(DataSourceWriteOptions.OPERATION.key, DELETE_OPERATION_OPT_VAL)
+      .mode(SaveMode.Append)
+      .save(basePath)
+    verifyFileIndexAndSQLQueries(commonOpts, 
isTableDataSameAsAfterSecondInstant = true)
+
+    // Add the last df back and verify the queries
+    doWriteAndValidateColumnStats(testCase, metadataOpts, commonOpts,
+      dataSourcePath = "index/colstats/update-input-table-json",
+      expectedColStatsSourcePath = "",
+      operation = DataSourceWriteOptions.UPSERT_OPERATION_OPT_VAL,
+      saveMode = SaveMode.Append,
+      shouldValidate = false)
+    verifyFileIndexAndSQLQueries(commonOpts, verifyFileCount = false)
+  }
+
+  @ParameterizedTest
+  @MethodSource(Array("testMetadataColumnStatsIndexParamsForMOR"))
+  def testMetadataColumnStatsIndexCompactionWithSQL(testCase: 
ColumnStatsTestCase): Unit = {
+    val metadataOpts = Map(
+      HoodieMetadataConfig.ENABLE.key -> "true",
+      HoodieMetadataConfig.ENABLE_METADATA_INDEX_COLUMN_STATS.key -> "true"
+    )
+
+    val commonOpts = Map(
+      "hoodie.insert.shuffle.parallelism" -> "4",
+      "hoodie.upsert.shuffle.parallelism" -> "4",
+      HoodieWriteConfig.TBL_NAME.key -> "hoodie_test",
+      DataSourceWriteOptions.TABLE_TYPE.key -> testCase.tableType.toString,
+      RECORDKEY_FIELD.key -> "c1",
+      PRECOMBINE_FIELD.key -> "c1",
+      HoodieTableConfig.POPULATE_META_FIELDS.key -> "true",
+      DataSourceReadOptions.ENABLE_DATA_SKIPPING.key -> "true",
+      DataSourceReadOptions.QUERY_TYPE.key -> 
DataSourceReadOptions.QUERY_TYPE_INCREMENTAL_OPT_VAL,
+      HoodieCompactionConfig.INLINE_COMPACT.key() -> "true",
+      HoodieCompactionConfig.INLINE_COMPACT_NUM_DELTA_COMMITS.key() -> "1"
+    ) ++ metadataOpts
+    setupTable(testCase, metadataOpts, commonOpts, shouldValidate = false)
+
+    assertFalse(hasLogFiles())
+    verifyFileIndexAndSQLQueries(commonOpts)
+  }
+
+  @ParameterizedTest
+  @MethodSource(Array("testMetadataColumnStatsIndexParamsForMOR"))
+  def testMetadataColumnStatsIndexScheduledCompactionWithSQL(testCase: 
ColumnStatsTestCase): Unit = {
+    val metadataOpts = Map(
+      HoodieMetadataConfig.ENABLE.key -> "true",
+      HoodieMetadataConfig.ENABLE_METADATA_INDEX_COLUMN_STATS.key -> "true"
+    )
+
+    val commonOpts = Map(
+      "hoodie.insert.shuffle.parallelism" -> "4",
+      "hoodie.upsert.shuffle.parallelism" -> "4",
+      HoodieWriteConfig.TBL_NAME.key -> "hoodie_test",
+      DataSourceWriteOptions.TABLE_TYPE.key -> testCase.tableType.toString,
+      RECORDKEY_FIELD.key -> "c1",
+      PRECOMBINE_FIELD.key -> "c1",
+      HoodieTableConfig.POPULATE_META_FIELDS.key -> "true",
+      DataSourceReadOptions.ENABLE_DATA_SKIPPING.key -> "true",
+      DataSourceReadOptions.QUERY_TYPE.key -> 
DataSourceReadOptions.QUERY_TYPE_INCREMENTAL_OPT_VAL,
+      HoodieCompactionConfig.INLINE_COMPACT_NUM_DELTA_COMMITS.key() -> "1"
+    ) ++ metadataOpts
+    setupTable(testCase, metadataOpts, commonOpts, shouldValidate = false)
+
+    val writeClient = new SparkRDDWriteClient(new 
HoodieSparkEngineContext(jsc), getWriteConfig(commonOpts))
+    writeClient.scheduleCompaction(org.apache.hudi.common.util.Option.empty())
+
+    doWriteAndValidateColumnStats(testCase, metadataOpts, commonOpts,
+      dataSourcePath = "index/colstats/update-input-table-json",
+      expectedColStatsSourcePath = "",
+      operation = DataSourceWriteOptions.UPSERT_OPERATION_OPT_VAL,
+      saveMode = SaveMode.Append,
+      shouldValidate = false)
+    verifyFileIndexAndSQLQueries(commonOpts)
+  }
+
+  private def setupTable(testCase: ColumnStatsTestCase, metadataOpts: 
Map[String, String], commonOpts: Map[String, String], shouldValidate: Boolean): 
Unit = {
+    doWriteAndValidateColumnStats(testCase, metadataOpts, commonOpts,
+      dataSourcePath = "index/colstats/input-table-json",
+      expectedColStatsSourcePath = 
"index/colstats/column-stats-index-table.json",
+      operation = DataSourceWriteOptions.INSERT_OPERATION_OPT_VAL,
+      saveMode = SaveMode.Overwrite)
+
+    doWriteAndValidateColumnStats(testCase, metadataOpts, commonOpts,
+      dataSourcePath = "index/colstats/another-input-table-json",
+      expectedColStatsSourcePath = 
"index/colstats/updated-column-stats-index-table.json",
+      operation = DataSourceWriteOptions.UPSERT_OPERATION_OPT_VAL,
+      saveMode = SaveMode.Append)
+
+    // NOTE: MOR and COW have different fixtures since MOR is bearing 
delta-log files (holding
+    //       deferred updates), diverging from COW
+    val expectedColStatsSourcePath = if (testCase.tableType == 
HoodieTableType.COPY_ON_WRITE) {
+      "index/colstats/cow-updated2-column-stats-index-table.json"
+    } else {
+      "index/colstats/mor-updated2-column-stats-index-table.json"
+    }
+
+    doWriteAndValidateColumnStats(testCase, metadataOpts, commonOpts,
+      dataSourcePath = "index/colstats/update-input-table-json",
+      expectedColStatsSourcePath = expectedColStatsSourcePath,
+      operation = DataSourceWriteOptions.UPSERT_OPERATION_OPT_VAL,
+      saveMode = SaveMode.Append,
+      shouldValidate)
+  }
+
+  def verifyFileIndexAndSQLQueries(opts: Map[String, String], 
isTableDataSameAsAfterSecondInstant: Boolean = false, verifyFileCount: Boolean 
= true): Unit = {
+    var commonOpts = opts
+    val inputDF1 = spark.read.format("hudi")
+      .options(commonOpts)
+      .option("as.of.instant", 
metaClient.getActiveTimeline.getInstants.get(1).getTimestamp)
+      .option(DataSourceReadOptions.QUERY_TYPE.key, 
DataSourceReadOptions.QUERY_TYPE_SNAPSHOT_OPT_VAL)
+      .option(DataSourceReadOptions.ENABLE_DATA_SKIPPING.key, "false")
+      .load(basePath)
+    inputDF1.createOrReplaceTempView("tbl")
+    val numRecordsForFirstQuery = spark.sql("select * from tbl where c5 > 
70").count()
+    val numRecordsForSecondQuery = spark.sql("select * from tbl where c5 > 70 
and c6 >= '2020-03-28'").count()
+    // verify snapshot query
+    verifySQLQueries(numRecordsForFirstQuery, numRecordsForSecondQuery, 
DataSourceReadOptions.QUERY_TYPE_SNAPSHOT_OPT_VAL, commonOpts, 
isTableDataSameAsAfterSecondInstant)
+
+    // verify read_optimized query
+    verifySQLQueries(numRecordsForFirstQuery, numRecordsForSecondQuery, 
DataSourceReadOptions.QUERY_TYPE_READ_OPTIMIZED_OPT_VAL, commonOpts, 
isTableDataSameAsAfterSecondInstant)
+
+    // verify incremental query
+    verifySQLQueries(numRecordsForFirstQuery, numRecordsForSecondQuery, 
DataSourceReadOptions.QUERY_TYPE_INCREMENTAL_OPT_VAL, commonOpts, 
isTableDataSameAsAfterSecondInstant)
+    commonOpts = commonOpts + 
(DataSourceReadOptions.INCREMENTAL_FALLBACK_TO_FULL_TABLE_SCAN_FOR_NON_EXISTING_FILES.key
 -> "true")
+    // TODO: https://issues.apache.org/jira/browse/HUDI-6657 - Investigate why 
below assertions fail with full table scan enabled.
+    //verifySQLQueries(numRecordsForFirstQuery, 
DataSourceReadOptions.QUERY_TYPE_INCREMENTAL_OPT_VAL, commonOpts, 
isTableDataSameAsAfterSecondInstant)
+
+    var dataFilter: Expression = GreaterThan(attribute("c5"), literal("70"))
+    verifyPruningFileCount(commonOpts, dataFilter)
+    dataFilter = And(dataFilter, GreaterThan(attribute("c6"), 
literal("'2020-03-28'")))
+    verifyPruningFileCount(commonOpts, dataFilter)
+    dataFilter = GreaterThan(attribute("c5"), literal("90"))
+    verifyPruningFileCount(commonOpts, dataFilter)
+    dataFilter = And(dataFilter, GreaterThan(attribute("c6"), 
literal("'2020-03-28'")))
+    verifyPruningFileCount(commonOpts, dataFilter)
+  }
+
+  private def verifyPruningFileCount(opts: Map[String, String], dataFilter: 
Expression): Unit = {
+    // with data skipping
+    val commonOpts = opts + ("path" -> basePath)
+    var fileIndex = HoodieFileIndex(spark, metaClient, None, commonOpts, 
includeLogFiles = true)
+    val filteredPartitionDirectories = fileIndex.listFiles(Seq(), 
Seq(dataFilter))
+    val filteredFilesCount = filteredPartitionDirectories.flatMap(s => 
s.files).size
+    assertTrue(filteredFilesCount < getLatestDataFilesCount(opts))
+
+    // with no data skipping
+    fileIndex = HoodieFileIndex(spark, metaClient, None, commonOpts + 
(DataSourceReadOptions.ENABLE_DATA_SKIPPING.key -> "false"), includeLogFiles = 
true)
+    val filesCountWithNoSkipping = fileIndex.listFiles(Seq(), 
Seq(dataFilter)).flatMap(s => s.files).size
+    assertTrue(filteredFilesCount < filesCountWithNoSkipping)
+  }
+
+  private def getLatestDataFilesCount(opts: Map[String, String], 
includeLogFiles: Boolean = true) = {
+    var totalLatestDataFiles = 0L
+    val fsView = getTableFileSystemView(opts)
+    fsView.loadAllPartitions()
+    fsView.getPartitionPaths.asScala.flatMap { partitionPath =>
+      val relativePath = 
FSUtils.getRelativePartitionPath(metaClient.getBasePathV2, partitionPath)
+      fsView.getLatestMergedFileSlicesBeforeOrOn(relativePath, 
metaClient.reloadActiveTimeline().lastInstant().get().getTimestamp).iterator().asScala.toSeq
+    }.foreach(
+      slice => totalLatestDataFiles += (if (includeLogFiles) 
slice.getLogFiles.count() else 0)
+        + (if (slice.getBaseFile.isPresent) 1 else 0))
+    totalLatestDataFiles
+  }
+
+  private def getTableFileSystemView(opts: Map[String, String]): 
HoodieMetadataFileSystemView = {
+    new HoodieMetadataFileSystemView(metaClient, metaClient.getActiveTimeline, 
metadataWriter(getWriteConfig(opts)).getTableMetadata)
+  }
+
+  protected def getWriteConfig(hudiOpts: Map[String, String]): 
HoodieWriteConfig = {
+    val props = new Properties()
+    props.putAll(JavaConverters.mapAsJavaMapConverter(hudiOpts).asJava)
+    HoodieWriteConfig.newBuilder()
+      .withProps(props)
+      .withPath(basePath)
+      .build()
+  }
+
+  private def attribute(partition: String): AttributeReference = {
+    AttributeReference(partition, StringType, true)()
+  }
+
+  private def literal(value: String): Literal = {
+    Literal.create(value)
+  }
+
+  private def verifySQLQueries(numRecordsForFirstQueryAtPrevInstant: Long, 
numRecordsForSecondQueryAtPrevInstant: Long,
+                               queryType: String, opts: Map[String, String], 
isLastOperationDelete: Boolean): Unit = {
+    val firstQuery = "select * from tbl where c5 > 70"
+    val secondQuery = "select * from tbl where c5 > 70 and c6 >= '2020-03-28'"
+    // 2 records are updated with c5 greater than 70 and one record is 
inserted with c5 value greater than 70
+    var commonOpts: Map[String, String] = opts
+    createSQLTable(commonOpts, queryType)
+    val incrementFirstQuery = if 
(queryType.equals(DataSourceReadOptions.QUERY_TYPE_READ_OPTIMIZED_OPT_VAL) && 
hasLogFiles()) {
+      1 // only one insert
+    } else if (isLastOperationDelete) {
+      0 // no increment
+    } else {
+      3 // one insert and two upserts
+    }
+    val incrementSecondQuery = if 
(queryType.equals(DataSourceReadOptions.QUERY_TYPE_READ_OPTIMIZED_OPT_VAL) && 
hasLogFiles()) {
+      1 // only one insert
+    } else if (isLastOperationDelete) {
+      0 // no increment
+    } else {
+      2 // one insert and two upserts
+    }
+    assertEquals(spark.sql(firstQuery).count(), 
numRecordsForFirstQueryAtPrevInstant + incrementFirstQuery)
+    assertEquals(spark.sql(secondQuery).count(), 
numRecordsForSecondQueryAtPrevInstant + incrementSecondQuery)
+    val numRecordsForFirstQueryWithDataSkipping = spark.sql(firstQuery).count()
+    val numRecordsForSecondQueryWithDataSkipping = 
spark.sql(secondQuery).count()
+
+    if 
(queryType.equals(DataSourceReadOptions.QUERY_TYPE_INCREMENTAL_OPT_VAL)) {
+      createIncrementalSQLTable(commonOpts, 
metaClient.reloadActiveTimeline().getInstants.get(1).getTimestamp)
+      assertEquals(spark.sql(firstQuery).count(), if (isLastOperationDelete) 0 
else 3)
+      assertEquals(spark.sql(secondQuery).count(), if (isLastOperationDelete) 
0 else 2)
+    }
+
+    commonOpts = opts + (DataSourceReadOptions.ENABLE_DATA_SKIPPING.key -> 
"false")
+    createSQLTable(commonOpts, queryType)
+    val numRecordsForFirstQueryWithoutDataSkipping = 
spark.sql(firstQuery).count()
+    val numRecordsForSecondQueryWithoutDataSkipping = 
spark.sql(secondQuery).count()
+    assertEquals(numRecordsForFirstQueryWithDataSkipping, 
numRecordsForFirstQueryWithoutDataSkipping)
+    assertEquals(numRecordsForSecondQueryWithDataSkipping, 
numRecordsForSecondQueryWithoutDataSkipping)
+  }
+
+  private def createSQLTable(hudiOpts: Map[String, String], queryType: 
String): Unit = {
+    val opts = hudiOpts + (
+      DataSourceReadOptions.QUERY_TYPE.key -> queryType,
+      DataSourceReadOptions.BEGIN_INSTANTTIME.key() -> 
metaClient.getActiveTimeline.getInstants.get(0).getTimestamp.replaceFirst(".", 
"0")
+    )
+    val inputDF1 = spark.read.format("hudi").options(opts).load(basePath)
+    inputDF1.createOrReplaceTempView("tbl")
+  }
+
+  private def createIncrementalSQLTable(hudiOpts: Map[String, String], 
instantTime: String): Unit = {
+    val opts = hudiOpts + (
+      DataSourceReadOptions.QUERY_TYPE.key -> 
DataSourceReadOptions.QUERY_TYPE_INCREMENTAL_OPT_VAL,
+      DataSourceReadOptions.BEGIN_INSTANTTIME.key() -> instantTime
+    )
+    val inputDF1 = spark.read.format("hudi").options(opts).load(basePath)
+    inputDF1.createOrReplaceTempView("tbl")
+  }
+
+  private def hasLogFiles(): Boolean = {
+    isTableMOR && getLatestCompactionInstant() != 
metaClient.getActiveTimeline.lastInstant()
+  }
+
+  private def isTableMOR(): Boolean = {
+    metaClient.getTableType == HoodieTableType.MERGE_ON_READ
+  }
+
+  protected def getLatestCompactionInstant(): 
org.apache.hudi.common.util.Option[HoodieInstant] = {
+    metaClient.reloadActiveTimeline()
+      .filter(JavaConversions.getPredicate(s => Option(
+        try {
+          val commitMetadata = 
MetadataConversionUtils.getHoodieCommitMetadata(metaClient, s)
+            .orElse(new HoodieCommitMetadata())
+          commitMetadata
+        } catch {
+          case _: Exception => new HoodieCommitMetadata()
+        })
+        .map(c => c.getOperationType == WriteOperationType.COMPACT)
+        .get))
+      .lastInstant()
+  }
+}
diff --git 
a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/procedure/TestClusteringProcedure.scala
 
b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/procedure/TestClusteringProcedure.scala
index 1cc3a968e09..8da368039d5 100644
--- 
a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/procedure/TestClusteringProcedure.scala
+++ 
b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/procedure/TestClusteringProcedure.scala
@@ -587,7 +587,7 @@ class TestClusteringProcedure extends 
HoodieSparkProcedureTestBase {
 
       metaClient.reloadActiveTimeline()
       val fileIndex1 = HoodieFileIndex(spark, metaClient, None, queryOpts)
-      val orderAllFiles = fileIndex1.allFiles.size
+      val orderAllFiles = fileIndex1.allBaseFiles.size
       val c2OrderFilterCount = fileIndex1.listFiles(Seq(), 
Seq(dataFilterC2)).head.files.size
       val c3OrderFilterCount = fileIndex1.listFiles(Seq(), 
Seq(dataFilterC3)).head.files.size
 
@@ -604,7 +604,7 @@ class TestClusteringProcedure extends 
HoodieSparkProcedureTestBase {
 
       metaClient.reloadActiveTimeline()
       val fileIndex2 = HoodieFileIndex(spark, metaClient, None, queryOpts)
-      val ZOrderAllFiles = fileIndex2.allFiles.size
+      val ZOrderAllFiles = fileIndex2.allBaseFiles.size
       val c2ZOrderFilterCount = fileIndex2.listFiles(Seq(), 
Seq(dataFilterC2)).head.files.size
       val c3ZOrderFilterCount = fileIndex2.listFiles(Seq(), 
Seq(dataFilterC3)).head.files.size
 

Reply via email to