This is an automated email from the ASF dual-hosted git repository.

kazuyukitanimura pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 9b3e87b5 build: Add spark-4.0 profile and shims (#407)
9b3e87b5 is described below

commit 9b3e87b5e8b616ac6663411090513f29b98330e7
Author: KAZUYUKI TANIMURA <[email protected]>
AuthorDate: Tue May 28 12:09:03 2024 -0700

    build: Add spark-4.0 profile and shims (#407)
    
    This PR adds the spark-4.0 profile and shims
    This is an initial commit. Tests with the spark-4.0 profile do not pass 
yet. Tests for spark-3.x should pass.
---
 .github/workflows/pr_build.yml                     | 113 +++++++++++++++++++++
 .../apache/comet/parquet/CometParquetUtils.scala   |  61 +----------
 .../sql/comet/shims/ShimCometParquetUtils.scala}   |  22 +---
 .../org/apache/comet/shims/ShimBatchReader.scala   |  26 +++--
 .../org/apache/comet/shims/ShimFileFormat.scala    |  20 ++--
 .../comet/shims/ShimResolveDefaultColumns.scala    |  18 ++--
 .../sql/comet/shims/ShimCometParquetUtils.scala    |  34 +++----
 dev/ensure-jars-have-correct-contents.sh           |   2 +
 pom.xml                                            |  27 +++--
 spark/pom.xml                                      |   2 +
 .../comet/parquet/CometParquetFileFormat.scala     |   9 +-
 .../CometParquetPartitionReaderFactory.scala       |   3 +-
 .../org/apache/comet/parquet/ParquetFilters.scala  |   9 +-
 .../org/apache/comet/serde/QueryPlanSerde.scala    |  23 ++---
 .../src/main/scala/org/apache/spark/Plugins.scala  |  24 ++---
 .../sql/comet/CometBroadcastExchangeExec.scala     |   2 +-
 .../org/apache/spark/sql/comet/CometScanExec.scala |   9 +-
 .../apache/spark/sql/comet/DecimalPrecision.scala  |   9 +-
 .../shuffle/CometShuffleExchangeExec.scala         |  36 ++++---
 .../org/apache/comet/shims/CometExprShim.scala     |   4 +-
 .../org/apache/comet/shims/CometExprShim.scala     |   4 +-
 .../org/apache/comet/shims/CometExprShim.scala     |   4 +-
 .../comet/shims/ShimCometShuffleExchangeExec.scala |  10 ++
 .../shims/ShimCometSparkSessionExtensions.scala    |   4 -
 .../org/apache/comet/shims/ShimSQLConf.scala       |   4 +
 .../shims/ShimCometBroadcastExchangeExec.scala     |   2 +-
 .../spark/comet/shims/ShimCometDriverPlugin.scala  |  38 +++++++
 .../sql}/comet/shims/ShimCometScanExec.scala       |  29 ++++--
 .../shims/ShimCometShuffleWriteProcessor.scala}    |  38 +++----
 .../org/apache/comet/shims/CometExprShim.scala     |   6 +-
 .../comet/shims/ShimCometBatchScanExec.scala}      |  23 +++--
 .../shims/ShimCometBroadcastHashJoinExec.scala}    |  18 ++--
 .../comet/shims/ShimCometShuffleExchangeExec.scala |  15 +--
 .../shims/ShimCometSparkSessionExtensions.scala}   |  29 +++---
 .../ShimCometTakeOrderedAndProjectExec.scala}      |  15 +--
 .../apache/comet/shims/ShimQueryPlanSerde.scala    |  37 +++++++
 .../org/apache/comet/shims/ShimSQLConf.scala}      |  20 ++--
 .../shims/ShimCometBroadcastExchangeExec.scala}    |  21 ++--
 .../spark/comet/shims/ShimCometDriverPlugin.scala} |  23 ++---
 .../spark/sql/comet/shims/ShimCometScanExec.scala  |  83 +++++++++++++++
 .../shims/ShimCometShuffleWriteProcessor.scala}    |  17 +---
 .../apache/spark/sql/CometTPCDSQuerySuite.scala    |   7 +-
 .../org/apache/spark/sql/CometTPCHQuerySuite.scala |   5 +-
 .../scala/org/apache/spark/sql/CometTestBase.scala |   1 -
 .../spark/sql/benchmark/CometReadBenchmark.scala   |   5 +-
 .../apache/comet/exec/CometExec3_4PlusSuite.scala} |   4 +-
 .../comet/shims/ShimCometTPCHQuerySuite.scala}     |  17 ++--
 .../apache/spark/comet/shims/ShimTestUtils.scala}  |   5 +-
 .../comet/shims/ShimCometTPCDSQuerySuite.scala}    |  16 +--
 .../comet/shims/ShimCometTPCHQuerySuite.scala}     |  14 +--
 .../apache/spark/comet/shims/ShimTestUtils.scala}  |  18 ++--
 .../comet/shims/ShimCometTPCDSQuerySuite.scala}    |  15 +--
 52 files changed, 586 insertions(+), 414 deletions(-)

diff --git a/.github/workflows/pr_build.yml b/.github/workflows/pr_build.yml
index 1e347250..410f1e1f 100644
--- a/.github/workflows/pr_build.yml
+++ b/.github/workflows/pr_build.yml
@@ -76,6 +76,44 @@ jobs:
           # upload test reports only for java 17
           upload-test-reports: ${{ matrix.java_version == '17' }}
 
+  linux-test-with-spark4_0:
+    strategy:
+      matrix:
+        os: [ubuntu-latest]
+        java_version: [17]
+        test-target: [java]
+        spark-version: ['4.0']
+        is_push_event:
+          - ${{ github.event_name == 'push' }}
+      fail-fast: false
+    name: ${{ matrix.os }}/java ${{ matrix.java_version 
}}-spark-${{matrix.spark-version}}/${{ matrix.test-target }}
+    runs-on: ${{ matrix.os }}
+    container:
+      image: amd64/rust
+    steps:
+      - uses: actions/checkout@v4
+      - name: Setup Rust & Java toolchain
+        uses: ./.github/actions/setup-builder
+        with:
+          rust-version: ${{env.RUST_VERSION}}
+          jdk-version: ${{ matrix.java_version }}
+      - name: Clone Spark
+        uses: actions/checkout@v4
+        with:
+          repository: "apache/spark"
+          path: "apache-spark"
+      - name: Install Spark
+        shell: bash
+        working-directory: ./apache-spark
+        run: build/mvn install -Phive -Phadoop-cloud -DskipTests
+      - name: Java test steps
+        uses: ./.github/actions/java-test
+        with:
+          # TODO: remove -DskipTests after fixing tests
+          maven_opts: "-Pspark-${{ matrix.spark-version }} -DskipTests"
+          # TODO: upload test reports after enabling tests
+          upload-test-reports: false
+
   linux-test-with-old-spark:
     strategy:
       matrix:
@@ -169,6 +207,81 @@ jobs:
         with:
           maven_opts: -Pspark-${{ matrix.spark-version }},scala-${{ 
matrix.scala-version }}
 
+  macos-test-with-spark4_0:
+    strategy:
+      matrix:
+        os: [macos-13]
+        java_version: [17]
+        test-target: [java]
+        spark-version: ['4.0']
+      fail-fast: false
+    if: github.event_name == 'push'
+    name: ${{ matrix.os }}/java ${{ matrix.java_version 
}}-spark-${{matrix.spark-version}}/${{ matrix.test-target }}
+    runs-on: ${{ matrix.os }}
+    steps:
+      - uses: actions/checkout@v4
+      - name: Setup Rust & Java toolchain
+        uses: ./.github/actions/setup-macos-builder
+        with:
+          rust-version: ${{env.RUST_VERSION}}
+          jdk-version: ${{ matrix.java_version }}
+      - name: Clone Spark
+        uses: actions/checkout@v4
+        with:
+          repository: "apache/spark"
+          path: "apache-spark"
+      - name: Install Spark
+        shell: bash
+        working-directory: ./apache-spark
+        run: build/mvn install -Phive -Phadoop-cloud -DskipTests
+      - name: Java test steps
+        uses: ./.github/actions/java-test
+        with:
+          # TODO: remove -DskipTests after fixing tests
+          maven_opts: "-Pspark-${{ matrix.spark-version }} -DskipTests"
+          # TODO: upload test reports after enabling tests
+          upload-test-reports: false
+
+  macos-aarch64-test-with-spark4_0:
+    strategy:
+      matrix:
+        java_version: [17]
+        test-target: [java]
+        spark-version: ['4.0']
+        is_push_event:
+          - ${{ github.event_name == 'push' }}
+        exclude: # exclude java 11 for pull_request event
+          - java_version: 11
+            is_push_event: false
+      fail-fast: false
+    name: macos-14(Silicon)/java ${{ matrix.java_version 
}}-spark-${{matrix.spark-version}}/${{ matrix.test-target }}
+    runs-on: macos-14
+    steps:
+      - uses: actions/checkout@v4
+      - name: Setup Rust & Java toolchain
+        uses: ./.github/actions/setup-macos-builder
+        with:
+          rust-version: ${{env.RUST_VERSION}}
+          jdk-version: ${{ matrix.java_version }}
+          jdk-architecture: aarch64
+          protoc-architecture: aarch_64
+      - name: Clone Spark
+        uses: actions/checkout@v4
+        with:
+          repository: "apache/spark"
+          path: "apache-spark"
+      - name: Install Spark
+        shell: bash
+        working-directory: ./apache-spark
+        run: build/mvn install -Phive -Phadoop-cloud -DskipTests
+      - name: Java test steps
+        uses: ./.github/actions/java-test
+        with:
+          # TODO: remove -DskipTests after fixing tests
+          maven_opts: "-Pspark-${{ matrix.spark-version }} -DskipTests"
+          # TODO: upload test reports after enabling tests
+          upload-test-reports: false
+
   macos-aarch64-test-with-old-spark:
     strategy:
       matrix:
diff --git 
a/common/src/main/scala/org/apache/comet/parquet/CometParquetUtils.scala 
b/common/src/main/scala/org/apache/comet/parquet/CometParquetUtils.scala
index d851067b..d03252d0 100644
--- a/common/src/main/scala/org/apache/comet/parquet/CometParquetUtils.scala
+++ b/common/src/main/scala/org/apache/comet/parquet/CometParquetUtils.scala
@@ -20,10 +20,10 @@
 package org.apache.comet.parquet
 
 import org.apache.hadoop.conf.Configuration
+import org.apache.spark.sql.comet.shims.ShimCometParquetUtils
 import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.types._
 
-object CometParquetUtils {
+object CometParquetUtils extends ShimCometParquetUtils {
   private val PARQUET_FIELD_ID_WRITE_ENABLED = 
"spark.sql.parquet.fieldId.write.enabled"
   private val PARQUET_FIELD_ID_READ_ENABLED = 
"spark.sql.parquet.fieldId.read.enabled"
   private val IGNORE_MISSING_PARQUET_FIELD_ID = 
"spark.sql.parquet.fieldId.read.ignoreMissing"
@@ -39,61 +39,4 @@ object CometParquetUtils {
 
   def ignoreMissingIds(conf: SQLConf): Boolean =
     conf.getConfString(IGNORE_MISSING_PARQUET_FIELD_ID, "false").toBoolean
-
-  // The following is copied from QueryExecutionErrors
-  // TODO: remove after dropping Spark 3.2.0 support and directly use
-  //       QueryExecutionErrors.foundDuplicateFieldInFieldIdLookupModeError
-  def foundDuplicateFieldInFieldIdLookupModeError(
-      requiredId: Int,
-      matchedFields: String): Throwable = {
-    new RuntimeException(s"""
-         |Found duplicate field(s) "$requiredId": $matchedFields
-         |in id mapping mode
-     """.stripMargin.replaceAll("\n", " "))
-  }
-
-  // The followings are copied from 
org.apache.spark.sql.execution.datasources.parquet.ParquetUtils
-  // TODO: remove after dropping Spark 3.2.0 support and directly use 
ParquetUtils
-  /**
-   * A StructField metadata key used to set the field id of a column in the 
Parquet schema.
-   */
-  val FIELD_ID_METADATA_KEY = "parquet.field.id"
-
-  /**
-   * Whether there exists a field in the schema, whether inner or leaf, has 
the parquet field ID
-   * metadata.
-   */
-  def hasFieldIds(schema: StructType): Boolean = {
-    def recursiveCheck(schema: DataType): Boolean = {
-      schema match {
-        case st: StructType =>
-          st.exists(field => hasFieldId(field) || 
recursiveCheck(field.dataType))
-
-        case at: ArrayType => recursiveCheck(at.elementType)
-
-        case mt: MapType => recursiveCheck(mt.keyType) || 
recursiveCheck(mt.valueType)
-
-        case _ =>
-          // No need to really check primitive types, just to terminate the 
recursion
-          false
-      }
-    }
-    if (schema.isEmpty) false else recursiveCheck(schema)
-  }
-
-  def hasFieldId(field: StructField): Boolean =
-    field.metadata.contains(FIELD_ID_METADATA_KEY)
-
-  def getFieldId(field: StructField): Int = {
-    require(
-      hasFieldId(field),
-      s"The key `$FIELD_ID_METADATA_KEY` doesn't exist in the metadata of " + 
field)
-    try {
-      Math.toIntExact(field.metadata.getLong(FIELD_ID_METADATA_KEY))
-    } catch {
-      case _: ArithmeticException | _: ClassCastException =>
-        throw new IllegalArgumentException(
-          s"The key `$FIELD_ID_METADATA_KEY` must be a 32-bit integer")
-    }
-  }
 }
diff --git 
a/common/src/main/scala/org/apache/comet/parquet/CometParquetUtils.scala 
b/common/src/main/spark-3.x/org/apache/spark/sql/comet/shims/ShimCometParquetUtils.scala
similarity index 76%
copy from common/src/main/scala/org/apache/comet/parquet/CometParquetUtils.scala
copy to 
common/src/main/spark-3.x/org/apache/spark/sql/comet/shims/ShimCometParquetUtils.scala
index d851067b..f22ac406 100644
--- a/common/src/main/scala/org/apache/comet/parquet/CometParquetUtils.scala
+++ 
b/common/src/main/spark-3.x/org/apache/spark/sql/comet/shims/ShimCometParquetUtils.scala
@@ -17,29 +17,11 @@
  * under the License.
  */
 
-package org.apache.comet.parquet
+package org.apache.spark.sql.comet.shims
 
-import org.apache.hadoop.conf.Configuration
-import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
 
-object CometParquetUtils {
-  private val PARQUET_FIELD_ID_WRITE_ENABLED = 
"spark.sql.parquet.fieldId.write.enabled"
-  private val PARQUET_FIELD_ID_READ_ENABLED = 
"spark.sql.parquet.fieldId.read.enabled"
-  private val IGNORE_MISSING_PARQUET_FIELD_ID = 
"spark.sql.parquet.fieldId.read.ignoreMissing"
-
-  def writeFieldId(conf: SQLConf): Boolean =
-    conf.getConfString(PARQUET_FIELD_ID_WRITE_ENABLED, "false").toBoolean
-
-  def writeFieldId(conf: Configuration): Boolean =
-    conf.getBoolean(PARQUET_FIELD_ID_WRITE_ENABLED, false)
-
-  def readFieldId(conf: SQLConf): Boolean =
-    conf.getConfString(PARQUET_FIELD_ID_READ_ENABLED, "false").toBoolean
-
-  def ignoreMissingIds(conf: SQLConf): Boolean =
-    conf.getConfString(IGNORE_MISSING_PARQUET_FIELD_ID, "false").toBoolean
-
+trait ShimCometParquetUtils {
   // The following is copied from QueryExecutionErrors
   // TODO: remove after dropping Spark 3.2.0 support and directly use
   //       QueryExecutionErrors.foundDuplicateFieldInFieldIdLookupModeError
diff --git 
a/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala 
b/common/src/main/spark-4.0/org/apache/comet/shims/ShimBatchReader.scala
similarity index 64%
copy from spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala
copy to common/src/main/spark-4.0/org/apache/comet/shims/ShimBatchReader.scala
index 0c45a9c2..448d0886 100644
--- a/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala
+++ b/common/src/main/spark-4.0/org/apache/comet/shims/ShimBatchReader.scala
@@ -16,18 +16,22 @@
  * specific language governing permissions and limitations
  * under the License.
  */
+
 package org.apache.comet.shims
 
-import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.paths.SparkPath
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.execution.datasources.PartitionedFile
 
-/**
- * `CometExprShim` acts as a shim for for parsing expressions from different 
Spark versions.
- */
-trait CometExprShim {
-    /**
-      * Returns a tuple of expressions for the `unhex` function.
-      */
-    def unhexSerde(unhex: Unhex): (Expression, Expression) = {
-        (unhex.child, Literal(false))
-    }
+object ShimBatchReader {
+  def newPartitionedFile(partitionValues: InternalRow, file: String): 
PartitionedFile =
+    PartitionedFile(
+      partitionValues,
+      SparkPath.fromUrlString(file),
+      -1, // -1 means we read the entire file
+      -1,
+      Array.empty[String],
+      0,
+      0
+    )
 }
diff --git 
a/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala 
b/common/src/main/spark-4.0/org/apache/comet/shims/ShimFileFormat.scala
similarity index 64%
copy from spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala
copy to common/src/main/spark-4.0/org/apache/comet/shims/ShimFileFormat.scala
index 0c45a9c2..2f386869 100644
--- a/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala
+++ b/common/src/main/spark-4.0/org/apache/comet/shims/ShimFileFormat.scala
@@ -16,18 +16,16 @@
  * specific language governing permissions and limitations
  * under the License.
  */
+
 package org.apache.comet.shims
 
-import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.execution.datasources.FileFormat
+import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
 
-/**
- * `CometExprShim` acts as a shim for for parsing expressions from different 
Spark versions.
- */
-trait CometExprShim {
-    /**
-      * Returns a tuple of expressions for the `unhex` function.
-      */
-    def unhexSerde(unhex: Unhex): (Expression, Expression) = {
-        (unhex.child, Literal(false))
-    }
+object ShimFileFormat {
+  // A name for a temporary column that holds row indexes computed by the file 
format reader
+  // until they can be placed in the _metadata struct.
+  val ROW_INDEX_TEMPORARY_COLUMN_NAME = 
ParquetFileFormat.ROW_INDEX_TEMPORARY_COLUMN_NAME
+
+  val OPTION_RETURNING_BATCH = FileFormat.OPTION_RETURNING_BATCH
 }
diff --git 
a/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala 
b/common/src/main/spark-4.0/org/apache/comet/shims/ShimResolveDefaultColumns.scala
similarity index 69%
copy from spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala
copy to 
common/src/main/spark-4.0/org/apache/comet/shims/ShimResolveDefaultColumns.scala
index 0c45a9c2..60e21765 100644
--- a/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala
+++ 
b/common/src/main/spark-4.0/org/apache/comet/shims/ShimResolveDefaultColumns.scala
@@ -16,18 +16,14 @@
  * specific language governing permissions and limitations
  * under the License.
  */
+
 package org.apache.comet.shims
 
-import org.apache.spark.sql.catalyst.expressions._
 
-/**
- * `CometExprShim` acts as a shim for for parsing expressions from different 
Spark versions.
- */
-trait CometExprShim {
-    /**
-      * Returns a tuple of expressions for the `unhex` function.
-      */
-    def unhexSerde(unhex: Unhex): (Expression, Expression) = {
-        (unhex.child, Literal(false))
-    }
+import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns
+import org.apache.spark.sql.types.{StructField, StructType}
+
+object ShimResolveDefaultColumns {
+  def getExistenceDefaultValue(field: StructField): Any =
+    
ResolveDefaultColumns.getExistenceDefaultValues(StructType(Seq(field))).head
 }
diff --git a/spark/src/main/spark-3.x/org/apache/comet/shims/ShimSQLConf.scala 
b/common/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimCometParquetUtils.scala
similarity index 54%
copy from spark/src/main/spark-3.x/org/apache/comet/shims/ShimSQLConf.scala
copy to 
common/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimCometParquetUtils.scala
index ff60ef96..d402cd78 100644
--- a/spark/src/main/spark-3.x/org/apache/comet/shims/ShimSQLConf.scala
+++ 
b/common/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimCometParquetUtils.scala
@@ -17,26 +17,22 @@
  * under the License.
  */
 
-package org.apache.comet.shims
+package org.apache.spark.sql.comet.shims
 
-import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.errors.QueryExecutionErrors
+import org.apache.spark.sql.execution.datasources.parquet.ParquetUtils
+import org.apache.spark.sql.types._
 
-trait ShimSQLConf {
+trait ShimCometParquetUtils {
+  def foundDuplicateFieldInFieldIdLookupModeError(
+      requiredId: Int,
+      matchedFields: String): Throwable = {
+    
QueryExecutionErrors.foundDuplicateFieldInFieldIdLookupModeError(requiredId, 
matchedFields)
+  }
 
-  /**
-   * Spark 3.4 renamed parquetFilterPushDownStringStartWith to
-   * parquetFilterPushDownStringPredicate
-   *
-   * TODO: delete after dropping Spark 3.2 & 3.3 support and simply use
-   * parquetFilterPushDownStringPredicate
-   */
-  protected def getPushDownStringPredicate(sqlConf: SQLConf): Boolean =
-    sqlConf.getClass.getMethods
-      .flatMap(m =>
-        m.getName match {
-          case "parquetFilterPushDownStringStartWith" | 
"parquetFilterPushDownStringPredicate" =>
-            Some(m.invoke(sqlConf).asInstanceOf[Boolean])
-          case _ => None
-        })
-      .head
+  def hasFieldIds(schema: StructType): Boolean = 
ParquetUtils.hasFieldIds(schema)
+
+  def hasFieldId(field: StructField): Boolean = ParquetUtils.hasFieldId(field)
+
+  def getFieldId(field: StructField): Int = ParquetUtils.getFieldId (field)
 }
diff --git a/dev/ensure-jars-have-correct-contents.sh 
b/dev/ensure-jars-have-correct-contents.sh
index 1f97d2d4..12f555b8 100755
--- a/dev/ensure-jars-have-correct-contents.sh
+++ b/dev/ensure-jars-have-correct-contents.sh
@@ -40,9 +40,11 @@ allowed_expr="(^org/$|^org/apache/$"
 # we have to allow the directories that lead to the org/apache/comet dir
 # We allow all the classes under the following packages:
 #   * org.apache.comet
+#   * org.apache.spark.comet
 #   * org.apache.spark.sql.comet
 #   * org.apache.arrow.c
 allowed_expr+="|^org/apache/comet/"
+allowed_expr+="|^org/apache/spark/comet/"
 allowed_expr+="|^org/apache/spark/sql/comet/"
 allowed_expr+="|^org/apache/arrow/c/"
 #   * whatever in the "META-INF" directory
diff --git a/pom.xml b/pom.xml
index 59e0569f..57b4206c 100644
--- a/pom.xml
+++ b/pom.xml
@@ -87,7 +87,7 @@ under the License.
     </extraJavaTestArgs>
     <argLine>-ea -Xmx4g -Xss4m ${extraJavaTestArgs}</argLine>
     <additional.3_3.test.source>spark-3.3-plus</additional.3_3.test.source>
-    <additional.3_4.test.source>spark-3.4</additional.3_4.test.source>
+    <additional.3_4.test.source>spark-3.4-plus</additional.3_4.test.source>
     <shims.majorVerSrc>spark-3.x</shims.majorVerSrc>
     <shims.minorVerSrc>spark-3.4</shims.minorVerSrc>
   </properties>
@@ -512,7 +512,6 @@ under the License.
         <spark.version>3.3.2</spark.version>
         <spark.version.short>3.3</spark.version.short>
         <parquet.version>1.12.0</parquet.version>
-        <additional.3_3.test.source>spark-3.3-plus</additional.3_3.test.source>
         <additional.3_4.test.source>not-needed-yet</additional.3_4.test.source>
         <shims.minorVerSrc>spark-3.3</shims.minorVerSrc>
       </properties>
@@ -524,9 +523,25 @@ under the License.
         <scala.version>2.12.17</scala.version>
         <spark.version.short>3.4</spark.version.short>
         <parquet.version>1.13.1</parquet.version>
-        <additional.3_3.test.source>spark-3.3-plus</additional.3_3.test.source>
-        <additional.3_4.test.source>spark-3.4</additional.3_4.test.source>
-        <shims.minorVerSrc>spark-3.4</shims.minorVerSrc>
+      </properties>
+    </profile>
+
+    <profile>
+      <!-- FIXME: this is WIP. Tests may fail -->
+      <id>spark-4.0</id>
+      <properties>
+        <!-- Use Scala 2.13 by default -->
+        <scala.version>2.13.13</scala.version>
+        <scala.binary.version>2.13</scala.binary.version>
+        <spark.version>4.0.0-SNAPSHOT</spark.version>
+        <spark.version.short>4.0</spark.version.short>
+        <parquet.version>1.13.1</parquet.version>
+        <shims.majorVerSrc>spark-4.0</shims.majorVerSrc>
+        <shims.minorVerSrc>not-needed-yet</shims.minorVerSrc>
+        <!-- Use jdk17 by default -->
+        <java.version>17</java.version>
+        <maven.compiler.source>${java.version}</maven.compiler.source>
+        <maven.compiler.target>${java.version}</maven.compiler.target>
       </properties>
     </profile>
 
@@ -605,7 +620,7 @@ under the License.
                   <compilerPlugin>
                     <groupId>org.scalameta</groupId>
                     <artifactId>semanticdb-scalac_${scala.version}</artifactId>
-                    <version>4.7.5</version>
+                    <version>4.8.8</version>
                   </compilerPlugin>
                 </compilerPlugins>
               </configuration>
diff --git a/spark/pom.xml b/spark/pom.xml
index 21fa09fc..84e2e501 100644
--- a/spark/pom.xml
+++ b/spark/pom.xml
@@ -252,6 +252,8 @@ under the License.
               <sources>
                 <source>src/test/${additional.3_3.test.source}</source>
                 <source>src/test/${additional.3_4.test.source}</source>
+                <source>src/test/${shims.majorVerSrc}</source>
+                <source>src/test/${shims.minorVerSrc}</source>
               </sources>
             </configuration>
           </execution>
diff --git 
a/spark/src/main/scala/org/apache/comet/parquet/CometParquetFileFormat.scala 
b/spark/src/main/scala/org/apache/comet/parquet/CometParquetFileFormat.scala
index ac871cf6..52d8d09a 100644
--- a/spark/src/main/scala/org/apache/comet/parquet/CometParquetFileFormat.scala
+++ b/spark/src/main/scala/org/apache/comet/parquet/CometParquetFileFormat.scala
@@ -37,7 +37,6 @@ import 
org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
 import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions
 import org.apache.spark.sql.execution.datasources.parquet.ParquetReadSupport
 import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy
 import org.apache.spark.sql.sources.Filter
 import org.apache.spark.sql.types.{DateType, StructType, TimestampType}
 import org.apache.spark.util.SerializableConfiguration
@@ -144,7 +143,7 @@ class CometParquetFileFormat extends ParquetFileFormat with 
MetricsSupport with
         isCaseSensitive,
         useFieldId,
         ignoreMissingIds,
-        datetimeRebaseSpec.mode == LegacyBehaviorPolicy.CORRECTED,
+        datetimeRebaseSpec.mode == CORRECTED,
         partitionSchema,
         file.partitionValues,
         JavaConverters.mapAsJavaMap(metrics))
@@ -161,7 +160,7 @@ class CometParquetFileFormat extends ParquetFileFormat with 
MetricsSupport with
   }
 }
 
-object CometParquetFileFormat extends Logging {
+object CometParquetFileFormat extends Logging with ShimSQLConf {
 
   /**
    * Populates Parquet related configurations from the input `sqlConf` to the 
`hadoopConf`
@@ -210,7 +209,7 @@ object CometParquetFileFormat extends Logging {
         case _ => false
       })
 
-    if (hasDateOrTimestamp && datetimeRebaseSpec.mode == 
LegacyBehaviorPolicy.LEGACY) {
+    if (hasDateOrTimestamp && datetimeRebaseSpec.mode == LEGACY) {
       if (exceptionOnRebase) {
         logWarning(
           s"""Found Parquet file $file that could potentially contain 
dates/timestamps that were
@@ -222,7 +221,7 @@ object CometParquetFileFormat extends Logging {
               calendar, please disable Comet for this query.""")
       } else {
         // do not throw exception on rebase - read as it is
-        datetimeRebaseSpec = 
datetimeRebaseSpec.copy(LegacyBehaviorPolicy.CORRECTED)
+        datetimeRebaseSpec = datetimeRebaseSpec.copy(CORRECTED)
       }
     }
 
diff --git 
a/spark/src/main/scala/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala
 
b/spark/src/main/scala/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala
index 693af125..e48d7638 100644
--- 
a/spark/src/main/scala/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala
+++ 
b/spark/src/main/scala/org/apache/comet/parquet/CometParquetPartitionReaderFactory.scala
@@ -37,7 +37,6 @@ import 
org.apache.spark.sql.execution.datasources.parquet.ParquetOptions
 import org.apache.spark.sql.execution.datasources.v2.FilePartitionReaderFactory
 import org.apache.spark.sql.execution.metric.SQLMetric
 import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy
 import org.apache.spark.sql.sources.Filter
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.sql.vectorized.ColumnarBatch
@@ -135,7 +134,7 @@ case class CometParquetPartitionReaderFactory(
         isCaseSensitive,
         useFieldId,
         ignoreMissingIds,
-        datetimeRebaseSpec.mode == LegacyBehaviorPolicy.CORRECTED,
+        datetimeRebaseSpec.mode == CORRECTED,
         partitionSchema,
         file.partitionValues,
         JavaConverters.mapAsJavaMap(metrics))
diff --git a/spark/src/main/scala/org/apache/comet/parquet/ParquetFilters.scala 
b/spark/src/main/scala/org/apache/comet/parquet/ParquetFilters.scala
index 5994dfb4..58c2aeb4 100644
--- a/spark/src/main/scala/org/apache/comet/parquet/ParquetFilters.scala
+++ b/spark/src/main/scala/org/apache/comet/parquet/ParquetFilters.scala
@@ -38,11 +38,11 @@ import 
org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._
 import org.apache.parquet.schema.Type.Repetition
 import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, CaseInsensitiveMap, 
DateTimeUtils, IntervalUtils}
 import 
org.apache.spark.sql.catalyst.util.RebaseDateTime.{rebaseGregorianToJulianDays, 
rebaseGregorianToJulianMicros, RebaseSpec}
-import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy
 import org.apache.spark.sql.sources
 import org.apache.spark.unsafe.types.UTF8String
 
 import org.apache.comet.CometSparkSessionExtensions.isSpark34Plus
+import org.apache.comet.shims.ShimSQLConf
 
 /**
  * Copied from Spark 3.2 & 3.4, in order to fix Parquet shading issue. TODO: 
find a way to remove
@@ -58,7 +58,8 @@ class ParquetFilters(
     pushDownStringPredicate: Boolean,
     pushDownInFilterThreshold: Int,
     caseSensitive: Boolean,
-    datetimeRebaseSpec: RebaseSpec) {
+    datetimeRebaseSpec: RebaseSpec)
+    extends ShimSQLConf {
   // A map which contains parquet field name and data type, if predicate push 
down applies.
   //
   // Each key in `nameToParquetField` represents a column; `dots` are used as 
separators for
@@ -153,7 +154,7 @@ class ParquetFilters(
       case ld: LocalDate => DateTimeUtils.localDateToDays(ld)
     }
     datetimeRebaseSpec.mode match {
-      case LegacyBehaviorPolicy.LEGACY => 
rebaseGregorianToJulianDays(gregorianDays)
+      case LEGACY => rebaseGregorianToJulianDays(gregorianDays)
       case _ => gregorianDays
     }
   }
@@ -164,7 +165,7 @@ class ParquetFilters(
       case t: Timestamp => DateTimeUtils.fromJavaTimestamp(t)
     }
     datetimeRebaseSpec.mode match {
-      case LegacyBehaviorPolicy.LEGACY =>
+      case LEGACY =>
         rebaseGregorianToJulianMicros(datetimeRebaseSpec.timeZone, 
gregorianMicros)
       case _ => gregorianMicros
     }
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala 
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 6333650d..a717e066 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -1987,18 +1987,17 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde with CometExprShim
         // With Spark 3.4, CharVarcharCodegenUtils.readSidePadding gets called 
to pad spaces for
         // char types. Use rpad to achieve the behavior.
         // See https://github.com/apache/spark/pull/38151
-        case StaticInvoke(
-              _: Class[CharVarcharCodegenUtils],
-              _: StringType,
-              "readSidePadding",
-              arguments,
-              _,
-              true,
-              false,
-              true) if arguments.size == 2 =>
+        case s: StaticInvoke
+            if s.staticObject.isInstanceOf[Class[CharVarcharCodegenUtils]] &&
+              s.dataType.isInstanceOf[StringType] &&
+              s.functionName == "readSidePadding" &&
+              s.arguments.size == 2 &&
+              s.propagateNull &&
+              !s.returnNullable &&
+              s.isDeterministic =>
           val argsExpr = Seq(
-            exprToProtoInternal(Cast(arguments(0), StringType), inputs),
-            exprToProtoInternal(arguments(1), inputs))
+            exprToProtoInternal(Cast(s.arguments(0), StringType), inputs),
+            exprToProtoInternal(s.arguments(1), inputs))
 
           if (argsExpr.forall(_.isDefined)) {
             val builder = ExprOuterClass.ScalarFunc.newBuilder()
@@ -2007,7 +2006,7 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde with CometExprShim
 
             
Some(ExprOuterClass.Expr.newBuilder().setScalarFunc(builder).build())
           } else {
-            withInfo(expr, arguments: _*)
+            withInfo(expr, s.arguments: _*)
             None
           }
 
diff --git a/spark/src/main/scala/org/apache/spark/Plugins.scala 
b/spark/src/main/scala/org/apache/spark/Plugins.scala
index 97838448..dcc00f66 100644
--- a/spark/src/main/scala/org/apache/spark/Plugins.scala
+++ b/spark/src/main/scala/org/apache/spark/Plugins.scala
@@ -23,9 +23,9 @@ import java.{util => ju}
 import java.util.Collections
 
 import org.apache.spark.api.plugin.{DriverPlugin, ExecutorPlugin, 
PluginContext, SparkPlugin}
+import org.apache.spark.comet.shims.ShimCometDriverPlugin
 import org.apache.spark.internal.Logging
 import org.apache.spark.internal.config.{EXECUTOR_MEMORY, 
EXECUTOR_MEMORY_OVERHEAD}
-import org.apache.spark.resource.ResourceProfile
 
 import org.apache.comet.{CometConf, CometSparkSessionExtensions}
 
@@ -40,9 +40,7 @@ import org.apache.comet.{CometConf, 
CometSparkSessionExtensions}
  *
  * To enable this plugin, set the config "spark.plugins" to 
`org.apache.spark.CometPlugin`.
  */
-class CometDriverPlugin extends DriverPlugin with Logging {
-  import CometDriverPlugin._
-
+class CometDriverPlugin extends DriverPlugin with Logging with 
ShimCometDriverPlugin {
   override def init(sc: SparkContext, pluginContext: PluginContext): 
ju.Map[String, String] = {
     logInfo("CometDriverPlugin init")
 
@@ -52,14 +50,10 @@ class CometDriverPlugin extends DriverPlugin with Logging {
       } else {
         // By default, executorMemory * spark.executor.memoryOverheadFactor, 
with minimum of 384MB
         val executorMemory = sc.getConf.getSizeAsMb(EXECUTOR_MEMORY.key)
-        val memoryOverheadFactor =
-          sc.getConf.getDouble(
-            EXECUTOR_MEMORY_OVERHEAD_FACTOR,
-            EXECUTOR_MEMORY_OVERHEAD_FACTOR_DEFAULT)
-
-        Math.max(
-          (executorMemory * memoryOverheadFactor).toInt,
-          ResourceProfile.MEMORY_OVERHEAD_MIN_MIB)
+        val memoryOverheadFactor = getMemoryOverheadFactor(sc.getConf)
+        val memoryOverheadMinMib = getMemoryOverheadMinMib(sc.getConf)
+
+        Math.max((executorMemory * memoryOverheadFactor).toLong, 
memoryOverheadMinMib)
       }
 
       val cometMemOverhead = 
CometSparkSessionExtensions.getCometMemoryOverheadInMiB(sc.getConf)
@@ -100,12 +94,6 @@ class CometDriverPlugin extends DriverPlugin with Logging {
   }
 }
 
-object CometDriverPlugin {
-  // `org.apache.spark.internal.config.EXECUTOR_MEMORY_OVERHEAD_FACTOR` was 
added since Spark 3.3.0
-  val EXECUTOR_MEMORY_OVERHEAD_FACTOR = "spark.executor.memoryOverheadFactor"
-  val EXECUTOR_MEMORY_OVERHEAD_FACTOR_DEFAULT = 0.1
-}
-
 /**
  * The Comet plugin for Spark. To enable this plugin, set the config 
"spark.plugins" to
  * `org.apache.spark.CometPlugin`
diff --git 
a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala
 
b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala
index 7bd34deb..38247b2c 100644
--- 
a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala
+++ 
b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala
@@ -27,6 +27,7 @@ import scala.concurrent.duration.NANOSECONDS
 import scala.util.control.NonFatal
 
 import org.apache.spark.{broadcast, Partition, SparkContext, TaskContext}
+import org.apache.spark.comet.shims.ShimCometBroadcastExchangeExec
 import org.apache.spark.launcher.SparkLauncher
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
@@ -44,7 +45,6 @@ import org.apache.spark.util.io.ChunkedByteBuffer
 import com.google.common.base.Objects
 
 import org.apache.comet.CometRuntimeException
-import org.apache.comet.shims.ShimCometBroadcastExchangeExec
 
 /**
  * A [[CometBroadcastExchangeExec]] collects, transforms and finally 
broadcasts the result of a
diff --git 
a/spark/src/main/scala/org/apache/spark/sql/comet/CometScanExec.scala 
b/spark/src/main/scala/org/apache/spark/sql/comet/CometScanExec.scala
index 14a66410..9a5b55d6 100644
--- a/spark/src/main/scala/org/apache/spark/sql/comet/CometScanExec.scala
+++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometScanExec.scala
@@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.QueryPlan
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
+import org.apache.spark.sql.comet.shims.ShimCometScanExec
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.datasources._
 import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions
@@ -43,7 +44,7 @@ import org.apache.spark.util.collection._
 
 import org.apache.comet.{CometConf, MetricsSupport}
 import org.apache.comet.parquet.{CometParquetFileFormat, 
CometParquetPartitionReaderFactory}
-import org.apache.comet.shims.{ShimCometScanExec, ShimFileFormat}
+import org.apache.comet.shims.ShimFileFormat
 
 /**
  * Comet physical scan node for DataSource V1. Most of the code here follow 
Spark's
@@ -271,7 +272,7 @@ case class CometScanExec(
       selectedPartitions
         .flatMap { p =>
           p.files.map { f =>
-            PartitionedFileUtil.getPartitionedFile(f, f.getPath, p.values)
+            getPartitionedFile(f, p)
           }
         }
         .groupBy { f =>
@@ -358,7 +359,7 @@ case class CometScanExec(
               // SPARK-39634: Allow file splitting in combination with row 
index generation once
               // the fix for PARQUET-2161 is available.
               !isNeededForSchema(requiredSchema)
-            PartitionedFileUtil.splitFiles(
+            super.splitFiles(
               sparkSession = relation.sparkSession,
               file = file,
               filePath = filePath,
@@ -409,7 +410,7 @@ case class CometScanExec(
         Map.empty)
     } else {
       newFileScanRDD(
-        fsRelation.sparkSession,
+        fsRelation,
         readFile,
         partitions,
         new StructType(requiredSchema.fields ++ 
fsRelation.partitionSchema.fields),
diff --git 
a/spark/src/main/scala/org/apache/spark/sql/comet/DecimalPrecision.scala 
b/spark/src/main/scala/org/apache/spark/sql/comet/DecimalPrecision.scala
index 13f26ce5..a2cdf421 100644
--- a/spark/src/main/scala/org/apache/spark/sql/comet/DecimalPrecision.scala
+++ b/spark/src/main/scala/org/apache/spark/sql/comet/DecimalPrecision.scala
@@ -59,7 +59,7 @@ object DecimalPrecision {
         }
         CheckOverflow(add, resultType, nullOnOverflow)
 
-      case sub @ Subtract(DecimalType.Expression(p1, s1), 
DecimalType.Expression(p2, s2), _) =>
+      case sub @ Subtract(DecimalExpression(p1, s1), DecimalExpression(p2, 
s2), _) =>
         val resultScale = max(s1, s2)
         val resultType = if (allowPrecisionLoss) {
           DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale 
+ 1, resultScale)
@@ -68,7 +68,7 @@ object DecimalPrecision {
         }
         CheckOverflow(sub, resultType, nullOnOverflow)
 
-      case mul @ Multiply(DecimalType.Expression(p1, s1), 
DecimalType.Expression(p2, s2), _) =>
+      case mul @ Multiply(DecimalExpression(p1, s1), DecimalExpression(p2, 
s2), _) =>
         val resultType = if (allowPrecisionLoss) {
           DecimalType.adjustPrecisionScale(p1 + p2 + 1, s1 + s2)
         } else {
@@ -76,7 +76,7 @@ object DecimalPrecision {
         }
         CheckOverflow(mul, resultType, nullOnOverflow)
 
-      case div @ Divide(DecimalType.Expression(p1, s1), 
DecimalType.Expression(p2, s2), _) =>
+      case div @ Divide(DecimalExpression(p1, s1), DecimalExpression(p2, s2), 
_) =>
         val resultType = if (allowPrecisionLoss) {
           // Precision: p1 - s1 + s2 + max(6, s1 + p2 + 1)
           // Scale: max(6, s1 + p2 + 1)
@@ -96,7 +96,7 @@ object DecimalPrecision {
         }
         CheckOverflow(div, resultType, nullOnOverflow)
 
-      case rem @ Remainder(DecimalType.Expression(p1, s1), 
DecimalType.Expression(p2, s2), _) =>
+      case rem @ Remainder(DecimalExpression(p1, s1), DecimalExpression(p2, 
s2), _) =>
         val resultType = if (allowPrecisionLoss) {
           DecimalType.adjustPrecisionScale(min(p1 - s1, p2 - s2) + max(s1, 
s2), max(s1, s2))
         } else {
@@ -108,6 +108,7 @@ object DecimalPrecision {
     }
   }
 
+  // TODO: consider to use `org.apache.spark.sql.types.DecimalExpression` for 
Spark 3.5+
   object DecimalExpression {
     def unapply(e: Expression): Option[(Int, Int)] = e.dataType match {
       case t: DecimalType => Some((t.precision, t.scale))
diff --git 
a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala
 
b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala
index 49c263f3..3f4d7bfd 100644
--- 
a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala
+++ 
b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala
@@ -28,10 +28,10 @@ import scala.concurrent.Future
 
 import org.apache.spark._
 import org.apache.spark.internal.config
-import org.apache.spark.rdd.{MapPartitionsRDD, RDD}
+import org.apache.spark.rdd.RDD
 import org.apache.spark.scheduler.MapStatus
 import org.apache.spark.serializer.Serializer
-import org.apache.spark.shuffle.{IndexShuffleBlockResolver, 
ShuffleWriteMetricsReporter, ShuffleWriteProcessor}
+import org.apache.spark.shuffle.{IndexShuffleBlockResolver, 
ShuffleWriteMetricsReporter}
 import org.apache.spark.shuffle.sort.SortShuffleManager
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, 
UnsafeProjection, UnsafeRow}
@@ -39,12 +39,12 @@ import 
org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering
 import org.apache.spark.sql.catalyst.plans.logical.Statistics
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.comet.{CometExec, CometMetricNode, CometPlan}
+import org.apache.spark.sql.comet.shims.ShimCometShuffleWriteProcessor
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, 
ShuffleExchangeLike, ShuffleOrigin}
 import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
 import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, 
SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter}
 import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.types.StructType
 import org.apache.spark.sql.vectorized.ColumnarBatch
 import org.apache.spark.util.MutablePair
 import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, 
RecordComparator}
@@ -68,7 +68,8 @@ case class CometShuffleExchangeExec(
     shuffleType: ShuffleType = CometNativeShuffle,
     advisoryPartitionSize: Option[Long] = None)
     extends ShuffleExchangeLike
-    with CometPlan {
+    with CometPlan
+    with ShimCometShuffleExchangeExec {
 
   private lazy val writeMetrics =
     SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext)
@@ -127,6 +128,9 @@ case class CometShuffleExchangeExec(
     Statistics(dataSize, Some(rowCount))
   }
 
+  // TODO: add `override` keyword after dropping Spark-3.x supports
+  def shuffleId: Int = getShuffleId(shuffleDependency)
+
   /**
    * A [[ShuffleDependency]] that will partition rows of its child based on 
the partitioning
    * scheme defined in `newPartitioning`. Those partitions of the returned 
ShuffleDependency will
@@ -386,7 +390,7 @@ object CometShuffleExchangeExec extends 
ShimCometShuffleExchangeExec {
           val pageSize = SparkEnv.get.memoryManager.pageSizeBytes
 
           val sorter = UnsafeExternalRowSorter.createWithRecordComparator(
-            StructType.fromAttributes(outputAttributes),
+            fromAttributes(outputAttributes),
             recordComparatorSupplier,
             prefixComparator,
             prefixComputer,
@@ -430,7 +434,7 @@ object CometShuffleExchangeExec extends 
ShimCometShuffleExchangeExec {
         serializer,
         shuffleWriterProcessor = 
ShuffleExchangeExec.createShuffleWriteProcessor(writeMetrics),
         shuffleType = CometColumnarShuffle,
-        schema = Some(StructType.fromAttributes(outputAttributes)))
+        schema = Some(fromAttributes(outputAttributes)))
 
     dependency
   }
@@ -445,7 +449,7 @@ class CometShuffleWriteProcessor(
     outputPartitioning: Partitioning,
     outputAttributes: Seq[Attribute],
     metrics: Map[String, SQLMetric])
-    extends ShuffleWriteProcessor {
+    extends ShimCometShuffleWriteProcessor {
 
   private val OFFSET_LENGTH = 8
 
@@ -455,11 +459,11 @@ class CometShuffleWriteProcessor(
   }
 
   override def write(
-      rdd: RDD[_],
+      inputs: Iterator[_],
       dep: ShuffleDependency[_, _, _],
       mapId: Long,
-      context: TaskContext,
-      partition: Partition): MapStatus = {
+      mapIndex: Int,
+      context: TaskContext): MapStatus = {
     val shuffleBlockResolver =
       
SparkEnv.get.shuffleManager.shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver]
     val dataFile = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
@@ -469,10 +473,6 @@ class CometShuffleWriteProcessor(
     val tempDataFilePath = Paths.get(tempDataFilename)
     val tempIndexFilePath = Paths.get(tempIndexFilename)
 
-    // Getting rid of the fake partitionId
-    val cometRDD =
-      rdd.asInstanceOf[MapPartitionsRDD[_, 
_]].prev.asInstanceOf[RDD[ColumnarBatch]]
-
     // Call native shuffle write
     val nativePlan = getNativePlan(tempDataFilename, tempIndexFilename)
 
@@ -482,8 +482,12 @@ class CometShuffleWriteProcessor(
       "elapsed_compute" -> metrics("shuffleReadElapsedCompute"))
     val nativeMetrics = CometMetricNode(nativeSQLMetrics)
 
-    val rawIter = cometRDD.iterator(partition, context)
-    val cometIter = CometExec.getCometIterator(Seq(rawIter), nativePlan, 
nativeMetrics)
+    // Getting rid of the fake partitionId
+    val newInputs = inputs.asInstanceOf[Iterator[_ <: Product2[Any, 
Any]]].map(_._2)
+    val cometIter = CometExec.getCometIterator(
+      Seq(newInputs.asInstanceOf[Iterator[ColumnarBatch]]),
+      nativePlan,
+      nativeMetrics)
 
     while (cometIter.hasNext) {
       cometIter.next()
diff --git 
a/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala 
b/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala
index 0c45a9c2..f5a578f8 100644
--- a/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala
+++ b/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala
@@ -25,8 +25,8 @@ import org.apache.spark.sql.catalyst.expressions._
  */
 trait CometExprShim {
     /**
-      * Returns a tuple of expressions for the `unhex` function.
-      */
+     * Returns a tuple of expressions for the `unhex` function.
+     */
     def unhexSerde(unhex: Unhex): (Expression, Expression) = {
         (unhex.child, Literal(false))
     }
diff --git 
a/spark/src/main/spark-3.3/org/apache/comet/shims/CometExprShim.scala 
b/spark/src/main/spark-3.3/org/apache/comet/shims/CometExprShim.scala
index 0c45a9c2..f5a578f8 100644
--- a/spark/src/main/spark-3.3/org/apache/comet/shims/CometExprShim.scala
+++ b/spark/src/main/spark-3.3/org/apache/comet/shims/CometExprShim.scala
@@ -25,8 +25,8 @@ import org.apache.spark.sql.catalyst.expressions._
  */
 trait CometExprShim {
     /**
-      * Returns a tuple of expressions for the `unhex` function.
-      */
+     * Returns a tuple of expressions for the `unhex` function.
+     */
     def unhexSerde(unhex: Unhex): (Expression, Expression) = {
         (unhex.child, Literal(false))
     }
diff --git 
a/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala 
b/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala
index 409e1c94..3f2301f0 100644
--- a/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala
+++ b/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala
@@ -25,8 +25,8 @@ import org.apache.spark.sql.catalyst.expressions._
  */
 trait CometExprShim {
     /**
-      * Returns a tuple of expressions for the `unhex` function.
-      */
+     * Returns a tuple of expressions for the `unhex` function.
+     */
     def unhexSerde(unhex: Unhex): (Expression, Expression) = {
         (unhex.child, Literal(unhex.failOnError))
     }
diff --git 
a/spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometShuffleExchangeExec.scala
 
b/spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometShuffleExchangeExec.scala
index 6b4fad97..350aeb9f 100644
--- 
a/spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometShuffleExchangeExec.scala
+++ 
b/spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometShuffleExchangeExec.scala
@@ -19,8 +19,11 @@
 
 package org.apache.comet.shims
 
+import org.apache.spark.ShuffleDependency
+import org.apache.spark.sql.catalyst.expressions.Attribute
 import org.apache.spark.sql.comet.execution.shuffle.{CometShuffleExchangeExec, 
ShuffleType}
 import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
+import org.apache.spark.sql.types.{StructField, StructType}
 
 trait ShimCometShuffleExchangeExec {
   // TODO: remove after dropping Spark 3.2 and 3.3 support
@@ -37,4 +40,11 @@ trait ShimCometShuffleExchangeExec {
       shuffleType,
       advisoryPartitionSize)
   }
+
+  // TODO: remove after dropping Spark 3.x support
+  protected def fromAttributes(attributes: Seq[Attribute]): StructType =
+    StructType(attributes.map(a => StructField(a.name, a.dataType, a.nullable, 
a.metadata)))
+
+  // TODO: remove after dropping Spark 3.x support
+  protected def getShuffleId(shuffleDependency: ShuffleDependency[Int, _, _]): 
Int = 0
 }
diff --git 
a/spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala
 
b/spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala
index eb04c68a..37748533 100644
--- 
a/spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala
+++ 
b/spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala
@@ -24,8 +24,6 @@ import org.apache.spark.sql.execution.{LimitExec, 
QueryExecution, SparkPlan}
 import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
 
 trait ShimCometSparkSessionExtensions {
-  import org.apache.comet.shims.ShimCometSparkSessionExtensions._
-
   /**
    * TODO: delete after dropping Spark 3.2.0 support and directly call 
scan.pushedAggregate
    */
@@ -45,9 +43,7 @@ trait ShimCometSparkSessionExtensions {
    *       SQLConf.EXTENDED_EXPLAIN_PROVIDERS.key
    */
   protected val EXTENDED_EXPLAIN_PROVIDERS_KEY = 
"spark.sql.extendedExplainProviders"
-}
 
-object ShimCometSparkSessionExtensions {
   private def getOffsetOpt(plan: SparkPlan): Option[Int] = 
plan.getClass.getDeclaredFields
     .filter(_.getName == "offset")
     .map { a => a.setAccessible(true); a.get(plan) }
diff --git a/spark/src/main/spark-3.x/org/apache/comet/shims/ShimSQLConf.scala 
b/spark/src/main/spark-3.x/org/apache/comet/shims/ShimSQLConf.scala
index ff60ef96..c3d0c56e 100644
--- a/spark/src/main/spark-3.x/org/apache/comet/shims/ShimSQLConf.scala
+++ b/spark/src/main/spark-3.x/org/apache/comet/shims/ShimSQLConf.scala
@@ -20,6 +20,7 @@
 package org.apache.comet.shims
 
 import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy
 
 trait ShimSQLConf {
 
@@ -39,4 +40,7 @@ trait ShimSQLConf {
           case _ => None
         })
       .head
+
+  protected val LEGACY = LegacyBehaviorPolicy.LEGACY
+  protected val CORRECTED = LegacyBehaviorPolicy.CORRECTED
 }
diff --git 
a/spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometBroadcastExchangeExec.scala
 
b/spark/src/main/spark-3.x/org/apache/spark/comet/shims/ShimCometBroadcastExchangeExec.scala
similarity index 98%
rename from 
spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometBroadcastExchangeExec.scala
rename to 
spark/src/main/spark-3.x/org/apache/spark/comet/shims/ShimCometBroadcastExchangeExec.scala
index 63ff2a2c..aede4795 100644
--- 
a/spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometBroadcastExchangeExec.scala
+++ 
b/spark/src/main/spark-3.x/org/apache/spark/comet/shims/ShimCometBroadcastExchangeExec.scala
@@ -17,7 +17,7 @@
  * under the License.
  */
 
-package org.apache.comet.shims
+package org.apache.spark.comet.shims
 
 import scala.reflect.ClassTag
 
diff --git 
a/spark/src/main/spark-3.x/org/apache/spark/comet/shims/ShimCometDriverPlugin.scala
 
b/spark/src/main/spark-3.x/org/apache/spark/comet/shims/ShimCometDriverPlugin.scala
new file mode 100644
index 00000000..cfb6a008
--- /dev/null
+++ 
b/spark/src/main/spark-3.x/org/apache/spark/comet/shims/ShimCometDriverPlugin.scala
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.comet.shims
+
+import org.apache.spark.SparkConf
+
+trait ShimCometDriverPlugin {
+  // `org.apache.spark.internal.config.EXECUTOR_MEMORY_OVERHEAD_FACTOR` was 
added since Spark 3.3.0
+  private val EXECUTOR_MEMORY_OVERHEAD_FACTOR = 
"spark.executor.memoryOverheadFactor"
+  private val EXECUTOR_MEMORY_OVERHEAD_FACTOR_DEFAULT = 0.1
+  // `org.apache.spark.internal.config.EXECUTOR_MIN_MEMORY_OVERHEAD` was added 
since Spark 4.0.0
+  private val EXECUTOR_MIN_MEMORY_OVERHEAD = "spark.executor.minMemoryOverhead"
+  private val EXECUTOR_MIN_MEMORY_OVERHEAD_DEFAULT = 384L
+
+  def getMemoryOverheadFactor(sc: SparkConf): Double =
+    sc.getDouble(
+      EXECUTOR_MEMORY_OVERHEAD_FACTOR,
+      EXECUTOR_MEMORY_OVERHEAD_FACTOR_DEFAULT)
+  def getMemoryOverheadMinMib(sc: SparkConf): Long =
+    sc.getLong(EXECUTOR_MIN_MEMORY_OVERHEAD, 
EXECUTOR_MIN_MEMORY_OVERHEAD_DEFAULT)
+}
diff --git 
a/spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometScanExec.scala 
b/spark/src/main/spark-3.x/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala
similarity index 81%
rename from 
spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometScanExec.scala
rename to 
spark/src/main/spark-3.x/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala
index 544a6738..02b97f9f 100644
--- a/spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometScanExec.scala
+++ 
b/spark/src/main/spark-3.x/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala
@@ -17,17 +17,21 @@
  * under the License.
  */
 
-package org.apache.comet.shims
+package org.apache.spark.sql.comet.shims
+
+import org.apache.comet.shims.ShimFileFormat
 
 import scala.language.implicitConversions
 
+import org.apache.hadoop.fs.{FileStatus, Path}
+
 import org.apache.spark.{SparkContext, SparkException}
 import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.AttributeReference
 import org.apache.spark.sql.connector.read.{InputPartition, 
PartitionReaderFactory}
-import org.apache.spark.sql.execution.FileSourceScanExec
-import org.apache.spark.sql.execution.datasources.{FilePartition, FileScanRDD, 
PartitionedFile}
+import org.apache.spark.sql.execution.{FileSourceScanExec, PartitionedFileUtil}
+import org.apache.spark.sql.execution.datasources.{FilePartition, FileScanRDD, 
HadoopFsRelation, PartitionDirectory, PartitionedFile}
 import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions
 import org.apache.spark.sql.execution.datasources.v2.DataSourceRDD
 import org.apache.spark.sql.execution.metric.SQLMetric
@@ -63,7 +67,7 @@ trait ShimCometScanExec {
 
   // TODO: remove after dropping Spark 3.2 support and directly call new 
FileScanRDD
   protected def newFileScanRDD(
-      sparkSession: SparkSession,
+      fsRelation: HadoopFsRelation,
       readFunction: PartitionedFile => Iterator[InternalRow],
       filePartitions: Seq[FilePartition],
       readSchema: StructType,
@@ -73,12 +77,12 @@ trait ShimCometScanExec {
       .filter(c => List(3, 5, 6).contains(c.getParameterCount()) )
       .map { c =>
         c.getParameterCount match {
-          case 3 => c.newInstance(sparkSession, readFunction, filePartitions)
+          case 3 => c.newInstance(fsRelation.sparkSession, readFunction, 
filePartitions)
           case 5 =>
-            c.newInstance(sparkSession, readFunction, filePartitions, 
readSchema, metadataColumns)
+            c.newInstance(fsRelation.sparkSession, readFunction, 
filePartitions, readSchema, metadataColumns)
           case 6 =>
             c.newInstance(
-              sparkSession,
+              fsRelation.sparkSession,
               readFunction,
               filePartitions,
               readSchema,
@@ -123,4 +127,15 @@ trait ShimCometScanExec {
   protected def isNeededForSchema(sparkSchema: StructType): Boolean = {
     findRowIndexColumnIndexInSchema(sparkSchema) >= 0
   }
+
+  protected def getPartitionedFile(f: FileStatus, p: PartitionDirectory): 
PartitionedFile =
+    PartitionedFileUtil.getPartitionedFile(f, f.getPath, p.values)
+
+  protected def splitFiles(sparkSession: SparkSession,
+                           file: FileStatus,
+                           filePath: Path,
+                           isSplitable: Boolean,
+                           maxSplitBytes: Long,
+                           partitionValues: InternalRow): Seq[PartitionedFile] 
=
+    PartitionedFileUtil.splitFiles(sparkSession, file, filePath, isSplitable, 
maxSplitBytes, partitionValues)
 }
diff --git 
a/spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometShuffleExchangeExec.scala
 
b/spark/src/main/spark-3.x/org/apache/spark/sql/comet/shims/ShimCometShuffleWriteProcessor.scala
similarity index 52%
copy from 
spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometShuffleExchangeExec.scala
copy to 
spark/src/main/spark-3.x/org/apache/spark/sql/comet/shims/ShimCometShuffleWriteProcessor.scala
index 6b4fad97..9100b90c 100644
--- 
a/spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometShuffleExchangeExec.scala
+++ 
b/spark/src/main/spark-3.x/org/apache/spark/sql/comet/shims/ShimCometShuffleWriteProcessor.scala
@@ -17,24 +17,28 @@
  * under the License.
  */
 
-package org.apache.comet.shims
+package org.apache.spark.sql.comet.shims
 
-import org.apache.spark.sql.comet.execution.shuffle.{CometShuffleExchangeExec, 
ShuffleType}
-import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
+import org.apache.spark.{Partition, ShuffleDependency, TaskContext}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.scheduler.MapStatus
+import org.apache.spark.shuffle.ShuffleWriteProcessor
 
-trait ShimCometShuffleExchangeExec {
-  // TODO: remove after dropping Spark 3.2 and 3.3 support
-  def apply(s: ShuffleExchangeExec, shuffleType: ShuffleType): 
CometShuffleExchangeExec = {
-    val advisoryPartitionSize = s.getClass.getDeclaredMethods
-      .filter(_.getName == "advisoryPartitionSize")
-      .flatMap(_.invoke(s).asInstanceOf[Option[Long]])
-      .headOption
-    CometShuffleExchangeExec(
-      s.outputPartitioning,
-      s.child,
-      s,
-      s.shuffleOrigin,
-      shuffleType,
-      advisoryPartitionSize)
+trait ShimCometShuffleWriteProcessor extends ShuffleWriteProcessor {
+  override def write(
+      rdd: RDD[_],
+      dep: ShuffleDependency[_, _, _],
+      mapId: Long,
+      context: TaskContext,
+      partition: Partition): MapStatus = {
+    val rawIter = rdd.iterator(partition, context)
+    write(rawIter, dep, mapId, partition.index, context)
   }
+
+  def write(
+    inputs: Iterator[_],
+    dep: ShuffleDependency[_, _, _],
+    mapId: Long,
+    mapIndex: Int,
+    context: TaskContext): MapStatus
 }
diff --git 
a/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala 
b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala
similarity index 88%
copy from spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala
copy to spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala
index 409e1c94..01f92320 100644
--- a/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala
+++ b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala
@@ -25,9 +25,9 @@ import org.apache.spark.sql.catalyst.expressions._
  */
 trait CometExprShim {
     /**
-      * Returns a tuple of expressions for the `unhex` function.
-      */
-    def unhexSerde(unhex: Unhex): (Expression, Expression) = {
+     * Returns a tuple of expressions for the `unhex` function.
+     */
+    protected def unhexSerde(unhex: Unhex): (Expression, Expression) = {
         (unhex.child, Literal(unhex.failOnError))
     }
 }
diff --git 
a/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala 
b/spark/src/main/spark-4.0/org/apache/comet/shims/ShimCometBatchScanExec.scala
similarity index 63%
copy from spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala
copy to 
spark/src/main/spark-4.0/org/apache/comet/shims/ShimCometBatchScanExec.scala
index 0c45a9c2..167b539f 100644
--- a/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala
+++ 
b/spark/src/main/spark-4.0/org/apache/comet/shims/ShimCometBatchScanExec.scala
@@ -16,18 +16,19 @@
  * specific language governing permissions and limitations
  * under the License.
  */
+
 package org.apache.comet.shims
 
-import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.{Expression, SortOrder}
+import org.apache.spark.sql.connector.read.InputPartition
+import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
 
-/**
- * `CometExprShim` acts as a shim for for parsing expressions from different 
Spark versions.
- */
-trait CometExprShim {
-    /**
-      * Returns a tuple of expressions for the `unhex` function.
-      */
-    def unhexSerde(unhex: Unhex): (Expression, Expression) = {
-        (unhex.child, Literal(false))
-    }
+trait ShimCometBatchScanExec {
+  def wrapped: BatchScanExec
+
+  def keyGroupedPartitioning: Option[Seq[Expression]] = 
wrapped.keyGroupedPartitioning
+
+  def inputPartitions: Seq[InputPartition] = wrapped.inputPartitions
+
+  def ordering: Option[Seq[SortOrder]] = wrapped.ordering
 }
diff --git 
a/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala 
b/spark/src/main/spark-4.0/org/apache/comet/shims/ShimCometBroadcastHashJoinExec.scala
similarity index 68%
copy from spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala
copy to 
spark/src/main/spark-4.0/org/apache/comet/shims/ShimCometBroadcastHashJoinExec.scala
index 0c45a9c2..1f689b40 100644
--- a/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala
+++ 
b/spark/src/main/spark-4.0/org/apache/comet/shims/ShimCometBroadcastHashJoinExec.scala
@@ -16,18 +16,16 @@
  * specific language governing permissions and limitations
  * under the License.
  */
+
 package org.apache.comet.shims
 
-import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioningLike, 
Partitioning}
 
-/**
- * `CometExprShim` acts as a shim for for parsing expressions from different 
Spark versions.
- */
-trait CometExprShim {
-    /**
-      * Returns a tuple of expressions for the `unhex` function.
-      */
-    def unhexSerde(unhex: Unhex): (Expression, Expression) = {
-        (unhex.child, Literal(false))
+trait ShimCometBroadcastHashJoinExec {
+  protected def getHashPartitioningLikeExpressions(partitioning: 
Partitioning): Seq[Expression] =
+    partitioning match {
+      case p: HashPartitioningLike => p.expressions
+      case _ => Seq()
     }
 }
diff --git 
a/spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometShuffleExchangeExec.scala
 
b/spark/src/main/spark-4.0/org/apache/comet/shims/ShimCometShuffleExchangeExec.scala
similarity index 73%
copy from 
spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometShuffleExchangeExec.scala
copy to 
spark/src/main/spark-4.0/org/apache/comet/shims/ShimCometShuffleExchangeExec.scala
index 6b4fad97..559e327b 100644
--- 
a/spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometShuffleExchangeExec.scala
+++ 
b/spark/src/main/spark-4.0/org/apache/comet/shims/ShimCometShuffleExchangeExec.scala
@@ -19,22 +19,25 @@
 
 package org.apache.comet.shims
 
+import org.apache.spark.ShuffleDependency
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.types.DataTypeUtils
 import org.apache.spark.sql.comet.execution.shuffle.{CometShuffleExchangeExec, 
ShuffleType}
 import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
+import org.apache.spark.sql.types.StructType
 
 trait ShimCometShuffleExchangeExec {
-  // TODO: remove after dropping Spark 3.2 and 3.3 support
   def apply(s: ShuffleExchangeExec, shuffleType: ShuffleType): 
CometShuffleExchangeExec = {
-    val advisoryPartitionSize = s.getClass.getDeclaredMethods
-      .filter(_.getName == "advisoryPartitionSize")
-      .flatMap(_.invoke(s).asInstanceOf[Option[Long]])
-      .headOption
     CometShuffleExchangeExec(
       s.outputPartitioning,
       s.child,
       s,
       s.shuffleOrigin,
       shuffleType,
-      advisoryPartitionSize)
+      s.advisoryPartitionSize)
   }
+
+  protected def fromAttributes(attributes: Seq[Attribute]): StructType = 
DataTypeUtils.fromAttributes(attributes)
+
+  protected def getShuffleId(shuffleDependency: ShuffleDependency[Int, _, _]): 
Int = shuffleDependency.shuffleId
 }
diff --git a/spark/src/main/spark-3.x/org/apache/comet/shims/ShimSQLConf.scala 
b/spark/src/main/spark-4.0/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala
similarity index 54%
copy from spark/src/main/spark-3.x/org/apache/comet/shims/ShimSQLConf.scala
copy to 
spark/src/main/spark-4.0/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala
index ff60ef96..9fb7355e 100644
--- a/spark/src/main/spark-3.x/org/apache/comet/shims/ShimSQLConf.scala
+++ 
b/spark/src/main/spark-4.0/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala
@@ -19,24 +19,19 @@
 
 package org.apache.comet.shims
 
+import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
+import org.apache.spark.sql.execution.{CollectLimitExec, GlobalLimitExec, 
LocalLimitExec, QueryExecution}
+import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
 import org.apache.spark.sql.internal.SQLConf
 
-trait ShimSQLConf {
+trait ShimCometSparkSessionExtensions {
+  protected def getPushedAggregate(scan: ParquetScan): Option[Aggregation] = 
scan.pushedAggregate
 
-  /**
-   * Spark 3.4 renamed parquetFilterPushDownStringStartWith to
-   * parquetFilterPushDownStringPredicate
-   *
-   * TODO: delete after dropping Spark 3.2 & 3.3 support and simply use
-   * parquetFilterPushDownStringPredicate
-   */
-  protected def getPushDownStringPredicate(sqlConf: SQLConf): Boolean =
-    sqlConf.getClass.getMethods
-      .flatMap(m =>
-        m.getName match {
-          case "parquetFilterPushDownStringStartWith" | 
"parquetFilterPushDownStringPredicate" =>
-            Some(m.invoke(sqlConf).asInstanceOf[Boolean])
-          case _ => None
-        })
-      .head
+  protected def getOffset(limit: LocalLimitExec): Int = 0
+  protected def getOffset(limit: GlobalLimitExec): Int = limit.offset
+  protected def getOffset(limit: CollectLimitExec): Int = limit.offset
+
+  protected def supportsExtendedExplainInfo(qe: QueryExecution): Boolean = true
+
+  protected val EXTENDED_EXPLAIN_PROVIDERS_KEY = 
SQLConf.EXTENDED_EXPLAIN_PROVIDERS.key
 }
diff --git 
a/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala 
b/spark/src/main/spark-4.0/org/apache/comet/shims/ShimCometTakeOrderedAndProjectExec.scala
similarity index 69%
copy from spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala
copy to 
spark/src/main/spark-4.0/org/apache/comet/shims/ShimCometTakeOrderedAndProjectExec.scala
index 0c45a9c2..5a8ac97b 100644
--- a/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala
+++ 
b/spark/src/main/spark-4.0/org/apache/comet/shims/ShimCometTakeOrderedAndProjectExec.scala
@@ -16,18 +16,11 @@
  * specific language governing permissions and limitations
  * under the License.
  */
+
 package org.apache.comet.shims
 
-import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.execution.TakeOrderedAndProjectExec
 
-/**
- * `CometExprShim` acts as a shim for for parsing expressions from different 
Spark versions.
- */
-trait CometExprShim {
-    /**
-      * Returns a tuple of expressions for the `unhex` function.
-      */
-    def unhexSerde(unhex: Unhex): (Expression, Expression) = {
-        (unhex.child, Literal(false))
-    }
+trait ShimCometTakeOrderedAndProjectExec {
+  protected def getOffset(plan: TakeOrderedAndProjectExec): Option[Int] = 
Some(plan.offset)
 }
diff --git 
a/spark/src/main/spark-4.0/org/apache/comet/shims/ShimQueryPlanSerde.scala 
b/spark/src/main/spark-4.0/org/apache/comet/shims/ShimQueryPlanSerde.scala
new file mode 100644
index 00000000..4d261f3c
--- /dev/null
+++ b/spark/src/main/spark-4.0/org/apache/comet/shims/ShimQueryPlanSerde.scala
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.shims
+
+import org.apache.spark.sql.catalyst.expressions.{BinaryArithmetic, 
BinaryExpression, BloomFilterMightContain, EvalMode}
+import org.apache.spark.sql.catalyst.expressions.aggregate.{Average, Sum}
+
+trait ShimQueryPlanSerde {
+  protected def getFailOnError(b: BinaryArithmetic): Boolean =
+    b.getClass.getMethod("failOnError").invoke(b).asInstanceOf[Boolean]
+
+  protected def getFailOnError(aggregate: Sum): Boolean = 
aggregate.initQueryContext().isDefined
+  protected def getFailOnError(aggregate: Average): Boolean = 
aggregate.initQueryContext().isDefined
+
+  protected def isLegacyMode(aggregate: Sum): Boolean = 
aggregate.evalMode.equals(EvalMode.LEGACY)
+  protected def isLegacyMode(aggregate: Average): Boolean = 
aggregate.evalMode.equals(EvalMode.LEGACY)
+
+  protected def isBloomFilterMightContain(binary: BinaryExpression): Boolean =
+    binary.isInstanceOf[BloomFilterMightContain]
+}
diff --git 
a/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala 
b/spark/src/main/spark-4.0/org/apache/comet/shims/ShimSQLConf.scala
similarity index 69%
copy from spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala
copy to spark/src/main/spark-4.0/org/apache/comet/shims/ShimSQLConf.scala
index 0c45a9c2..57496776 100644
--- a/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala
+++ b/spark/src/main/spark-4.0/org/apache/comet/shims/ShimSQLConf.scala
@@ -16,18 +16,16 @@
  * specific language governing permissions and limitations
  * under the License.
  */
+
 package org.apache.comet.shims
 
-import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.LegacyBehaviorPolicy
 
-/**
- * `CometExprShim` acts as a shim for for parsing expressions from different 
Spark versions.
- */
-trait CometExprShim {
-    /**
-      * Returns a tuple of expressions for the `unhex` function.
-      */
-    def unhexSerde(unhex: Unhex): (Expression, Expression) = {
-        (unhex.child, Literal(false))
-    }
+trait ShimSQLConf {
+  protected def getPushDownStringPredicate(sqlConf: SQLConf): Boolean =
+    sqlConf.parquetFilterPushDownStringPredicate
+
+  protected val LEGACY = LegacyBehaviorPolicy.LEGACY
+  protected val CORRECTED = LegacyBehaviorPolicy.CORRECTED
 }
diff --git 
a/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala 
b/spark/src/main/spark-4.0/org/apache/spark/comet/shims/ShimCometBroadcastExchangeExec.scala
similarity index 67%
copy from spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala
copy to 
spark/src/main/spark-4.0/org/apache/spark/comet/shims/ShimCometBroadcastExchangeExec.scala
index 0c45a9c2..ba87a251 100644
--- a/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala
+++ 
b/spark/src/main/spark-4.0/org/apache/spark/comet/shims/ShimCometBroadcastExchangeExec.scala
@@ -16,18 +16,15 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-package org.apache.comet.shims
 
-import org.apache.spark.sql.catalyst.expressions._
+package org.apache.spark.comet.shims
 
-/**
- * `CometExprShim` acts as a shim for for parsing expressions from different 
Spark versions.
- */
-trait CometExprShim {
-    /**
-      * Returns a tuple of expressions for the `unhex` function.
-      */
-    def unhexSerde(unhex: Unhex): (Expression, Expression) = {
-        (unhex.child, Literal(false))
-    }
+import scala.reflect.ClassTag
+
+import org.apache.spark.SparkContext
+import org.apache.spark.broadcast.Broadcast
+
+trait ShimCometBroadcastExchangeExec {
+  protected def doBroadcast[T: ClassTag](sparkContext: SparkContext, value: 
T): Broadcast[Any] =
+    sparkContext.broadcastInternal(value, true)
 }
diff --git 
a/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala 
b/spark/src/main/spark-4.0/org/apache/spark/comet/shims/ShimCometDriverPlugin.scala
similarity index 62%
copy from spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala
copy to 
spark/src/main/spark-4.0/org/apache/spark/comet/shims/ShimCometDriverPlugin.scala
index 0c45a9c2..f7a57a64 100644
--- a/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala
+++ 
b/spark/src/main/spark-4.0/org/apache/spark/comet/shims/ShimCometDriverPlugin.scala
@@ -16,18 +16,17 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-package org.apache.comet.shims
 
-import org.apache.spark.sql.catalyst.expressions._
+package org.apache.spark.comet.shims
 
-/**
- * `CometExprShim` acts as a shim for for parsing expressions from different 
Spark versions.
- */
-trait CometExprShim {
-    /**
-      * Returns a tuple of expressions for the `unhex` function.
-      */
-    def unhexSerde(unhex: Unhex): (Expression, Expression) = {
-        (unhex.child, Literal(false))
-    }
+import org.apache.spark.SparkConf
+import org.apache.spark.internal.config.EXECUTOR_MEMORY_OVERHEAD_FACTOR
+import org.apache.spark.internal.config.EXECUTOR_MIN_MEMORY_OVERHEAD
+
+trait ShimCometDriverPlugin {
+  protected def getMemoryOverheadFactor(sparkConf: SparkConf): Double = 
sparkConf.get(
+    EXECUTOR_MEMORY_OVERHEAD_FACTOR)
+
+  protected def getMemoryOverheadMinMib(sparkConf: SparkConf): Long = 
sparkConf.get(
+    EXECUTOR_MIN_MEMORY_OVERHEAD)
 }
diff --git 
a/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala
 
b/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala
new file mode 100644
index 00000000..543116c1
--- /dev/null
+++ 
b/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.comet.shims
+
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.AttributeReference
+import org.apache.spark.sql.connector.read.{InputPartition, 
PartitionReaderFactory}
+import org.apache.spark.sql.errors.QueryExecutionErrors
+import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions
+import org.apache.spark.sql.execution.datasources.v2.DataSourceRDD
+import org.apache.spark.sql.execution.datasources._
+import org.apache.spark.sql.execution.metric.SQLMetric
+import org.apache.spark.sql.execution.{FileSourceScanExec, PartitionedFileUtil}
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.SparkContext
+
+trait ShimCometScanExec {
+  def wrapped: FileSourceScanExec
+
+  lazy val fileConstantMetadataColumns: Seq[AttributeReference] =
+    wrapped.fileConstantMetadataColumns
+
+  protected def newDataSourceRDD(
+      sc: SparkContext,
+      inputPartitions: Seq[Seq[InputPartition]],
+      partitionReaderFactory: PartitionReaderFactory,
+      columnarReads: Boolean,
+      customMetrics: Map[String, SQLMetric]): DataSourceRDD =
+    new DataSourceRDD(sc, inputPartitions, partitionReaderFactory, 
columnarReads, customMetrics)
+
+  protected def newFileScanRDD(
+      fsRelation: HadoopFsRelation,
+      readFunction: PartitionedFile => Iterator[InternalRow],
+      filePartitions: Seq[FilePartition],
+      readSchema: StructType,
+      options: ParquetOptions): FileScanRDD = {
+    new FileScanRDD(
+      fsRelation.sparkSession,
+      readFunction,
+      filePartitions,
+      readSchema,
+      fileConstantMetadataColumns,
+      fsRelation.fileFormat.fileConstantMetadataExtractors,
+      options)
+  }
+
+  protected def invalidBucketFile(path: String, sparkVersion: String): 
Throwable =
+    QueryExecutionErrors.invalidBucketFile(path)
+
+  // see SPARK-39634
+  protected def isNeededForSchema(sparkSchema: StructType): Boolean = false
+
+  protected def getPartitionedFile(f: FileStatusWithMetadata, p: 
PartitionDirectory): PartitionedFile =
+    PartitionedFileUtil.getPartitionedFile(f, p.values, 0, f.getLen)
+
+  protected def splitFiles(sparkSession: SparkSession,
+                           file: FileStatusWithMetadata,
+                           filePath: Path,
+                           isSplitable: Boolean,
+                           maxSplitBytes: Long,
+                           partitionValues: InternalRow): Seq[PartitionedFile] 
=
+    PartitionedFileUtil.splitFiles(file, isSplitable, maxSplitBytes, 
partitionValues)
+}
diff --git 
a/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala 
b/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimCometShuffleWriteProcessor.scala
similarity index 67%
copy from spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala
copy to 
spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimCometShuffleWriteProcessor.scala
index 0c45a9c2..f875e3f3 100644
--- a/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala
+++ 
b/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimCometShuffleWriteProcessor.scala
@@ -16,18 +16,11 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-package org.apache.comet.shims
 
-import org.apache.spark.sql.catalyst.expressions._
+package org.apache.spark.sql.comet.shims
+
+import org.apache.spark.shuffle.ShuffleWriteProcessor
+
+trait ShimCometShuffleWriteProcessor extends ShuffleWriteProcessor {
 
-/**
- * `CometExprShim` acts as a shim for for parsing expressions from different 
Spark versions.
- */
-trait CometExprShim {
-    /**
-      * Returns a tuple of expressions for the `unhex` function.
-      */
-    def unhexSerde(unhex: Unhex): (Expression, Expression) = {
-        (unhex.child, Literal(false))
-    }
 }
diff --git 
a/spark/src/test/scala/org/apache/spark/sql/CometTPCDSQuerySuite.scala 
b/spark/src/test/scala/org/apache/spark/sql/CometTPCDSQuerySuite.scala
index 1357d654..53186b13 100644
--- a/spark/src/test/scala/org/apache/spark/sql/CometTPCDSQuerySuite.scala
+++ b/spark/src/test/scala/org/apache/spark/sql/CometTPCDSQuerySuite.scala
@@ -21,13 +21,13 @@ package org.apache.spark.sql
 
 import org.apache.spark.SparkConf
 import org.apache.spark.internal.config.{MEMORY_OFFHEAP_ENABLED, 
MEMORY_OFFHEAP_SIZE}
+import org.apache.spark.sql.comet.shims.ShimCometTPCDSQuerySuite
 
 import org.apache.comet.CometConf
 
 class CometTPCDSQuerySuite
     extends {
-      // This is private in `TPCDSBase`.
-      val excludedTpcdsQueries: Seq[String] = Seq()
+      override val excludedTpcdsQueries: Set[String] = Set()
 
       // This is private in `TPCDSBase` and `excludedTpcdsQueries` is private 
too.
       // So we cannot override `excludedTpcdsQueries` to exclude the queries.
@@ -145,7 +145,8 @@ class CometTPCDSQuerySuite
       override val tpcdsQueries: Seq[String] =
         tpcdsAllQueries.filterNot(excludedTpcdsQueries.contains)
     }
-    with TPCDSQueryTestSuite {
+    with TPCDSQueryTestSuite
+    with ShimCometTPCDSQuerySuite {
   override def sparkConf: SparkConf = {
     val conf = super.sparkConf
     conf.set("spark.sql.extensions", 
"org.apache.comet.CometSparkSessionExtensions")
diff --git 
a/spark/src/test/scala/org/apache/spark/sql/CometTPCHQuerySuite.scala 
b/spark/src/test/scala/org/apache/spark/sql/CometTPCHQuerySuite.scala
index 1abe5fae..ec87f19e 100644
--- a/spark/src/test/scala/org/apache/spark/sql/CometTPCHQuerySuite.scala
+++ b/spark/src/test/scala/org/apache/spark/sql/CometTPCHQuerySuite.scala
@@ -33,6 +33,7 @@ import org.apache.spark.sql.test.{SharedSparkSession, 
TestSparkSession}
 
 import org.apache.comet.CometConf
 import org.apache.comet.CometSparkSessionExtensions.isSpark34Plus
+import org.apache.comet.shims.ShimCometTPCHQuerySuite
 
 /**
  * End-to-end tests to check TPCH query results.
@@ -49,7 +50,7 @@ import 
org.apache.comet.CometSparkSessionExtensions.isSpark34Plus
  *     ./mvnw -Dsuites=org.apache.spark.sql.CometTPCHQuerySuite test
  * }}}
  */
-class CometTPCHQuerySuite extends QueryTest with CometTPCBase with 
SQLQueryTestHelper {
+class CometTPCHQuerySuite extends QueryTest with CometTPCBase with 
ShimCometTPCHQuerySuite {
 
   private val tpchDataPath = sys.env.get("SPARK_TPCH_DATA")
 
@@ -142,7 +143,7 @@ class CometTPCHQuerySuite extends QueryTest with 
CometTPCBase with SQLQueryTestH
     val shouldSortResults = sortMergeJoinConf != conf // Sort for other joins
     withSQLConf(conf.toSeq: _*) {
       try {
-        val (schema, output) = handleExceptions(getNormalizedResult(spark, 
query))
+        val (schema, output) = 
handleExceptions(getNormalizedQueryExecutionResult(spark, query))
         val queryString = query.trim
         val outputString = output.mkString("\n").replaceAll("\\s+$", "")
         if (shouldRegenerateGoldenFiles) {
diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala 
b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
index 0530d764..d8c82f12 100644
--- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
+++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
@@ -48,7 +48,6 @@ import org.apache.spark.sql.types.StructType
 import org.apache.comet._
 import org.apache.comet.CometSparkSessionExtensions.isSpark34Plus
 import org.apache.comet.shims.ShimCometSparkSessionExtensions
-import 
org.apache.comet.shims.ShimCometSparkSessionExtensions.supportsExtendedExplainInfo
 
 /**
  * Base class for testing. This exists in `org.apache.spark.sql` since 
[[SQLTestUtils]] is
diff --git 
a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometReadBenchmark.scala 
b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometReadBenchmark.scala
index 4c2f832a..fc454944 100644
--- 
a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometReadBenchmark.scala
+++ 
b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometReadBenchmark.scala
@@ -25,11 +25,12 @@ import scala.collection.JavaConverters._
 import scala.util.Random
 
 import org.apache.spark.benchmark.Benchmark
+import org.apache.spark.comet.shims.ShimTestUtils
 import 
org.apache.spark.sql.execution.datasources.parquet.VectorizedParquetRecordReader
 import org.apache.spark.sql.types._
 import org.apache.spark.sql.vectorized.ColumnVector
 
-import org.apache.comet.{CometConf, TestUtils}
+import org.apache.comet.CometConf
 import org.apache.comet.parquet.BatchReader
 
 /**
@@ -123,7 +124,7 @@ object CometReadBenchmark extends CometBenchmarkBase {
             (col: ColumnVector, i: Int) => longSum += 
col.getUTF8String(i).toLongExact
         }
 
-        val files = TestUtils.listDirectory(new File(dir, "parquetV1"))
+        val files = ShimTestUtils.listDirectory(new File(dir, "parquetV1"))
 
         sqlBenchmark.addCase("ParquetReader Spark") { _ =>
           files.map(_.asInstanceOf[String]).foreach { p =>
diff --git 
a/spark/src/test/spark-3.4/org/apache/comet/exec/CometExec3_4Suite.scala 
b/spark/src/test/spark-3.4-plus/org/apache/comet/exec/CometExec3_4PlusSuite.scala
similarity index 98%
rename from 
spark/src/test/spark-3.4/org/apache/comet/exec/CometExec3_4Suite.scala
rename to 
spark/src/test/spark-3.4-plus/org/apache/comet/exec/CometExec3_4PlusSuite.scala
index 019b4f03..31d1ffbf 100644
--- a/spark/src/test/spark-3.4/org/apache/comet/exec/CometExec3_4Suite.scala
+++ 
b/spark/src/test/spark-3.4-plus/org/apache/comet/exec/CometExec3_4PlusSuite.scala
@@ -27,9 +27,9 @@ import org.apache.spark.sql.CometTestBase
 import org.apache.comet.CometConf
 
 /**
- * This test suite contains tests for only Spark 3.4.
+ * This test suite contains tests for only Spark 3.4+.
  */
-class CometExec3_4Suite extends CometTestBase {
+class CometExec3_4PlusSuite extends CometTestBase {
   import testImplicits._
 
   override protected def test(testName: String, testTags: Tag*)(testFun: => 
Any)(implicit
diff --git 
a/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala 
b/spark/src/test/spark-3.x/org/apache/comet/shims/ShimCometTPCHQuerySuite.scala
similarity index 69%
copy from spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala
copy to 
spark/src/test/spark-3.x/org/apache/comet/shims/ShimCometTPCHQuerySuite.scala
index 0c45a9c2..caa943c2 100644
--- a/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala
+++ 
b/spark/src/test/spark-3.x/org/apache/comet/shims/ShimCometTPCHQuerySuite.scala
@@ -16,18 +16,13 @@
  * specific language governing permissions and limitations
  * under the License.
  */
+
 package org.apache.comet.shims
 
-import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.{SQLQueryTestHelper, SparkSession}
 
-/**
- * `CometExprShim` acts as a shim for for parsing expressions from different 
Spark versions.
- */
-trait CometExprShim {
-    /**
-      * Returns a tuple of expressions for the `unhex` function.
-      */
-    def unhexSerde(unhex: Unhex): (Expression, Expression) = {
-        (unhex.child, Literal(false))
-    }
+trait ShimCometTPCHQuerySuite extends SQLQueryTestHelper {
+  protected def getNormalizedQueryExecutionResult(session: SparkSession, sql: 
String): (String, Seq[String]) = {
+    getNormalizedResult(session, sql)
+  }
 }
diff --git a/spark/src/test/scala/org/apache/comet/TestUtils.scala 
b/spark/src/test/spark-3.x/org/apache/spark/comet/shims/ShimTestUtils.scala
similarity index 96%
rename from spark/src/test/scala/org/apache/comet/TestUtils.scala
rename to 
spark/src/test/spark-3.x/org/apache/spark/comet/shims/ShimTestUtils.scala
index d4e77156..fcb543f9 100644
--- a/spark/src/test/scala/org/apache/comet/TestUtils.scala
+++ b/spark/src/test/spark-3.x/org/apache/spark/comet/shims/ShimTestUtils.scala
@@ -17,13 +17,12 @@
  * under the License.
  */
 
-package org.apache.comet
+package org.apache.spark.comet.shims
 
 import java.io.File
-
 import scala.collection.mutable.ArrayBuffer
 
-object TestUtils {
+object ShimTestUtils {
 
   /**
    * Spark 3.3.0 moved {{{SpecificParquetRecordReaderBase.listDirectory}}} to
diff --git 
a/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala 
b/spark/src/test/spark-3.x/org/apache/spark/sql/comet/shims/ShimCometTPCDSQuerySuite.scala
similarity index 67%
copy from spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala
copy to 
spark/src/test/spark-3.x/org/apache/spark/sql/comet/shims/ShimCometTPCDSQuerySuite.scala
index 0c45a9c2..f8d621c7 100644
--- a/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala
+++ 
b/spark/src/test/spark-3.x/org/apache/spark/sql/comet/shims/ShimCometTPCDSQuerySuite.scala
@@ -16,18 +16,10 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-package org.apache.comet.shims
 
-import org.apache.spark.sql.catalyst.expressions._
+package org.apache.spark.sql.comet.shims
 
-/**
- * `CometExprShim` acts as a shim for for parsing expressions from different 
Spark versions.
- */
-trait CometExprShim {
-    /**
-      * Returns a tuple of expressions for the `unhex` function.
-      */
-    def unhexSerde(unhex: Unhex): (Expression, Expression) = {
-        (unhex.child, Literal(false))
-    }
+trait ShimCometTPCDSQuerySuite {
+  // This is private in `TPCDSBase`.
+  val excludedTpcdsQueries: Set[String] = Set()
 }
diff --git 
a/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala 
b/spark/src/test/spark-4.0/org/apache/comet/shims/ShimCometTPCHQuerySuite.scala
similarity index 69%
copy from spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala
copy to 
spark/src/test/spark-4.0/org/apache/comet/shims/ShimCometTPCHQuerySuite.scala
index 0c45a9c2..ec9823e5 100644
--- a/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala
+++ 
b/spark/src/test/spark-4.0/org/apache/comet/shims/ShimCometTPCHQuerySuite.scala
@@ -16,18 +16,10 @@
  * specific language governing permissions and limitations
  * under the License.
  */
+
 package org.apache.comet.shims
 
-import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.SQLQueryTestHelper
 
-/**
- * `CometExprShim` acts as a shim for for parsing expressions from different 
Spark versions.
- */
-trait CometExprShim {
-    /**
-      * Returns a tuple of expressions for the `unhex` function.
-      */
-    def unhexSerde(unhex: Unhex): (Expression, Expression) = {
-        (unhex.child, Literal(false))
-    }
+trait ShimCometTPCHQuerySuite extends SQLQueryTestHelper {
 }
diff --git 
a/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala 
b/spark/src/test/spark-4.0/org/apache/spark/comet/shims/ShimTestUtils.scala
similarity index 67%
copy from spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala
copy to 
spark/src/test/spark-4.0/org/apache/spark/comet/shims/ShimTestUtils.scala
index 0c45a9c2..923ae68f 100644
--- a/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala
+++ b/spark/src/test/spark-4.0/org/apache/spark/comet/shims/ShimTestUtils.scala
@@ -16,18 +16,12 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-package org.apache.comet.shims
 
-import org.apache.spark.sql.catalyst.expressions._
+package org.apache.spark.comet.shims
 
-/**
- * `CometExprShim` acts as a shim for for parsing expressions from different 
Spark versions.
- */
-trait CometExprShim {
-    /**
-      * Returns a tuple of expressions for the `unhex` function.
-      */
-    def unhexSerde(unhex: Unhex): (Expression, Expression) = {
-        (unhex.child, Literal(false))
-    }
+import java.io.File
+
+object ShimTestUtils {
+  def listDirectory(path: File): Array[String] =
+    org.apache.spark.TestUtils.listDirectory(path)
 }
diff --git 
a/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala 
b/spark/src/test/spark-4.0/org/apache/spark/sql/comet/shims/ShimCometTPCDSQuerySuite.scala
similarity index 67%
copy from spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala
copy to 
spark/src/test/spark-4.0/org/apache/spark/sql/comet/shims/ShimCometTPCDSQuerySuite.scala
index 0c45a9c2..43917df6 100644
--- a/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala
+++ 
b/spark/src/test/spark-4.0/org/apache/spark/sql/comet/shims/ShimCometTPCDSQuerySuite.scala
@@ -16,18 +16,9 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-package org.apache.comet.shims
 
-import org.apache.spark.sql.catalyst.expressions._
+package org.apache.spark.sql.comet.shims
+
+trait ShimCometTPCDSQuerySuite {
 
-/**
- * `CometExprShim` acts as a shim for for parsing expressions from different 
Spark versions.
- */
-trait CometExprShim {
-    /**
-      * Returns a tuple of expressions for the `unhex` function.
-      */
-    def unhexSerde(unhex: Unhex): (Expression, Expression) = {
-        (unhex.child, Literal(false))
-    }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to