This is an automated email from the ASF dual-hosted git repository.
viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 14494d3a feat: Enable columnar shuffle by default (#250)
14494d3a is described below
commit 14494d3a06338b28ce8ad31d032ac60b75f4c227
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Wed May 8 17:38:40 2024 -0700
feat: Enable columnar shuffle by default (#250)
* feat: Enable columnar shuffle by default
* Update plan stability
* Fix
* Update diff
* Add Comet memoryoverhead for Spark SQL tests
* Update plan stability
* Update diff
* Update more diff
* Update DataFusion commit
* Update diff
* Update diff
* Update diff
* Update diff
* Update diff
* Fix more tests
* Fix more
* Fix
* Fix more
* Fix more
* Fix more
* Fix more
* Fix more
* Update diff
* Fix memory leak
* Update plan stability
* Restore diff
* Update core/src/execution/datafusion/planner.rs
Co-authored-by: Andy Grove <[email protected]>
* Update core/src/execution/datafusion/planner.rs
Co-authored-by: Andy Grove <[email protected]>
* Fix style
* Use ShuffleExchangeLike instead
---------
Co-authored-by: Andy Grove <[email protected]>
---
.../main/scala/org/apache/comet/CometConf.scala | 16 +-
.../scala/org/apache/comet/vector/NativeUtil.scala | 2 +-
.../org/apache/spark/sql/comet/util/Utils.scala | 7 +-
core/Cargo.lock | 196 ++--
core/src/execution/datafusion/planner.rs | 25 +
dev/diffs/3.4.2.diff | 1128 +++++++++++++++++++-
docs/source/user-guide/configs.md | 2 +-
.../org/apache/comet/serde/QueryPlanSerde.scala | 2 +-
.../comet/plans/AliasAwareOutputExpression.scala | 3 +-
.../apache/spark/sql/CometTPCDSQuerySuite.scala | 3 +-
10 files changed, 1207 insertions(+), 177 deletions(-)
diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala
b/common/src/main/scala/org/apache/comet/CometConf.scala
index f9caee9d..1858f765 100644
--- a/common/src/main/scala/org/apache/comet/CometConf.scala
+++ b/common/src/main/scala/org/apache/comet/CometConf.scala
@@ -140,14 +140,14 @@ object CometConf {
.booleanConf
.createWithDefault(false)
- val COMET_COLUMNAR_SHUFFLE_ENABLED: ConfigEntry[Boolean] = conf(
- "spark.comet.columnar.shuffle.enabled")
- .doc(
- "Force Comet to only use columnar shuffle for CometScan and Spark
regular operators. " +
- "If this is enabled, Comet native shuffle will not be enabled but only
Arrow shuffle. " +
- "By default, this config is false.")
- .booleanConf
- .createWithDefault(false)
+ val COMET_COLUMNAR_SHUFFLE_ENABLED: ConfigEntry[Boolean] =
+ conf("spark.comet.columnar.shuffle.enabled")
+ .doc(
+ "Whether to enable Arrow-based columnar shuffle for Comet and Spark
regular operators. " +
+ "If this is enabled, Comet prefers columnar shuffle than native
shuffle. " +
+ "By default, this config is true.")
+ .booleanConf
+ .createWithDefault(true)
val COMET_SHUFFLE_ENFORCE_MODE_ENABLED: ConfigEntry[Boolean] =
conf("spark.comet.shuffle.enforceMode.enabled")
diff --git a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala
b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala
index eb731f9d..595c0a42 100644
--- a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala
+++ b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala
@@ -66,7 +66,7 @@ class NativeUtil {
val arrowArray = ArrowArray.allocateNew(allocator)
Data.exportVector(
allocator,
- getFieldVector(valueVector),
+ getFieldVector(valueVector, "export"),
provider,
arrowArray,
arrowSchema)
diff --git a/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala
b/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala
index 7d920e1b..2300e109 100644
--- a/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala
+++ b/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala
@@ -242,7 +242,7 @@ object Utils {
}
}
- getFieldVector(valueVector)
+ getFieldVector(valueVector, "serialize")
case c =>
throw new SparkException(
@@ -253,14 +253,15 @@ object Utils {
(fieldVectors, provider)
}
- def getFieldVector(valueVector: ValueVector): FieldVector = {
+ def getFieldVector(valueVector: ValueVector, reason: String): FieldVector = {
valueVector match {
case v @ (_: BitVector | _: TinyIntVector | _: SmallIntVector | _:
IntVector |
_: BigIntVector | _: Float4Vector | _: Float8Vector | _:
VarCharVector |
_: DecimalVector | _: DateDayVector | _: TimeStampMicroTZVector | _:
VarBinaryVector |
_: FixedSizeBinaryVector | _: TimeStampMicroVector) =>
v.asInstanceOf[FieldVector]
- case _ => throw new SparkException(s"Unsupported Arrow Vector:
${valueVector.getClass}")
+ case _ =>
+ throw new SparkException(s"Unsupported Arrow Vector for $reason:
${valueVector.getClass}")
}
}
}
diff --git a/core/Cargo.lock b/core/Cargo.lock
index 3fb7b5f6..52f10559 100644
--- a/core/Cargo.lock
+++ b/core/Cargo.lock
@@ -57,9 +57,9 @@ dependencies = [
[[package]]
name = "allocator-api2"
-version = "0.2.16"
+version = "0.2.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "0942ffc6dcaadf03badf6e6a2d0228460359d5e34b57ccdc720b7382dfbd5ec5"
+checksum = "5c6cb57a04249c6480766f7f7cef5467412af1490f8d1e243141daddada3264f"
[[package]]
name = "android-tzdata"
@@ -90,9 +90,9 @@ checksum =
"8901269c6307e8d93993578286ac0edf7f195079ffff5ebdeea6a59ffb7e36bc"
[[package]]
name = "anyhow"
-version = "1.0.81"
+version = "1.0.82"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "0952808a6c2afd1aa8947271f3a60f1a6763c7b912d210184c5149b5cf147247"
+checksum = "f538837af36e6f6a9be0faa67f9a314f8119e4e4b5867c6ab40ed60360142519"
[[package]]
name = "arc-swap"
@@ -327,13 +327,13 @@ checksum =
"0c24e9d990669fbd16806bff449e4ac644fd9b1fca014760087732fe4102f131"
[[package]]
name = "async-trait"
-version = "0.1.79"
+version = "0.1.80"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "a507401cad91ec6a857ed5513a2073c82a9b9048762b885bb98655b306964681"
+checksum = "c6fa2087f2753a7da8cc1c0dbfcf89579dd57458e36769de5ac750b4671737ca"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.57",
+ "syn 2.0.59",
]
[[package]]
@@ -438,9 +438,9 @@ dependencies = [
[[package]]
name = "bumpalo"
-version = "3.15.4"
+version = "3.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "7ff69b9dd49fd426c69a0db9fc04dd934cdb6645ff000864d98f7e2af8830eaa"
+checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c"
[[package]]
name = "bytemuck"
@@ -468,9 +468,9 @@ checksum =
"37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5"
[[package]]
name = "cc"
-version = "1.0.90"
+version = "1.0.94"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "8cd6604a82acf3039f1144f54b8eb34e91ffba622051189e71b781822d5ee1f5"
+checksum = "17f6e324229dc011159fcc089755d1e2e216a90d43a7dea6853ca740b84f35e7"
dependencies = [
"jobserver",
"libc",
@@ -490,14 +490,14 @@ checksum =
"baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "chrono"
-version = "0.4.37"
+version = "0.4.38"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "8a0d04d43504c61aa6c7531f1871dd0d418d91130162063b789da00fd7057a5e"
+checksum = "a21f936df1771bf62b77f047b726c4625ff2e8aa607c01ec06e5a05bd8463401"
dependencies = [
"android-tzdata",
"iana-time-zone",
"num-traits",
- "windows-targets 0.52.4",
+ "windows-targets 0.52.5",
]
[[package]]
@@ -576,9 +576,9 @@ checksum =
"98cc8fbded0c607b7ba9dd60cd98df59af97e84d24e49c8557331cfc26d301ce"
[[package]]
name = "combine"
-version = "4.6.6"
+version = "4.6.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "35ed6e9d84f0b51a7f52daf1c7d71dd136fd7a3f41a8462b8cdb8c78d920fad4"
+checksum = "ba5a308b75df32fe02788e748662718f03fde005016435c444eea572398219fd"
dependencies = [
"bytes",
"memchr",
@@ -625,7 +625,7 @@ dependencies = [
"parquet-format",
"paste",
"pprof",
- "prost 0.12.3",
+ "prost 0.12.4",
"prost-build",
"rand",
"regex",
@@ -643,12 +643,12 @@ dependencies = [
[[package]]
name = "comfy-table"
-version = "7.1.0"
+version = "7.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "7c64043d6c7b7a4c58e39e7efccfdea7b93d885a795d0c054a69dbbf4dd52686"
+checksum = "b34115915337defe99b2aff5c2ce6771e5fbc4079f4b506301f5cf394c8452f7"
dependencies = [
- "strum 0.25.0",
- "strum_macros 0.25.3",
+ "strum",
+ "strum_macros",
"unicode-width",
]
@@ -923,8 +923,8 @@ dependencies = [
"datafusion-common",
"paste",
"sqlparser",
- "strum 0.26.2",
- "strum_macros 0.26.2",
+ "strum",
+ "strum_macros",
]
[[package]]
@@ -1044,7 +1044,7 @@ dependencies = [
"datafusion-expr",
"log",
"sqlparser",
- "strum 0.26.2",
+ "strum",
]
[[package]]
@@ -1092,9 +1092,9 @@ checksum =
"fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10"
[[package]]
name = "either"
-version = "1.10.0"
+version = "1.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "11157ac094ffbdde99aa67b23417ebdd801842852b500e395a45a9c0aac03e4a"
+checksum = "a47c1c47d2f5964e29c61246e81db715514cd532db6b5116a25ea3c03d6780a2"
[[package]]
name = "equivalent"
@@ -1227,7 +1227,7 @@ checksum =
"87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.57",
+ "syn 2.0.59",
]
[[package]]
@@ -1272,9 +1272,9 @@ dependencies = [
[[package]]
name = "getrandom"
-version = "0.2.12"
+version = "0.2.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "190092ea657667030ac6a35e305e62fc4dd69fd98ac98631e5d3a2b1575a12b5"
+checksum = "94b22e06ecb0110981051723910cbf0b5f5e09a2062dd7663334ee79a9d1286c"
dependencies = [
"cfg-if",
"libc",
@@ -1526,9 +1526,9 @@ checksum =
"8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130"
[[package]]
name = "jobserver"
-version = "0.1.28"
+version = "0.1.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "ab46a6e9526ddef3ae7f787c06f0f2600639ba80ea3eade3d8e670a2230f51d6"
+checksum = "685a7d121ee3f65ae4fddd72b25a04bb36b6af81bc0828f7d5434c0fe60fa3a2"
dependencies = [
"libc",
]
@@ -1794,9 +1794,9 @@ dependencies = [
[[package]]
name = "num"
-version = "0.4.1"
+version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "b05180d69e3da0e530ba2a1dae5110317e49e3b7f3d41be227dc5f92e49ee7af"
+checksum = "3135b08af27d103b0a51f2ae0f8632117b7b185ccf931445affa8df530576a41"
dependencies = [
"num-bigint",
"num-complex",
@@ -2142,9 +2142,9 @@ checksum =
"5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de"
[[package]]
name = "proc-macro2"
-version = "1.0.79"
+version = "1.0.80"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "e835ff2298f5721608eb1a980ecaee1aef2c132bf95ecc026a11b7bf3c01c02e"
+checksum = "a56dea16b0a29e94408b9aa5e2940a4eedbd128a1ba20e8f7ae60fd3d465af0e"
dependencies = [
"unicode-ident",
]
@@ -2161,12 +2161,12 @@ dependencies = [
[[package]]
name = "prost"
-version = "0.12.3"
+version = "0.12.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "146c289cda302b98a28d40c8b3b90498d6e526dd24ac2ecea73e4e491685b94a"
+checksum = "d0f5d036824e4761737860779c906171497f6d55681139d8312388f8fe398922"
dependencies = [
"bytes",
- "prost-derive 0.12.3",
+ "prost-derive 0.12.4",
]
[[package]]
@@ -2204,15 +2204,15 @@ dependencies = [
[[package]]
name = "prost-derive"
-version = "0.12.3"
+version = "0.12.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "efb6c9a1dd1def8e2124d17e83a20af56f1570d6c2d2bd9e266ccb768df3840e"
+checksum = "19de2de2a00075bf566bee3bd4db014b11587e84184d3f7a791bc17f1a8e9e48"
dependencies = [
"anyhow",
- "itertools 0.11.0",
+ "itertools 0.12.1",
"proc-macro2",
"quote",
- "syn 2.0.57",
+ "syn 2.0.59",
]
[[package]]
@@ -2236,9 +2236,9 @@ dependencies = [
[[package]]
name = "quote"
-version = "1.0.35"
+version = "1.0.36"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "291ec9ab5efd934aaf503a6466c5d5251535d108ee747472c3977cc5acc868ef"
+checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7"
dependencies = [
"proc-macro2",
]
@@ -2370,9 +2370,9 @@ dependencies = [
[[package]]
name = "rustversion"
-version = "1.0.14"
+version = "1.0.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "7ffc183a10b4478d04cbbbfc96d0873219d962dd5accaff2ffbd4ceb7df837f4"
+checksum = "80af6f9131f277a45a3fba6ce8e2258037bb0477a67e610d3c1fe046ab31de47"
[[package]]
name = "ryu"
@@ -2434,14 +2434,14 @@ checksum =
"7eb0b34b42edc17f6b7cac84a52a1c5f0e1bb2227e997ca9011ea3dd34e8610b"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.57",
+ "syn 2.0.59",
]
[[package]]
name = "serde_json"
-version = "1.0.115"
+version = "1.0.116"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "12dc5c46daa8e9fdf4f5e71b6cf9a53f2487da0e86e55808e2d35539666497dd"
+checksum = "3e17db7126d17feb94eb3fad46bf1a96b034e8aacbc2e775fe81505f8b0b2813"
dependencies = [
"itoa",
"ryu",
@@ -2545,7 +2545,7 @@ checksum =
"01b2e185515564f15375f593fb966b5718bc624ba77fe49fa4616ad619690554"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.57",
+ "syn 2.0.59",
]
[[package]]
@@ -2566,32 +2566,13 @@ version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9091b6114800a5f2141aee1d1b9d6ca3592ac062dc5decb3764ec5895a47b4eb"
-[[package]]
-name = "strum"
-version = "0.25.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "290d54ea6f91c969195bdbcd7442c8c2a2ba87da8bf60a7ee86a235d4bc1e125"
-
[[package]]
name = "strum"
version = "0.26.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5d8cec3501a5194c432b2b7976db6b7d10ec95c253208b45f83f7136aa985e29"
dependencies = [
- "strum_macros 0.26.2",
-]
-
-[[package]]
-name = "strum_macros"
-version = "0.25.3"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "23dc1fa9ac9c169a78ba62f0b841814b7abae11bdd047b9c58f893439e309ea0"
-dependencies = [
- "heck 0.4.1",
- "proc-macro2",
- "quote",
- "rustversion",
- "syn 2.0.57",
+ "strum_macros",
]
[[package]]
@@ -2604,7 +2585,7 @@ dependencies = [
"proc-macro2",
"quote",
"rustversion",
- "syn 2.0.57",
+ "syn 2.0.59",
]
[[package]]
@@ -2649,9 +2630,9 @@ dependencies = [
[[package]]
name = "syn"
-version = "2.0.57"
+version = "2.0.59"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "11a6ae1e52eb25aab8f3fb9fca13be982a373b8f1157ca14b897a825ba4a2d35"
+checksum = "4a6531ffc7b071655e4ce2e04bd464c4830bb585a61cabb96cf808f05172615a"
dependencies = [
"proc-macro2",
"quote",
@@ -2687,7 +2668,7 @@ checksum =
"c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.57",
+ "syn 2.0.59",
]
[[package]]
@@ -2790,7 +2771,7 @@ checksum =
"5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.57",
+ "syn 2.0.59",
]
[[package]]
@@ -2823,7 +2804,7 @@ checksum =
"34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.57",
+ "syn 2.0.59",
]
[[package]]
@@ -2971,7 +2952,7 @@ dependencies = [
"once_cell",
"proc-macro2",
"quote",
- "syn 2.0.57",
+ "syn 2.0.59",
"wasm-bindgen-shared",
]
@@ -2993,7 +2974,7 @@ checksum =
"e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.57",
+ "syn 2.0.59",
"wasm-bindgen-backend",
"wasm-bindgen-shared",
]
@@ -3063,7 +3044,7 @@ version = "0.52.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9"
dependencies = [
- "windows-targets 0.52.4",
+ "windows-targets 0.52.5",
]
[[package]]
@@ -3081,7 +3062,7 @@ version = "0.52.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d"
dependencies = [
- "windows-targets 0.52.4",
+ "windows-targets 0.52.5",
]
[[package]]
@@ -3116,17 +3097,18 @@ dependencies = [
[[package]]
name = "windows-targets"
-version = "0.52.4"
+version = "0.52.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "7dd37b7e5ab9018759f893a1952c9420d060016fc19a472b4bb20d1bdd694d1b"
+checksum = "6f0713a46559409d202e70e28227288446bf7841d3211583a4b53e3f6d96e7eb"
dependencies = [
- "windows_aarch64_gnullvm 0.52.4",
- "windows_aarch64_msvc 0.52.4",
- "windows_i686_gnu 0.52.4",
- "windows_i686_msvc 0.52.4",
- "windows_x86_64_gnu 0.52.4",
- "windows_x86_64_gnullvm 0.52.4",
- "windows_x86_64_msvc 0.52.4",
+ "windows_aarch64_gnullvm 0.52.5",
+ "windows_aarch64_msvc 0.52.5",
+ "windows_i686_gnu 0.52.5",
+ "windows_i686_gnullvm",
+ "windows_i686_msvc 0.52.5",
+ "windows_x86_64_gnu 0.52.5",
+ "windows_x86_64_gnullvm 0.52.5",
+ "windows_x86_64_msvc 0.52.5",
]
[[package]]
@@ -3143,9 +3125,9 @@ checksum =
"2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8"
[[package]]
name = "windows_aarch64_gnullvm"
-version = "0.52.4"
+version = "0.52.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "bcf46cf4c365c6f2d1cc93ce535f2c8b244591df96ceee75d8e83deb70a9cac9"
+checksum = "7088eed71e8b8dda258ecc8bac5fb1153c5cffaf2578fc8ff5d61e23578d3263"
[[package]]
name = "windows_aarch64_msvc"
@@ -3161,9 +3143,9 @@ checksum =
"dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc"
[[package]]
name = "windows_aarch64_msvc"
-version = "0.52.4"
+version = "0.52.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "da9f259dd3bcf6990b55bffd094c4f7235817ba4ceebde8e6d11cd0c5633b675"
+checksum = "9985fd1504e250c615ca5f281c3f7a6da76213ebd5ccc9561496568a2752afb6"
[[package]]
name = "windows_i686_gnu"
@@ -3179,9 +3161,15 @@ checksum =
"a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e"
[[package]]
name = "windows_i686_gnu"
-version = "0.52.4"
+version = "0.52.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "88ba073cf16d5372720ec942a8ccbf61626074c6d4dd2e745299726ce8b89670"
+
+[[package]]
+name = "windows_i686_gnullvm"
+version = "0.52.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "b474d8268f99e0995f25b9f095bc7434632601028cf86590aea5c8a5cb7801d3"
+checksum = "87f4261229030a858f36b459e748ae97545d6f1ec60e5e0d6a3d32e0dc232ee9"
[[package]]
name = "windows_i686_msvc"
@@ -3197,9 +3185,9 @@ checksum =
"8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406"
[[package]]
name = "windows_i686_msvc"
-version = "0.52.4"
+version = "0.52.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "1515e9a29e5bed743cb4415a9ecf5dfca648ce85ee42e15873c3cd8610ff8e02"
+checksum = "db3c2bf3d13d5b658be73463284eaf12830ac9a26a90c717b7f771dfe97487bf"
[[package]]
name = "windows_x86_64_gnu"
@@ -3215,9 +3203,9 @@ checksum =
"53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e"
[[package]]
name = "windows_x86_64_gnu"
-version = "0.52.4"
+version = "0.52.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "5eee091590e89cc02ad514ffe3ead9eb6b660aedca2183455434b93546371a03"
+checksum = "4e4246f76bdeff09eb48875a0fd3e2af6aada79d409d33011886d3e1581517d9"
[[package]]
name = "windows_x86_64_gnullvm"
@@ -3233,9 +3221,9 @@ checksum =
"0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc"
[[package]]
name = "windows_x86_64_gnullvm"
-version = "0.52.4"
+version = "0.52.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "77ca79f2451b49fa9e2af39f0747fe999fcda4f5e241b2898624dca97a1f2177"
+checksum = "852298e482cd67c356ddd9570386e2862b5673c85bd5f88df9ab6802b334c596"
[[package]]
name = "windows_x86_64_msvc"
@@ -3251,9 +3239,9 @@ checksum =
"ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538"
[[package]]
name = "windows_x86_64_msvc"
-version = "0.52.4"
+version = "0.52.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "32b752e52a2da0ddfbdbcc6fceadfeede4c939ed16d13e648833a61dfb611ed8"
+checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0"
[[package]]
name = "zerocopy"
@@ -3272,7 +3260,7 @@ checksum =
"9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.57",
+ "syn 2.0.59",
]
[[package]]
diff --git a/core/src/execution/datafusion/planner.rs
b/core/src/execution/datafusion/planner.rs
index b5f8201a..59818857 100644
--- a/core/src/execution/datafusion/planner.rs
+++ b/core/src/execution/datafusion/planner.rs
@@ -1040,6 +1040,23 @@ impl PhysicalPlanner {
.collect();
let full_schema = Arc::new(Schema::new(all_fields));
+ // Because we cast dictionary array to array in scan operator,
+ // we need to change dictionary type to data type for join filter
expression.
+ let fields: Vec<_> = full_schema
+ .fields()
+ .iter()
+ .map(|f| match f.data_type() {
+ DataType::Dictionary(_, val_type) => Arc::new(Field::new(
+ f.name(),
+ val_type.as_ref().clone(),
+ f.is_nullable(),
+ )),
+ _ => f.clone(),
+ })
+ .collect();
+
+ let full_schema = Arc::new(Schema::new(fields));
+
let physical_expr = self.create_expr(expr, full_schema)?;
let (left_field_indices, right_field_indices) =
expr_to_columns(&physical_expr, left_fields.len(),
right_fields.len())?;
@@ -1058,6 +1075,14 @@ impl PhysicalPlanner {
.into_iter()
.map(|i| right.schema().field(i).clone()),
)
+ // Because we cast dictionary array to array in scan operator,
+ // we need to change dictionary type to data type for join
filter expression.
+ .map(|f| match f.data_type() {
+ DataType::Dictionary(_, val_type) => {
+ Field::new(f.name(), val_type.as_ref().clone(),
f.is_nullable())
+ }
+ _ => f.clone(),
+ })
.collect_vec();
let filter_schema = Schema::new_with_metadata(filter_fields,
HashMap::new());
diff --git a/dev/diffs/3.4.2.diff b/dev/diffs/3.4.2.diff
index 4154a705..19bf6dd4 100644
--- a/dev/diffs/3.4.2.diff
+++ b/dev/diffs/3.4.2.diff
@@ -210,6 +210,51 @@ index 0efe0877e9b..423d3b3d76d 100644
--
-- SELECT_HAVING
--
https://github.com/postgres/postgres/blob/REL_12_BETA2/src/test/regress/sql/select_having.sql
+diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
+index cf40e944c09..bdd5be4f462 100644
+--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
++++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
+@@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeConstants
+ import org.apache.spark.sql.execution.{ColumnarToRowExec,
ExecSubqueryExpression, RDDScanExec, SparkPlan}
+ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
+ import org.apache.spark.sql.execution.columnar._
+-import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
++import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike
+ import
org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate
+ import org.apache.spark.sql.functions._
+ import org.apache.spark.sql.internal.SQLConf
+@@ -516,7 +516,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils
+ */
+ private def verifyNumExchanges(df: DataFrame, expected: Int): Unit = {
+ assert(
+- collect(df.queryExecution.executedPlan) { case e: ShuffleExchangeExec
=> e }.size == expected)
++ collect(df.queryExecution.executedPlan) {
++ case _: ShuffleExchangeLike => 1 }.size == expected)
+ }
+
+ test("A cached table preserves the partitioning and ordering of its cached
SparkPlan") {
+diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+index ea5e47ede55..814b92d090f 100644
+---
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
++++
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+@@ -27,7 +27,7 @@ import org.apache.spark.SparkException
+ import org.apache.spark.sql.execution.WholeStageCodegenExec
+ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
+ import org.apache.spark.sql.execution.aggregate.{HashAggregateExec,
ObjectHashAggregateExec, SortAggregateExec}
+-import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
++import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike
+ import org.apache.spark.sql.expressions.Window
+ import org.apache.spark.sql.functions._
+ import org.apache.spark.sql.internal.SQLConf
+@@ -755,7 +755,7 @@ class DataFrameAggregateSuite extends QueryTest
+ assert(objHashAggPlans.nonEmpty)
+
+ val exchangePlans = collect(aggPlan) {
+- case shuffle: ShuffleExchangeExec => shuffle
++ case shuffle: ShuffleExchangeLike => shuffle
+ }
+ assert(exchangePlans.length == 1)
+ }
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
index 56e9520fdab..917932336df 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
@@ -226,9 +271,54 @@ index 56e9520fdab..917932336df 100644
spark.range(100).write.saveAsTable(s"$dbName.$table2Name")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
-index 9ddb4abe98b..1bebe99f1cc 100644
+index 9ddb4abe98b..1b9269acef1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+@@ -43,7 +43,7 @@ import org.apache.spark.sql.connector.FakeV2Provider
+ import org.apache.spark.sql.execution.{FilterExec, LogicalRDD,
QueryExecution, SortExec, WholeStageCodegenExec}
+ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
+ import org.apache.spark.sql.execution.aggregate.HashAggregateExec
+-import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec,
ReusedExchangeExec, ShuffleExchangeExec, ShuffleExchangeLike}
++import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec,
ReusedExchangeExec, ShuffleExchangeLike}
+ import org.apache.spark.sql.expressions.{Aggregator, Window}
+ import org.apache.spark.sql.functions._
+ import org.apache.spark.sql.internal.SQLConf
+@@ -1981,7 +1981,7 @@ class DataFrameSuite extends QueryTest
+ fail("Should not have back to back Aggregates")
+ }
+ atFirstAgg = true
+- case e: ShuffleExchangeExec => atFirstAgg = false
++ case e: ShuffleExchangeLike => atFirstAgg = false
+ case _ =>
+ }
+ }
+@@ -2291,7 +2291,7 @@ class DataFrameSuite extends QueryTest
+ checkAnswer(join, df)
+ assert(
+ collect(join.queryExecution.executedPlan) {
+- case e: ShuffleExchangeExec => true }.size === 1)
++ case _: ShuffleExchangeLike => true }.size === 1)
+ assert(
+ collect(join.queryExecution.executedPlan) { case e:
ReusedExchangeExec => true }.size === 1)
+ val broadcasted = broadcast(join)
+@@ -2299,7 +2299,7 @@ class DataFrameSuite extends QueryTest
+ checkAnswer(join2, df)
+ assert(
+ collect(join2.queryExecution.executedPlan) {
+- case e: ShuffleExchangeExec => true }.size == 1)
++ case _: ShuffleExchangeLike => true }.size == 1)
+ assert(
+ collect(join2.queryExecution.executedPlan) {
+ case e: BroadcastExchangeExec => true }.size === 1)
+@@ -2862,7 +2862,7 @@ class DataFrameSuite extends QueryTest
+
+ // Assert that no extra shuffle introduced by cogroup.
+ val exchanges = collect(df3.queryExecution.executedPlan) {
+- case h: ShuffleExchangeExec => h
++ case h: ShuffleExchangeLike => h
+ }
+ assert(exchanges.size == 2)
+ }
@@ -3311,7 +3311,8 @@ class DataFrameSuite extends QueryTest
assert(df2.isLocal)
}
@@ -239,8 +329,30 @@ index 9ddb4abe98b..1bebe99f1cc 100644
withTable("tbl") {
sql(
"""
+diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+index 7dec558f8df..840dda15033 100644
+--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
++++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+@@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.plans.{LeftAnti,
LeftSemi}
+ import org.apache.spark.sql.catalyst.util.sideBySide
+ import org.apache.spark.sql.execution.{LogicalRDD, RDDScanExec, SQLExecution}
+ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
+-import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec,
ShuffleExchangeExec}
++import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec,
ShuffleExchangeExec, ShuffleExchangeLike}
+ import org.apache.spark.sql.execution.streaming.MemoryStream
+ import org.apache.spark.sql.expressions.UserDefinedFunction
+ import org.apache.spark.sql.functions._
+@@ -2254,7 +2254,7 @@ class DatasetSuite extends QueryTest
+
+ // Assert that no extra shuffle introduced by cogroup.
+ val exchanges = collect(df3.queryExecution.executedPlan) {
+- case h: ShuffleExchangeExec => h
++ case h: ShuffleExchangeLike => h
+ }
+ assert(exchanges.size == 2)
+ }
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala
-index f33432ddb6f..6160c8d241a 100644
+index f33432ddb6f..060f874ea72 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala
@@ -22,6 +22,7 @@ import org.scalatest.GivenWhenThen
@@ -261,7 +373,17 @@ index f33432ddb6f..6160c8d241a 100644
case _ => Nil
}
}
-@@ -1238,7 +1242,8 @@ abstract class DynamicPartitionPruningSuiteBase
+@@ -1187,7 +1191,8 @@ abstract class DynamicPartitionPruningSuiteBase
+ }
+ }
+
+- test("Make sure dynamic pruning works on uncorrelated queries") {
++ test("Make sure dynamic pruning works on uncorrelated queries",
++ IgnoreComet("TODO: Support SubqueryBroadcastExec in Comet: #242")) {
+ withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key ->
"true") {
+ val df = sql(
+ """
+@@ -1238,7 +1243,8 @@ abstract class DynamicPartitionPruningSuiteBase
}
}
@@ -271,7 +393,7 @@ index f33432ddb6f..6160c8d241a 100644
Given("dynamic pruning filter on the build side")
withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key ->
"true") {
val df = sql(
-@@ -1485,7 +1490,7 @@ abstract class DynamicPartitionPruningSuiteBase
+@@ -1485,7 +1491,7 @@ abstract class DynamicPartitionPruningSuiteBase
}
test("SPARK-38148: Do not add dynamic partition pruning if there exists
static partition " +
@@ -280,7 +402,7 @@ index f33432ddb6f..6160c8d241a 100644
withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true") {
Seq(
"f.store_id = 1" -> false,
-@@ -1729,6 +1734,8 @@ abstract class DynamicPartitionPruningV1Suite extends
DynamicPartitionPruningDat
+@@ -1729,6 +1735,8 @@ abstract class DynamicPartitionPruningV1Suite extends
DynamicPartitionPruningDat
case s: BatchScanExec =>
// we use f1 col for v2 tables due to schema pruning
s.output.exists(_.exists(_.argString(maxFields =
100).contains("f1")))
@@ -290,7 +412,7 @@ index f33432ddb6f..6160c8d241a 100644
}
assert(scanOption.isDefined)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala
-index a6b295578d6..a5cb616945a 100644
+index a6b295578d6..91acca4306f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala
@@ -463,7 +463,8 @@ class ExplainSuite extends ExplainSuiteHelper with
DisableAdaptiveExecutionSuite
@@ -303,19 +425,38 @@ index a6b295578d6..a5cb616945a 100644
withTempDir { dir =>
Seq("parquet", "orc", "csv", "json").foreach { fmt =>
val basePath = dir.getCanonicalPath + "/" + fmt
+@@ -541,7 +542,9 @@ class ExplainSuite extends ExplainSuiteHelper with
DisableAdaptiveExecutionSuite
+ }
+ }
+
+-class ExplainSuiteAE extends ExplainSuiteHelper with
EnableAdaptiveExecutionSuite {
++// Ignored when Comet is enabled. Comet changes expected query plans.
++class ExplainSuiteAE extends ExplainSuiteHelper with
EnableAdaptiveExecutionSuite
++ with IgnoreCometSuite {
+ import testImplicits._
+
+ test("SPARK-35884: Explain Formatted") {
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala
-index 2796b1cf154..94591f83c84 100644
+index 2796b1cf154..be7078b38f4 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala
@@ -33,6 +33,7 @@ import org.apache.spark.sql.TestingUDT.{IntervalUDT,
NullData, NullUDT}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference,
GreaterThan, Literal}
import
org.apache.spark.sql.catalyst.expressions.IntegralLiteralTestUtils.{negativeInt,
positiveInt}
import org.apache.spark.sql.catalyst.plans.logical.Filter
-+import org.apache.spark.sql.comet.{CometBatchScanExec, CometScanExec}
++import org.apache.spark.sql.comet.{CometBatchScanExec, CometScanExec,
CometSortMergeJoinExec}
import org.apache.spark.sql.execution.{FileSourceScanLike, SimpleMode}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.datasources.FilePartition
-@@ -875,6 +876,7 @@ class FileBasedDataSourceSuite extends QueryTest
+@@ -815,6 +816,7 @@ class FileBasedDataSourceSuite extends QueryTest
+ assert(bJoinExec.isEmpty)
+ val smJoinExec = collect(joinedDF.queryExecution.executedPlan) {
+ case smJoin: SortMergeJoinExec => smJoin
++ case smJoin: CometSortMergeJoinExec => smJoin
+ }
+ assert(smJoinExec.nonEmpty)
+ }
+@@ -875,6 +877,7 @@ class FileBasedDataSourceSuite extends QueryTest
val fileScan = df.queryExecution.executedPlan collectFirst {
case BatchScanExec(_, f: FileScan, _, _, _, _, _, _, _) => f
@@ -323,7 +464,7 @@ index 2796b1cf154..94591f83c84 100644
}
assert(fileScan.nonEmpty)
assert(fileScan.get.partitionFilters.nonEmpty)
-@@ -916,6 +918,7 @@ class FileBasedDataSourceSuite extends QueryTest
+@@ -916,6 +919,7 @@ class FileBasedDataSourceSuite extends QueryTest
val fileScan = df.queryExecution.executedPlan collectFirst {
case BatchScanExec(_, f: FileScan, _, _, _, _, _, _, _) => f
@@ -331,7 +472,7 @@ index 2796b1cf154..94591f83c84 100644
}
assert(fileScan.nonEmpty)
assert(fileScan.get.partitionFilters.isEmpty)
-@@ -1100,6 +1103,8 @@ class FileBasedDataSourceSuite extends QueryTest
+@@ -1100,6 +1104,8 @@ class FileBasedDataSourceSuite extends QueryTest
val filters = df.queryExecution.executedPlan.collect {
case f: FileSourceScanLike => f.dataFilters
case b: BatchScanExec => b.scan.asInstanceOf[FileScan].dataFilters
@@ -388,8 +529,36 @@ index 00000000000..4b31bea33de
+ }
+ }
+}
+diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala
+index 1792b4c32eb..1616e6f39bd 100644
+--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala
++++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala
+@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.optimizer.{BuildLeft,
BuildRight, BuildSide
+ import org.apache.spark.sql.catalyst.plans.PlanTest
+ import org.apache.spark.sql.catalyst.plans.logical._
+ import org.apache.spark.sql.catalyst.rules.RuleExecutor
++import org.apache.spark.sql.comet.{CometHashJoinExec, CometSortMergeJoinExec}
+ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
+ import org.apache.spark.sql.execution.joins._
+ import org.apache.spark.sql.internal.SQLConf
+@@ -362,6 +363,7 @@ class JoinHintSuite extends PlanTest with
SharedSparkSession with AdaptiveSparkP
+ val executedPlan = df.queryExecution.executedPlan
+ val shuffleHashJoins = collect(executedPlan) {
+ case s: ShuffledHashJoinExec => s
++ case c: CometHashJoinExec =>
c.originalPlan.asInstanceOf[ShuffledHashJoinExec]
+ }
+ assert(shuffleHashJoins.size == 1)
+ assert(shuffleHashJoins.head.buildSide == buildSide)
+@@ -371,6 +373,7 @@ class JoinHintSuite extends PlanTest with
SharedSparkSession with AdaptiveSparkP
+ val executedPlan = df.queryExecution.executedPlan
+ val shuffleMergeJoins = collect(executedPlan) {
+ case s: SortMergeJoinExec => s
++ case c: CometSortMergeJoinExec => c
+ }
+ assert(shuffleMergeJoins.size == 1)
+ }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
-index 5125708be32..a1f1ae90796 100644
+index 5125708be32..210ab4f3ce1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier
@@ -400,7 +569,87 @@ index 5125708be32..a1f1ae90796 100644
import org.apache.spark.sql.execution.{BinaryExecNode, FilterExec,
ProjectExec, SortExec, SparkPlan, WholeStageCodegenExec}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.exchange.{ShuffleExchangeExec,
ShuffleExchangeLike}
-@@ -1369,9 +1370,12 @@ class JoinSuite extends QueryTest with
SharedSparkSession with AdaptiveSparkPlan
+@@ -739,7 +740,8 @@ class JoinSuite extends QueryTest with SharedSparkSession
with AdaptiveSparkPlan
+ }
+ }
+
+- test("test SortMergeJoin (with spill)") {
++ test("test SortMergeJoin (with spill)",
++ IgnoreComet("TODO: Comet SMJ doesn't support spill yet")) {
+ withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1",
+ SQLConf.SORT_MERGE_JOIN_EXEC_BUFFER_IN_MEMORY_THRESHOLD.key -> "0",
+ SQLConf.SORT_MERGE_JOIN_EXEC_BUFFER_SPILL_THRESHOLD.key -> "1") {
+@@ -1114,9 +1116,11 @@ class JoinSuite extends QueryTest with
SharedSparkSession with AdaptiveSparkPlan
+ val plan = df1.join(df2.hint("SHUFFLE_HASH"), $"k1" === $"k2", joinType)
+ .groupBy($"k1").count()
+ .queryExecution.executedPlan
+- assert(collect(plan) { case _: ShuffledHashJoinExec => true }.size ===
1)
++ assert(collect(plan) {
++ case _: ShuffledHashJoinExec | _: CometHashJoinExec => true }.size
=== 1)
+ // No extra shuffle before aggregate
+- assert(collect(plan) { case _: ShuffleExchangeExec => true }.size === 2)
++ assert(collect(plan) {
++ case _: ShuffleExchangeLike => true }.size === 2)
+ })
+ }
+
+@@ -1133,10 +1137,11 @@ class JoinSuite extends QueryTest with
SharedSparkSession with AdaptiveSparkPlan
+ .join(df4.hint("SHUFFLE_MERGE"), $"k1" === $"k4", joinType)
+ .queryExecution
+ .executedPlan
+- assert(collect(plan) { case _: SortMergeJoinExec => true }.size === 2)
++ assert(collect(plan) {
++ case _: SortMergeJoinExec | _: CometSortMergeJoinExec => true }.size
=== 2)
+ assert(collect(plan) { case _: BroadcastHashJoinExec => true }.size ===
1)
+ // No extra sort before last sort merge join
+- assert(collect(plan) { case _: SortExec => true }.size === 3)
++ assert(collect(plan) { case _: SortExec | _: CometSortExec => true
}.size === 3)
+ })
+
+ // Test shuffled hash join
+@@ -1146,10 +1151,13 @@ class JoinSuite extends QueryTest with
SharedSparkSession with AdaptiveSparkPlan
+ .join(df4.hint("SHUFFLE_MERGE"), $"k1" === $"k4", joinType)
+ .queryExecution
+ .executedPlan
+- assert(collect(plan) { case _: SortMergeJoinExec => true }.size === 2)
+- assert(collect(plan) { case _: ShuffledHashJoinExec => true }.size ===
1)
++ assert(collect(plan) {
++ case _: SortMergeJoinExec | _: CometSortMergeJoinExec => true }.size
=== 2)
++ assert(collect(plan) {
++ case _: ShuffledHashJoinExec | _: CometHashJoinExec => true }.size
=== 1)
+ // No extra sort before last sort merge join
+- assert(collect(plan) { case _: SortExec => true }.size === 3)
++ assert(collect(plan) {
++ case _: SortExec | _: CometSortExec => true }.size === 3)
+ })
+ }
+
+@@ -1240,12 +1248,12 @@ class JoinSuite extends QueryTest with
SharedSparkSession with AdaptiveSparkPlan
+ inputDFs.foreach { case (df1, df2, joinExprs) =>
+ val smjDF = df1.join(df2.hint("SHUFFLE_MERGE"), joinExprs, "full")
+ assert(collect(smjDF.queryExecution.executedPlan) {
+- case _: SortMergeJoinExec => true }.size === 1)
++ case _: SortMergeJoinExec | _: CometSortMergeJoinExec => true }.size
=== 1)
+ val smjResult = smjDF.collect()
+
+ val shjDF = df1.join(df2.hint("SHUFFLE_HASH"), joinExprs, "full")
+ assert(collect(shjDF.queryExecution.executedPlan) {
+- case _: ShuffledHashJoinExec => true }.size === 1)
++ case _: ShuffledHashJoinExec | _: CometHashJoinExec => true }.size
=== 1)
+ // Same result between shuffled hash join and sort merge join
+ checkAnswer(shjDF, smjResult)
+ }
+@@ -1340,7 +1348,8 @@ class JoinSuite extends QueryTest with
SharedSparkSession with AdaptiveSparkPlan
+ val plan = sql(getAggQuery(selectExpr,
joinType)).queryExecution.executedPlan
+ assert(collect(plan) { case _: BroadcastNestedLoopJoinExec => true
}.size === 1)
+ // Have shuffle before aggregation
+- assert(collect(plan) { case _: ShuffleExchangeExec => true }.size
=== 1)
++ assert(collect(plan) {
++ case _: ShuffleExchangeLike => true }.size === 1)
+ }
+
+ def getJoinQuery(selectExpr: String, joinType: String): String = {
+@@ -1369,9 +1378,12 @@ class JoinSuite extends QueryTest with
SharedSparkSession with AdaptiveSparkPlan
}
val plan = sql(getJoinQuery(selectExpr,
joinType)).queryExecution.executedPlan
assert(collect(plan) { case _: BroadcastNestedLoopJoinExec => true
}.size === 1)
@@ -415,7 +664,7 @@ index 5125708be32..a1f1ae90796 100644
}
// Test output ordering is not preserved
-@@ -1380,9 +1384,12 @@ class JoinSuite extends QueryTest with
SharedSparkSession with AdaptiveSparkPlan
+@@ -1380,9 +1392,12 @@ class JoinSuite extends QueryTest with
SharedSparkSession with AdaptiveSparkPlan
val selectExpr = "/*+ BROADCAST(left_t) */ k1 as k0"
val plan = sql(getJoinQuery(selectExpr,
joinType)).queryExecution.executedPlan
assert(collect(plan) { case _: BroadcastNestedLoopJoinExec => true
}.size === 1)
@@ -430,6 +679,16 @@ index 5125708be32..a1f1ae90796 100644
}
// Test singe partition
+@@ -1392,7 +1407,8 @@ class JoinSuite extends QueryTest with
SharedSparkSession with AdaptiveSparkPlan
+ |FROM range(0, 10, 1, 1) t1 FULL OUTER JOIN range(0, 10, 1, 1) t2
+ |""".stripMargin)
+ val plan = fullJoinDF.queryExecution.executedPlan
+- assert(collect(plan) { case _: ShuffleExchangeExec => true}.size == 1)
++ assert(collect(plan) {
++ case _: ShuffleExchangeLike => true}.size == 1)
+ checkAnswer(fullJoinDF, Row(100))
+ }
+ }
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala
index b5b34922694..a72403780c4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala
@@ -443,20 +702,38 @@ index b5b34922694..a72403780c4 100644
protected val baseResourcePath = {
// use the same way as `SQLQueryTestSuite` to get the resource path
+diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+index 525d97e4998..8a3e7457618 100644
+--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
++++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+@@ -1508,7 +1508,8 @@ class SQLQuerySuite extends QueryTest with
SharedSparkSession with AdaptiveSpark
+ checkAnswer(sql("select -0.001"), Row(BigDecimal("-0.001")))
+ }
+
+- test("external sorting updates peak execution memory") {
++ test("external sorting updates peak execution memory",
++ IgnoreComet("TODO: native CometSort does not update peak execution
memory")) {
+ AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "external
sort") {
+ sql("SELECT * FROM testData2 ORDER BY a ASC, b ASC").collect()
+ }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
-index 3cfda19134a..278bb1060c4 100644
+index 3cfda19134a..7590b808def 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
-@@ -21,6 +21,8 @@ import scala.collection.mutable.ArrayBuffer
+@@ -21,10 +21,11 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Join,
LogicalPlan, Project, Sort, Union}
+import org.apache.spark.sql.comet.CometScanExec
-+import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper,
DisableAdaptiveExecution}
import org.apache.spark.sql.execution.datasources.FileScanRDD
-@@ -1543,6 +1545,12 @@ class SubquerySuite extends QueryTest
+-import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
++import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike
+ import org.apache.spark.sql.execution.joins.{BaseJoinExec,
BroadcastHashJoinExec, BroadcastNestedLoopJoinExec}
+ import org.apache.spark.sql.internal.SQLConf
+ import org.apache.spark.sql.test.SharedSparkSession
+@@ -1543,6 +1544,12 @@ class SubquerySuite extends QueryTest
fs.inputRDDs().forall(
_.asInstanceOf[FileScanRDD].filePartitions.forall(
_.files.forall(_.urlEncodedPath.contains("p=0"))))
@@ -469,14 +746,78 @@ index 3cfda19134a..278bb1060c4 100644
case _ => false
})
}
-@@ -2109,6 +2117,7 @@ class SubquerySuite extends QueryTest
+@@ -2108,7 +2115,7 @@ class SubquerySuite extends QueryTest
+
df.collect()
val exchanges = collect(df.queryExecution.executedPlan) {
- case s: ShuffleExchangeExec => s
-+ case s: CometShuffleExchangeExec => s
+- case s: ShuffleExchangeExec => s
++ case s: ShuffleExchangeLike => s
}
assert(exchanges.size === 1)
}
+diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala
+index 02990a7a40d..bddf5e1ccc2 100644
+---
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala
++++
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala
+@@ -24,6 +24,7 @@ import test.org.apache.spark.sql.connector._
+
+ import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row}
+ import org.apache.spark.sql.catalyst.InternalRow
++import org.apache.spark.sql.comet.CometSortExec
+ import org.apache.spark.sql.connector.catalog.{PartitionInternalRow,
SupportsRead, Table, TableCapability, TableProvider}
+ import org.apache.spark.sql.connector.catalog.TableCapability._
+ import org.apache.spark.sql.connector.expressions.{Expression,
FieldReference, Literal, NamedReference, NullOrdering, SortDirection,
SortOrder, Transform}
+@@ -33,7 +34,7 @@ import
org.apache.spark.sql.connector.read.partitioning.{KeyGroupedPartitioning,
+ import org.apache.spark.sql.execution.SortExec
+ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
+ import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec,
DataSourceV2Relation, DataSourceV2ScanRelation}
+-import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec}
++import org.apache.spark.sql.execution.exchange.{Exchange,
ShuffleExchangeExec, ShuffleExchangeLike}
+ import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector
+ import org.apache.spark.sql.expressions.Window
+ import org.apache.spark.sql.functions._
+@@ -268,13 +269,13 @@ class DataSourceV2Suite extends QueryTest with
SharedSparkSession with AdaptiveS
+ val groupByColJ = df.groupBy($"j").agg(sum($"i"))
+ checkAnswer(groupByColJ, Seq(Row(2, 8), Row(4, 2), Row(6, 5)))
+ assert(collectFirst(groupByColJ.queryExecution.executedPlan) {
+- case e: ShuffleExchangeExec => e
++ case e: ShuffleExchangeLike => e
+ }.isDefined)
+
+ val groupByIPlusJ = df.groupBy($"i" + $"j").agg(count("*"))
+ checkAnswer(groupByIPlusJ, Seq(Row(5, 2), Row(6, 2), Row(8, 1),
Row(9, 1)))
+ assert(collectFirst(groupByIPlusJ.queryExecution.executedPlan) {
+- case e: ShuffleExchangeExec => e
++ case e: ShuffleExchangeLike => e
+ }.isDefined)
+ }
+ }
+@@ -334,10 +335,11 @@ class DataSourceV2Suite extends QueryTest with
SharedSparkSession with AdaptiveS
+
+ val (shuffleExpected, sortExpected) = groupByExpects
+ assert(collectFirst(groupBy.queryExecution.executedPlan) {
+- case e: ShuffleExchangeExec => e
++ case e: ShuffleExchangeLike => e
+ }.isDefined === shuffleExpected)
+ assert(collectFirst(groupBy.queryExecution.executedPlan) {
+ case e: SortExec => e
++ case c: CometSortExec => c
+ }.isDefined === sortExpected)
+ }
+
+@@ -352,10 +354,11 @@ class DataSourceV2Suite extends QueryTest with
SharedSparkSession with AdaptiveS
+
+ val (shuffleExpected, sortExpected) = windowFuncExpects
+
assert(collectFirst(windowPartByColIOrderByColJ.queryExecution.executedPlan) {
+- case e: ShuffleExchangeExec => e
++ case e: ShuffleExchangeLike => e
+ }.isDefined === shuffleExpected)
+
assert(collectFirst(windowPartByColIOrderByColJ.queryExecution.executedPlan) {
+ case e: SortExec => e
++ case c: CometSortExec => c
+ }.isDefined === sortExpected)
+ }
+ }
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala
index cfc8b2cc845..c6fcfd7bd08 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala
@@ -502,6 +843,44 @@ index cfc8b2cc845..c6fcfd7bd08 100644
}
} finally {
spark.listenerManager.unregister(listener)
+diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
+index cf76f6ca32c..f454128af06 100644
+---
a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
++++
b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
+@@ -22,6 +22,7 @@ import org.apache.spark.sql.{DataFrame, Row}
+ import org.apache.spark.sql.catalyst.InternalRow
+ import org.apache.spark.sql.catalyst.expressions.{Literal,
TransformExpression}
+ import org.apache.spark.sql.catalyst.plans.physical
++import org.apache.spark.sql.comet.CometSortMergeJoinExec
+ import org.apache.spark.sql.connector.catalog.Identifier
+ import org.apache.spark.sql.connector.catalog.InMemoryTableCatalog
+ import org.apache.spark.sql.connector.catalog.functions._
+@@ -31,7 +32,7 @@ import
org.apache.spark.sql.connector.expressions.Expressions._
+ import org.apache.spark.sql.execution.SparkPlan
+ import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
+ import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation
+-import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
++import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike
+ import org.apache.spark.sql.execution.joins.SortMergeJoinExec
+ import org.apache.spark.sql.internal.SQLConf
+ import org.apache.spark.sql.internal.SQLConf._
+@@ -279,13 +280,14 @@ class KeyGroupedPartitioningSuite extends
DistributionAndOrderingSuiteBase {
+ Row("bbb", 20, 250.0), Row("bbb", 20, 350.0), Row("ccc", 30, 400.50)))
+ }
+
+- private def collectShuffles(plan: SparkPlan): Seq[ShuffleExchangeExec] = {
++ private def collectShuffles(plan: SparkPlan): Seq[ShuffleExchangeLike] = {
+ // here we skip collecting shuffle operators that are not associated with
SMJ
+ collect(plan) {
+ case s: SortMergeJoinExec => s
++ case c: CometSortMergeJoinExec => c.originalPlan
+ }.flatMap(smj =>
+ collect(smj) {
+- case s: ShuffleExchangeExec => s
++ case s: ShuffleExchangeLike => s
+ })
+ }
+
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
index c0ec8a58bd5..4e8bc6ed3c5 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
@@ -547,6 +926,45 @@ index 418ca3430bb..eb8267192f8 100644
Seq("json", "orc", "parquet").foreach { format =>
withTempPath { path =>
val dir = path.getCanonicalPath
+diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala
+index 743ec41dbe7..9f30d6c8e04 100644
+---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala
++++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala
+@@ -53,6 +53,10 @@ class LogicalPlanTagInSparkPlanSuite extends
TPCDSQuerySuite with DisableAdaptiv
+ case ColumnarToRowExec(i: InputAdapter) => isScanPlanTree(i.child)
+ case p: ProjectExec => isScanPlanTree(p.child)
+ case f: FilterExec => isScanPlanTree(f.child)
++ // Comet produces scan plan tree like:
++ // ColumnarToRow
++ // +- ReusedExchange
++ case _: ReusedExchangeExec => false
+ case _: LeafExecNode => true
+ case _ => false
+ }
+diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+index 4b3d3a4b805..56e1e0e6f16 100644
+--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
++++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+@@ -18,7 +18,7 @@
+ package org.apache.spark.sql.execution
+
+ import org.apache.spark.rdd.RDD
+-import org.apache.spark.sql.{execution, DataFrame, Row}
++import org.apache.spark.sql.{execution, DataFrame, IgnoreCometSuite, Row}
+ import org.apache.spark.sql.catalyst.InternalRow
+ import org.apache.spark.sql.catalyst.expressions._
+ import org.apache.spark.sql.catalyst.plans._
+@@ -35,7 +35,9 @@ import org.apache.spark.sql.internal.SQLConf
+ import org.apache.spark.sql.test.SharedSparkSession
+ import org.apache.spark.sql.types._
+
+-class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
++// Ignore this suite when Comet is enabled. This suite tests the Spark
planner and Comet planner
++// comes out with too many difference. Simply ignoring this suite for now.
++class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper
with IgnoreCometSuite {
+ import testImplicits._
+
+ setupTestData()
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantProjectsSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantProjectsSuite.scala
index 9e9d717db3b..91a4f9a38d5 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantProjectsSuite.scala
@@ -571,11 +989,108 @@ index 9e9d717db3b..91a4f9a38d5 100644
assert(actual == expected)
}
}
+diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala
+index 30ce940b032..0d3f6c6c934 100644
+---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala
++++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala
+@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution
+
+ import org.apache.spark.sql.{DataFrame, QueryTest}
+ import org.apache.spark.sql.catalyst.plans.physical.{RangePartitioning,
UnknownPartitioning}
++import org.apache.spark.sql.comet.CometSortExec
+ import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper,
DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite}
+ import org.apache.spark.sql.execution.joins.ShuffledJoin
+ import org.apache.spark.sql.internal.SQLConf
+@@ -33,7 +34,7 @@ abstract class RemoveRedundantSortsSuiteBase
+
+ private def checkNumSorts(df: DataFrame, count: Int): Unit = {
+ val plan = df.queryExecution.executedPlan
+- assert(collectWithSubqueries(plan) { case s: SortExec => s }.length ==
count)
++ assert(collectWithSubqueries(plan) { case _: SortExec | _: CometSortExec
=> 1 }.length == count)
+ }
+
+ private def checkSorts(query: String, enabledCount: Int, disabledCount:
Int): Unit = {
+diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/ReplaceHashWithSortAggSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReplaceHashWithSortAggSuite.scala
+index 47679ed7865..9ffbaecb98e 100644
+---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/ReplaceHashWithSortAggSuite.scala
++++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReplaceHashWithSortAggSuite.scala
+@@ -18,6 +18,7 @@
+ package org.apache.spark.sql.execution
+
+ import org.apache.spark.sql.{DataFrame, QueryTest}
++import org.apache.spark.sql.comet.CometHashAggregateExec
+ import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper,
DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite}
+ import org.apache.spark.sql.execution.aggregate.{HashAggregateExec,
ObjectHashAggregateExec, SortAggregateExec}
+ import org.apache.spark.sql.internal.SQLConf
+@@ -31,7 +32,7 @@ abstract class ReplaceHashWithSortAggSuiteBase
+ private def checkNumAggs(df: DataFrame, hashAggCount: Int, sortAggCount:
Int): Unit = {
+ val plan = df.queryExecution.executedPlan
+ assert(collectWithSubqueries(plan) {
+- case s @ (_: HashAggregateExec | _: ObjectHashAggregateExec) => s
++ case s @ (_: HashAggregateExec | _: ObjectHashAggregateExec | _:
CometHashAggregateExec ) => s
+ }.length == hashAggCount)
+ assert(collectWithSubqueries(plan) { case s: SortAggregateExec => s
}.length == sortAggCount)
+ }
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
-index ac710c32296..37746bd470d 100644
+index ac710c32296..e163c1a6a76 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
-@@ -616,7 +616,9 @@ class WholeStageCodegenSuite extends QueryTest with
SharedSparkSession
+@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution
+
+ import org.apache.spark.sql.{Dataset, QueryTest, Row, SaveMode}
+ import org.apache.spark.sql.catalyst.expressions.codegen.{ByteCodeStats,
CodeAndComment, CodeGenerator}
++import org.apache.spark.sql.comet.CometSortMergeJoinExec
+ import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite
+ import org.apache.spark.sql.execution.aggregate.{HashAggregateExec,
SortAggregateExec}
+ import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
+@@ -224,6 +225,7 @@ class WholeStageCodegenSuite extends QueryTest with
SharedSparkSession
+ assert(twoJoinsDF.queryExecution.executedPlan.collect {
+ case WholeStageCodegenExec(_ : ShuffledHashJoinExec) if hint ==
"SHUFFLE_HASH" => true
+ case WholeStageCodegenExec(_ : SortMergeJoinExec) if hint ==
"SHUFFLE_MERGE" => true
++ case _: CometSortMergeJoinExec if hint == "SHUFFLE_MERGE" => true
+ }.size === 2)
+ checkAnswer(twoJoinsDF,
+ Seq(Row(0, 0, 0), Row(1, 1, null), Row(2, 2, 2), Row(3, 3, null),
Row(4, 4, null),
+@@ -258,6 +260,7 @@ class WholeStageCodegenSuite extends QueryTest with
SharedSparkSession
+ .join(df1.hint("SHUFFLE_MERGE"), $"k3" === $"k1", "right_outer")
+ assert(twoJoinsDF.queryExecution.executedPlan.collect {
+ case WholeStageCodegenExec(_ : SortMergeJoinExec) => true
++ case _: CometSortMergeJoinExec => true
+ }.size === 2)
+ checkAnswer(twoJoinsDF,
+ Seq(Row(0, 0, 0), Row(1, 1, 1), Row(2, 2, 2), Row(3, 3, 3), Row(4,
null, 4), Row(5, null, 5),
+@@ -280,8 +283,7 @@ class WholeStageCodegenSuite extends QueryTest with
SharedSparkSession
+ val twoJoinsDF = df3.join(df2.hint("SHUFFLE_MERGE"), $"k3" === $"k2",
"left_semi")
+ .join(df1.hint("SHUFFLE_MERGE"), $"k3" === $"k1", "left_semi")
+ assert(twoJoinsDF.queryExecution.executedPlan.collect {
+- case WholeStageCodegenExec(ProjectExec(_, _ : SortMergeJoinExec)) |
+- WholeStageCodegenExec(_ : SortMergeJoinExec) => true
++ case _: SortMergeJoinExec => true
+ }.size === 2)
+ checkAnswer(twoJoinsDF, Seq(Row(0), Row(1), Row(2), Row(3)))
+ }
+@@ -302,8 +304,7 @@ class WholeStageCodegenSuite extends QueryTest with
SharedSparkSession
+ val twoJoinsDF = df1.join(df2.hint("SHUFFLE_MERGE"), $"k1" === $"k2",
"left_anti")
+ .join(df3.hint("SHUFFLE_MERGE"), $"k1" === $"k3", "left_anti")
+ assert(twoJoinsDF.queryExecution.executedPlan.collect {
+- case WholeStageCodegenExec(ProjectExec(_, _ : SortMergeJoinExec)) |
+- WholeStageCodegenExec(_ : SortMergeJoinExec) => true
++ case _: SortMergeJoinExec => true
+ }.size === 2)
+ checkAnswer(twoJoinsDF, Seq(Row(6), Row(7), Row(8), Row(9)))
+ }
+@@ -436,7 +437,9 @@ class WholeStageCodegenSuite extends QueryTest with
SharedSparkSession
+ val plan = df.queryExecution.executedPlan
+ assert(plan.exists(p =>
+ p.isInstanceOf[WholeStageCodegenExec] &&
+- p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SortExec]))
++ p.asInstanceOf[WholeStageCodegenExec].collect {
++ case _: SortExec => true
++ }.nonEmpty))
+ assert(df.collect() === Array(Row(1), Row(2), Row(3)))
+ }
+
+@@ -616,7 +619,9 @@ class WholeStageCodegenSuite extends QueryTest with
SharedSparkSession
.write.mode(SaveMode.Overwrite).parquet(path)
withSQLConf(SQLConf.WHOLESTAGE_MAX_NUM_FIELDS.key -> "255",
@@ -587,18 +1102,31 @@ index ac710c32296..37746bd470d 100644
val df = spark.read.parquet(path).selectExpr(projection: _*)
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
-index 593bd7bb4ba..be1b82d0030 100644
+index 593bd7bb4ba..7ad55e3ab20 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
-@@ -29,6 +29,7 @@ import org.apache.spark.scheduler.{SparkListener,
SparkListenerEvent, SparkListe
- import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession, Strategy}
+@@ -26,9 +26,11 @@ import org.scalatest.time.SpanSugar._
+
+ import org.apache.spark.SparkException
+ import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent,
SparkListenerJobStart}
+-import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession, Strategy}
++import org.apache.spark.sql.{Dataset, IgnoreComet, QueryTest, Row,
SparkSession, Strategy}
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
+import org.apache.spark.sql.comet._
++import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
import org.apache.spark.sql.execution.{CollectLimitExec, LocalTableScanExec,
PartialReducerPartitionSpec, QueryExecution, ReusedSubqueryExec,
ShuffledRowRDD, SortExec, SparkPlan, SparkPlanInfo, UnionExec}
import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
import org.apache.spark.sql.execution.command.DataWritingCommandExec
-@@ -116,6 +117,9 @@ class AdaptiveQueryExecSuite
+@@ -104,6 +106,7 @@ class AdaptiveQueryExecSuite
+ private def findTopLevelBroadcastHashJoin(plan: SparkPlan):
Seq[BroadcastHashJoinExec] = {
+ collect(plan) {
+ case j: BroadcastHashJoinExec => j
++ case j: CometBroadcastHashJoinExec =>
j.originalPlan.asInstanceOf[BroadcastHashJoinExec]
+ }
+ }
+
+@@ -116,30 +119,38 @@ class AdaptiveQueryExecSuite
private def findTopLevelSortMergeJoin(plan: SparkPlan):
Seq[SortMergeJoinExec] = {
collect(plan) {
case j: SortMergeJoinExec => j
@@ -608,6 +1136,331 @@ index 593bd7bb4ba..be1b82d0030 100644
}
}
+ private def findTopLevelShuffledHashJoin(plan: SparkPlan):
Seq[ShuffledHashJoinExec] = {
+ collect(plan) {
+ case j: ShuffledHashJoinExec => j
++ case j: CometHashJoinExec =>
j.originalPlan.asInstanceOf[ShuffledHashJoinExec]
+ }
+ }
+
+ private def findTopLevelBaseJoin(plan: SparkPlan): Seq[BaseJoinExec] = {
+ collect(plan) {
+ case j: BaseJoinExec => j
++ case c: CometHashJoinExec => c.originalPlan.asInstanceOf[BaseJoinExec]
++ case c: CometSortMergeJoinExec =>
c.originalPlan.asInstanceOf[BaseJoinExec]
+ }
+ }
+
+ private def findTopLevelSort(plan: SparkPlan): Seq[SortExec] = {
+ collect(plan) {
+ case s: SortExec => s
++ case s: CometSortExec => s.originalPlan.asInstanceOf[SortExec]
+ }
+ }
+
+ private def findTopLevelAggregate(plan: SparkPlan): Seq[BaseAggregateExec]
= {
+ collect(plan) {
+ case agg: BaseAggregateExec => agg
++ case agg: CometHashAggregateExec =>
agg.originalPlan.asInstanceOf[BaseAggregateExec]
+ }
+ }
+
+@@ -176,6 +187,7 @@ class AdaptiveQueryExecSuite
+ val parts = rdd.partitions
+ assert(parts.forall(rdd.preferredLocations(_).nonEmpty))
+ }
++
+ assert(numShuffles === (numLocalReads.length +
numShufflesWithoutLocalRead))
+ }
+
+@@ -184,7 +196,7 @@ class AdaptiveQueryExecSuite
+ val plan = df.queryExecution.executedPlan
+ assert(plan.isInstanceOf[AdaptiveSparkPlanExec])
+ val shuffle =
plan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan.collect {
+- case s: ShuffleExchangeExec => s
++ case s: ShuffleExchangeLike => s
+ }
+ assert(shuffle.size == 1)
+ assert(shuffle(0).outputPartitioning.numPartitions == numPartition)
+@@ -200,7 +212,8 @@ class AdaptiveQueryExecSuite
+ assert(smj.size == 1)
+ val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
+ assert(bhj.size == 1)
+- checkNumLocalShuffleReads(adaptivePlan)
++ // Comet shuffle changes shuffle metrics
++ // checkNumLocalShuffleReads(adaptivePlan)
+ }
+ }
+
+@@ -227,7 +240,8 @@ class AdaptiveQueryExecSuite
+ }
+ }
+
+- test("Reuse the parallelism of coalesced shuffle in local shuffle read") {
++ test("Reuse the parallelism of coalesced shuffle in local shuffle read",
++ IgnoreComet("Comet shuffle changes shuffle partition size")) {
+ withSQLConf(
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80",
+@@ -259,7 +273,8 @@ class AdaptiveQueryExecSuite
+ }
+ }
+
+- test("Reuse the default parallelism in local shuffle read") {
++ test("Reuse the default parallelism in local shuffle read",
++ IgnoreComet("Comet shuffle changes shuffle partition size")) {
+ withSQLConf(
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80",
+@@ -273,7 +288,8 @@ class AdaptiveQueryExecSuite
+ val localReads = collect(adaptivePlan) {
+ case read: AQEShuffleReadExec if read.isLocalRead => read
+ }
+- assert(localReads.length == 2)
++ // Comet shuffle changes shuffle metrics
++ assert(localReads.length == 1)
+ val localShuffleRDD0 =
localReads(0).execute().asInstanceOf[ShuffledRowRDD]
+ val localShuffleRDD1 =
localReads(1).execute().asInstanceOf[ShuffledRowRDD]
+ // the final parallelism is math.max(1, numReduces / numMappers):
math.max(1, 5/2) = 2
+@@ -322,7 +338,7 @@ class AdaptiveQueryExecSuite
+ }
+ }
+
+- test("Scalar subquery") {
++ test("Scalar subquery", IgnoreComet("Comet shuffle changes shuffle
metrics")) {
+ withSQLConf(
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
+@@ -337,7 +353,7 @@ class AdaptiveQueryExecSuite
+ }
+ }
+
+- test("Scalar subquery in later stages") {
++ test("Scalar subquery in later stages", IgnoreComet("Comet shuffle changes
shuffle metrics")) {
+ withSQLConf(
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
+@@ -353,7 +369,7 @@ class AdaptiveQueryExecSuite
+ }
+ }
+
+- test("multiple joins") {
++ test("multiple joins", IgnoreComet("Comet shuffle changes shuffle
metrics")) {
+ withSQLConf(
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
+@@ -398,7 +414,7 @@ class AdaptiveQueryExecSuite
+ }
+ }
+
+- test("multiple joins with aggregate") {
++ test("multiple joins with aggregate", IgnoreComet("Comet shuffle changes
shuffle metrics")) {
+ withSQLConf(
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
+@@ -443,7 +459,7 @@ class AdaptiveQueryExecSuite
+ }
+ }
+
+- test("multiple joins with aggregate 2") {
++ test("multiple joins with aggregate 2", IgnoreComet("Comet shuffle changes
shuffle metrics")) {
+ withSQLConf(
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "500") {
+@@ -508,7 +524,7 @@ class AdaptiveQueryExecSuite
+ }
+ }
+
+- test("Exchange reuse with subqueries") {
++ test("Exchange reuse with subqueries", IgnoreComet("Comet shuffle changes
shuffle metrics")) {
+ withSQLConf(
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
+@@ -539,7 +555,9 @@ class AdaptiveQueryExecSuite
+ assert(smj.size == 1)
+ val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
+ assert(bhj.size == 1)
+- checkNumLocalShuffleReads(adaptivePlan)
++ // Comet shuffle changes shuffle metrics,
++ // so we can't check the number of local shuffle reads.
++ // checkNumLocalShuffleReads(adaptivePlan)
+ // Even with local shuffle read, the query stage reuse can also work.
+ val ex = findReusedExchange(adaptivePlan)
+ assert(ex.nonEmpty)
+@@ -560,7 +578,9 @@ class AdaptiveQueryExecSuite
+ assert(smj.size == 1)
+ val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
+ assert(bhj.size == 1)
+- checkNumLocalShuffleReads(adaptivePlan)
++ // Comet shuffle changes shuffle metrics,
++ // so we can't check the number of local shuffle reads.
++ // checkNumLocalShuffleReads(adaptivePlan)
+ // Even with local shuffle read, the query stage reuse can also work.
+ val ex = findReusedExchange(adaptivePlan)
+ assert(ex.isEmpty)
+@@ -569,7 +589,8 @@ class AdaptiveQueryExecSuite
+ }
+ }
+
+- test("Broadcast exchange reuse across subqueries") {
++ test("Broadcast exchange reuse across subqueries",
++ IgnoreComet("Comet shuffle changes shuffle metrics")) {
+ withSQLConf(
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "20000000",
+@@ -664,7 +685,8 @@ class AdaptiveQueryExecSuite
+ val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
+ assert(bhj.size == 1)
+ // There is still a SMJ, and its two shuffles can't apply local read.
+- checkNumLocalShuffleReads(adaptivePlan, 2)
++ // Comet shuffle changes shuffle metrics
++ // checkNumLocalShuffleReads(adaptivePlan, 2)
+ }
+ }
+
+@@ -786,7 +808,8 @@ class AdaptiveQueryExecSuite
+ }
+ }
+
+- test("SPARK-29544: adaptive skew join with different join types") {
++ test("SPARK-29544: adaptive skew join with different join types",
++ IgnoreComet("Comet shuffle has different partition metrics")) {
+ Seq("SHUFFLE_MERGE", "SHUFFLE_HASH").foreach { joinHint =>
+ def getJoinNode(plan: SparkPlan): Seq[ShuffledJoin] = if (joinHint ==
"SHUFFLE_MERGE") {
+ findTopLevelSortMergeJoin(plan)
+@@ -1004,7 +1027,8 @@ class AdaptiveQueryExecSuite
+ }
+ }
+
+- test("metrics of the shuffle read") {
++ test("metrics of the shuffle read",
++ IgnoreComet("Comet shuffle changes the metrics")) {
+ withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
+ val (_, adaptivePlan) = runAdaptiveAndVerifyResult(
+ "SELECT key FROM testData GROUP BY key")
+@@ -1599,7 +1623,7 @@ class AdaptiveQueryExecSuite
+ val (_, adaptivePlan) = runAdaptiveAndVerifyResult(
+ "SELECT id FROM v1 GROUP BY id DISTRIBUTE BY id")
+ assert(collect(adaptivePlan) {
+- case s: ShuffleExchangeExec => s
++ case s: ShuffleExchangeLike => s
+ }.length == 1)
+ }
+ }
+@@ -1679,7 +1703,8 @@ class AdaptiveQueryExecSuite
+ }
+ }
+
+- test("SPARK-33551: Do not use AQE shuffle read for repartition") {
++ test("SPARK-33551: Do not use AQE shuffle read for repartition",
++ IgnoreComet("Comet shuffle changes partition size")) {
+ def hasRepartitionShuffle(plan: SparkPlan): Boolean = {
+ find(plan) {
+ case s: ShuffleExchangeLike =>
+@@ -1864,6 +1889,9 @@ class AdaptiveQueryExecSuite
+ def checkNoCoalescePartitions(ds: Dataset[Row], origin: ShuffleOrigin):
Unit = {
+ assert(collect(ds.queryExecution.executedPlan) {
+ case s: ShuffleExchangeExec if s.shuffleOrigin == origin &&
s.numPartitions == 2 => s
++ case c: CometShuffleExchangeExec
++ if c.originalPlan.shuffleOrigin == origin &&
++ c.originalPlan.numPartitions == 2 => c
+ }.size == 1)
+ ds.collect()
+ val plan = ds.queryExecution.executedPlan
+@@ -1872,6 +1900,9 @@ class AdaptiveQueryExecSuite
+ }.isEmpty)
+ assert(collect(plan) {
+ case s: ShuffleExchangeExec if s.shuffleOrigin == origin &&
s.numPartitions == 2 => s
++ case c: CometShuffleExchangeExec
++ if c.originalPlan.shuffleOrigin == origin &&
++ c.originalPlan.numPartitions == 2 => c
+ }.size == 1)
+ checkAnswer(ds, testData)
+ }
+@@ -2028,7 +2059,8 @@ class AdaptiveQueryExecSuite
+ }
+ }
+
+- test("SPARK-35264: Support AQE side shuffled hash join formula") {
++ test("SPARK-35264: Support AQE side shuffled hash join formula",
++ IgnoreComet("Comet shuffle changes the partition size")) {
+ withTempView("t1", "t2") {
+ def checkJoinStrategy(shouldShuffleHashJoin: Boolean): Unit = {
+ Seq("100", "100000").foreach { size =>
+@@ -2114,7 +2146,8 @@ class AdaptiveQueryExecSuite
+ }
+ }
+
+- test("SPARK-35725: Support optimize skewed partitions in
RebalancePartitions") {
++ test("SPARK-35725: Support optimize skewed partitions in
RebalancePartitions",
++ IgnoreComet("Comet shuffle changes shuffle metrics")) {
+ withTempView("v") {
+ withSQLConf(
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+@@ -2213,7 +2246,7 @@ class AdaptiveQueryExecSuite
+ runAdaptiveAndVerifyResult(s"SELECT $repartition key1 FROM
skewData1 " +
+ s"JOIN skewData2 ON key1 = key2 GROUP BY key1")
+ val shuffles1 = collect(adaptive1) {
+- case s: ShuffleExchangeExec => s
++ case s: ShuffleExchangeLike => s
+ }
+ assert(shuffles1.size == 3)
+ // shuffles1.head is the top-level shuffle under the Aggregate
operator
+@@ -2226,7 +2259,7 @@ class AdaptiveQueryExecSuite
+ runAdaptiveAndVerifyResult(s"SELECT $repartition key1 FROM
skewData1 " +
+ s"JOIN skewData2 ON key1 = key2")
+ val shuffles2 = collect(adaptive2) {
+- case s: ShuffleExchangeExec => s
++ case s: ShuffleExchangeLike => s
+ }
+ if (hasRequiredDistribution) {
+ assert(shuffles2.size == 3)
+@@ -2260,7 +2293,8 @@ class AdaptiveQueryExecSuite
+ }
+ }
+
+- test("SPARK-35794: Allow custom plugin for cost evaluator") {
++ test("SPARK-35794: Allow custom plugin for cost evaluator",
++ IgnoreComet("Comet shuffle changes shuffle metrics")) {
+ CostEvaluator.instantiate(
+ classOf[SimpleShuffleSortCostEvaluator].getCanonicalName,
spark.sparkContext.getConf)
+ intercept[IllegalArgumentException] {
+@@ -2404,6 +2438,7 @@ class AdaptiveQueryExecSuite
+ val (_, adaptive) = runAdaptiveAndVerifyResult(query)
+ assert(adaptive.collect {
+ case sort: SortExec => sort
++ case sort: CometSortExec => sort
+ }.size == 1)
+ val read = collect(adaptive) {
+ case read: AQEShuffleReadExec => read
+@@ -2421,7 +2456,8 @@ class AdaptiveQueryExecSuite
+ }
+ }
+
+- test("SPARK-37357: Add small partition factor for rebalance partitions") {
++ test("SPARK-37357: Add small partition factor for rebalance partitions",
++ IgnoreComet("Comet shuffle changes shuffle metrics")) {
+ withTempView("v") {
+ withSQLConf(
+ SQLConf.ADAPTIVE_OPTIMIZE_SKEWS_IN_REBALANCE_PARTITIONS_ENABLED.key
-> "true",
+@@ -2533,7 +2569,7 @@ class AdaptiveQueryExecSuite
+ runAdaptiveAndVerifyResult("SELECT key1 FROM skewData1 JOIN
skewData2 ON key1 = key2 " +
+ "JOIN skewData3 ON value2 = value3")
+ val shuffles1 = collect(adaptive1) {
+- case s: ShuffleExchangeExec => s
++ case s: ShuffleExchangeLike => s
+ }
+ assert(shuffles1.size == 4)
+ val smj1 = findTopLevelSortMergeJoin(adaptive1)
+@@ -2544,7 +2580,7 @@ class AdaptiveQueryExecSuite
+ runAdaptiveAndVerifyResult("SELECT key1 FROM skewData1 JOIN
skewData2 ON key1 = key2 " +
+ "JOIN skewData3 ON value1 = value3")
+ val shuffles2 = collect(adaptive2) {
+- case s: ShuffleExchangeExec => s
++ case s: ShuffleExchangeLike => s
+ }
+ assert(shuffles2.size == 4)
+ val smj2 = findTopLevelSortMergeJoin(adaptive2)
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
index bd9c79e5b96..ab7584e768e 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
@@ -628,6 +1481,29 @@ index bd9c79e5b96..ab7584e768e 100644
}
assert(fileSourceScanSchemata.size === expectedSchemaCatalogStrings.size,
s"Found ${fileSourceScanSchemata.size} file sources in dataframe, " +
+diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala
+index ce43edb79c1..c414b19eda7 100644
+---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala
++++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala
+@@ -17,7 +17,7 @@
+
+ package org.apache.spark.sql.execution.datasources
+
+-import org.apache.spark.sql.{QueryTest, Row}
++import org.apache.spark.sql.{IgnoreComet, QueryTest, Row}
+ import org.apache.spark.sql.catalyst.expressions.{Ascending,
AttributeReference, NullsFirst, SortOrder}
+ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Sort}
+ import org.apache.spark.sql.execution.{QueryExecution, SortExec}
+@@ -305,7 +305,8 @@ class V1WriteCommandSuite extends QueryTest with
SharedSparkSession with V1Write
+ }
+ }
+
+- test("v1 write with AQE changing SMJ to BHJ") {
++ test("v1 write with AQE changing SMJ to BHJ",
++ IgnoreComet("TODO: Comet SMJ to BHJ by AQE")) {
+ withPlannedWrite { enabled =>
+ withTable("t") {
+ sql(
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala
index 1d2e467c94c..3ea82cd1a3f 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala
@@ -917,10 +1793,22 @@ index 3a0bd35cb70..b28f06a757f 100644
val workDirPath = workDir.getAbsolutePath
val input = spark.range(5).toDF("id")
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
-index 26e61c6b58d..cde10983c68 100644
+index 26e61c6b58d..cb09d7e116a 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
-@@ -737,7 +737,8 @@ class SQLMetricsSuite extends SharedSparkSession with
SQLMetricsTestUtils
+@@ -45,8 +45,10 @@ import org.apache.spark.sql.util.QueryExecutionListener
+ import org.apache.spark.util.{AccumulatorContext, JsonProtocol}
+
+ // Disable AQE because metric info is different with AQE on/off
++// This test suite runs tests against the metrics of physical operators.
++// Disabling it for Comet because the metrics are different with Comet
enabled.
+ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils
+- with DisableAdaptiveExecutionSuite {
++ with DisableAdaptiveExecutionSuite with IgnoreCometSuite {
+ import testImplicits._
+
+ /**
+@@ -737,7 +739,8 @@ class SQLMetricsSuite extends SharedSparkSession with
SQLMetricsTestUtils
}
}
@@ -1002,21 +1890,24 @@ index d083cac48ff..3c11bcde807 100644
import testImplicits._
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
-index 266bb343526..b33bb677f0d 100644
+index 266bb343526..a426d8396be 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
-@@ -24,7 +24,9 @@ import org.apache.spark.sql.catalyst.catalog.BucketSpec
+@@ -24,10 +24,11 @@ import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
-import org.apache.spark.sql.execution.{FileSourceScanExec, SortExec,
SparkPlan}
+import org.apache.spark.sql.comet._
-+import org.apache.spark.sql.comet.execution.shuffle._
+import org.apache.spark.sql.execution.{ColumnarToRowExec, FileSourceScanExec,
SortExec, SparkPlan}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec,
AdaptiveSparkPlanHelper, DisableAdaptiveExecution}
import org.apache.spark.sql.execution.datasources.BucketingUtils
- import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
-@@ -101,12 +103,20 @@ abstract class BucketedReadSuite extends QueryTest with
SQLTestUtils with Adapti
+-import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
++import org.apache.spark.sql.execution.exchange.{ShuffleExchangeExec,
ShuffleExchangeLike}
+ import org.apache.spark.sql.execution.joins.SortMergeJoinExec
+ import org.apache.spark.sql.functions._
+ import org.apache.spark.sql.internal.SQLConf
+@@ -101,12 +102,20 @@ abstract class BucketedReadSuite extends QueryTest with
SQLTestUtils with Adapti
}
}
@@ -1039,7 +1930,7 @@ index 266bb343526..b33bb677f0d 100644
// To verify if the bucket pruning works, this function checks two
conditions:
// 1) Check if the pruned buckets (before filtering) are empty.
// 2) Verify the final result is the same as the expected one
-@@ -155,7 +165,8 @@ abstract class BucketedReadSuite extends QueryTest with
SQLTestUtils with Adapti
+@@ -155,7 +164,8 @@ abstract class BucketedReadSuite extends QueryTest with
SQLTestUtils with Adapti
val planWithoutBucketedScan =
bucketedDataFrame.filter(filterCondition)
.queryExecution.executedPlan
val fileScan = getFileScan(planWithoutBucketedScan)
@@ -1049,7 +1940,7 @@ index 266bb343526..b33bb677f0d 100644
val bucketColumnType =
bucketedDataFrame.schema.apply(bucketColumnIndex).dataType
val rowsWithInvalidBuckets = fileScan.execute().filter(row => {
-@@ -451,28 +462,46 @@ abstract class BucketedReadSuite extends QueryTest with
SQLTestUtils with Adapti
+@@ -451,28 +461,44 @@ abstract class BucketedReadSuite extends QueryTest with
SQLTestUtils with Adapti
val joinOperator = if
(joined.sqlContext.conf.adaptiveExecutionEnabled) {
val executedPlan =
joined.queryExecution.executedPlan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan
@@ -1082,13 +1973,11 @@ index 266bb343526..b33bb677f0d 100644
// check existence of shuffle
assert(
- joinOperator.left.exists(_.isInstanceOf[ShuffleExchangeExec]) ==
shuffleLeft,
-+ joinOperator.left.exists(op => op.isInstanceOf[ShuffleExchangeExec]
||
-+ op.isInstanceOf[CometShuffleExchangeExec]) == shuffleLeft,
++ joinOperator.left.exists(op =>
op.isInstanceOf[ShuffleExchangeLike]) == shuffleLeft,
s"expected shuffle in plan to be $shuffleLeft but
found\n${joinOperator.left}")
assert(
- joinOperator.right.exists(_.isInstanceOf[ShuffleExchangeExec]) ==
shuffleRight,
-+ joinOperator.right.exists(op =>
op.isInstanceOf[ShuffleExchangeExec] ||
-+ op.isInstanceOf[CometShuffleExchangeExec]) == shuffleRight,
++ joinOperator.right.exists(op =>
op.isInstanceOf[ShuffleExchangeLike]) == shuffleRight,
s"expected shuffle in plan to be $shuffleRight but
found\n${joinOperator.right}")
// check existence of sort
@@ -1104,7 +1993,7 @@ index 266bb343526..b33bb677f0d 100644
s"expected sort in the right child to be $sortRight but
found\n${joinOperator.right}")
// check the output partitioning
-@@ -835,11 +864,11 @@ abstract class BucketedReadSuite extends QueryTest with
SQLTestUtils with Adapti
+@@ -835,11 +861,11 @@ abstract class BucketedReadSuite extends QueryTest with
SQLTestUtils with Adapti
df1.write.format("parquet").bucketBy(8,
"i").saveAsTable("bucketed_table")
val scanDF = spark.table("bucketed_table").select("j")
@@ -1118,14 +2007,13 @@ index 266bb343526..b33bb677f0d 100644
checkAnswer(aggDF, df1.groupBy("j").agg(max("k")))
}
}
-@@ -1026,15 +1055,24 @@ abstract class BucketedReadSuite extends QueryTest
with SQLTestUtils with Adapti
+@@ -1026,15 +1052,23 @@ abstract class BucketedReadSuite extends QueryTest
with SQLTestUtils with Adapti
expectedNumShuffles: Int,
expectedCoalescedNumBuckets: Option[Int]): Unit = {
val plan = sql(query).queryExecution.executedPlan
- val shuffles = plan.collect { case s: ShuffleExchangeExec => s }
+ val shuffles = plan.collect {
-+ case s: ShuffleExchangeExec => s
-+ case s: CometShuffleExchangeExec => s
++ case s: ShuffleExchangeLike => s
+ }
assert(shuffles.length == expectedNumShuffles)
@@ -1303,6 +2191,120 @@ index 2a2a83d35e1..e3b7b290b3e 100644
val initialStateDS = Seq(("keyInStateAndData", new
RunningCount(1))).toDS()
val initialState: KeyValueGroupedDataset[String, RunningCount] =
initialStateDS.groupByKey(_._1).mapValues(_._2)
+diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
+index ef5b8a769fe..84fe1bfabc9 100644
+--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
++++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
+@@ -37,6 +37,7 @@ import org.apache.spark.sql._
+ import org.apache.spark.sql.catalyst.plans.logical.{Range,
RepartitionByExpression}
+ import org.apache.spark.sql.catalyst.streaming.{InternalOutputModes,
StreamingRelationV2}
+ import org.apache.spark.sql.catalyst.util.DateTimeUtils
++import org.apache.spark.sql.comet.CometLocalLimitExec
+ import org.apache.spark.sql.execution.{LocalLimitExec, SimpleMode, SparkPlan}
+ import org.apache.spark.sql.execution.command.ExplainCommand
+ import org.apache.spark.sql.execution.streaming._
+@@ -1103,11 +1104,12 @@ class StreamSuite extends StreamTest {
+ val localLimits = execPlan.collect {
+ case l: LocalLimitExec => l
+ case l: StreamingLocalLimitExec => l
++ case l: CometLocalLimitExec => l
+ }
+
+ require(
+ localLimits.size == 1,
+- s"Cant verify local limit optimization with this plan:\n$execPlan")
++ s"Cant verify local limit optimization ${localLimits.size} with this
plan:\n$execPlan")
+
+ if (expectStreamingLimit) {
+ assert(
+@@ -1115,7 +1117,8 @@ class StreamSuite extends StreamTest {
+ s"Local limit was not StreamingLocalLimitExec:\n$execPlan")
+ } else {
+ assert(
+- localLimits.head.isInstanceOf[LocalLimitExec],
++ localLimits.head.isInstanceOf[LocalLimitExec] ||
++ localLimits.head.isInstanceOf[CometLocalLimitExec],
+ s"Local limit was not LocalLimitExec:\n$execPlan")
+ }
+ }
+diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationDistributionSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationDistributionSuite.scala
+index b4c4ec7acbf..20579284856 100644
+---
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationDistributionSuite.scala
++++
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationDistributionSuite.scala
+@@ -23,6 +23,7 @@ import org.apache.commons.io.FileUtils
+ import org.scalatest.Assertions
+
+ import org.apache.spark.sql.catalyst.plans.physical.UnspecifiedDistribution
++import org.apache.spark.sql.comet.CometHashAggregateExec
+ import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
+ import org.apache.spark.sql.execution.streaming.{MemoryStream,
StateStoreRestoreExec, StateStoreSaveExec}
+ import org.apache.spark.sql.functions.count
+@@ -67,6 +68,7 @@ class StreamingAggregationDistributionSuite extends
StreamTest
+ // verify aggregations in between, except partial aggregation
+ val allAggregateExecs = query.lastExecution.executedPlan.collect {
+ case a: BaseAggregateExec => a
++ case c: CometHashAggregateExec => c.originalPlan
+ }
+
+ val aggregateExecsWithoutPartialAgg = allAggregateExecs.filter {
+@@ -201,6 +203,7 @@ class StreamingAggregationDistributionSuite extends
StreamTest
+ // verify aggregations in between, except partial aggregation
+ val allAggregateExecs = executedPlan.collect {
+ case a: BaseAggregateExec => a
++ case c: CometHashAggregateExec => c.originalPlan
+ }
+
+ val aggregateExecsWithoutPartialAgg = allAggregateExecs.filter {
+diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
+index 4d92e270539..33f1c2eb75e 100644
+---
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
++++
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
+@@ -31,7 +31,7 @@ import org.apache.spark.scheduler.ExecutorCacheTaskLocation
+ import org.apache.spark.sql.{DataFrame, Row, SparkSession}
+ import org.apache.spark.sql.catalyst.expressions.{AttributeReference,
Expression}
+ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
+-import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
++import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike
+ import org.apache.spark.sql.execution.streaming.{MemoryStream,
StatefulOperatorStateInfo, StreamingSymmetricHashJoinExec,
StreamingSymmetricHashJoinHelper}
+ import
org.apache.spark.sql.execution.streaming.state.{RocksDBStateStoreProvider,
StateStore, StateStoreProviderId}
+ import org.apache.spark.sql.functions._
+@@ -619,14 +619,28 @@ class StreamingInnerJoinSuite extends StreamingJoinSuite
{
+
+ val numPartitions =
spark.sqlContext.conf.getConf(SQLConf.SHUFFLE_PARTITIONS)
+
+- assert(query.lastExecution.executedPlan.collect {
+- case j @ StreamingSymmetricHashJoinExec(_, _, _, _, _, _, _, _, _,
+- ShuffleExchangeExec(opA: HashPartitioning, _, _),
+- ShuffleExchangeExec(opB: HashPartitioning, _, _))
+- if partitionExpressionsColumns(opA.expressions) === Seq("a",
"b")
+- && partitionExpressionsColumns(opB.expressions) === Seq("a",
"b")
+- && opA.numPartitions == numPartitions && opB.numPartitions ==
numPartitions => j
+- }.size == 1)
++ val join = query.lastExecution.executedPlan.collect {
++ case j: StreamingSymmetricHashJoinExec => j
++ }.head
++ val opA = join.left.collect {
++ case s: ShuffleExchangeLike
++ if s.outputPartitioning.isInstanceOf[HashPartitioning] &&
++ partitionExpressionsColumns(
++ s.outputPartitioning
++ .asInstanceOf[HashPartitioning].expressions) === Seq("a",
"b") =>
++ s.outputPartitioning
++ .asInstanceOf[HashPartitioning]
++ }.head
++ val opB = join.right.collect {
++ case s: ShuffleExchangeLike
++ if s.outputPartitioning.isInstanceOf[HashPartitioning] &&
++ partitionExpressionsColumns(
++ s.outputPartitioning
++ .asInstanceOf[HashPartitioning].expressions) === Seq("a",
"b") =>
++ s.outputPartitioning
++ .asInstanceOf[HashPartitioning]
++ }.head
++ assert(opA.numPartitions == numPartitions && opB.numPartitions ==
numPartitions)
+ })
+ }
+
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala
index abe606ad9c1..2d930b64cca 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala
@@ -1327,7 +2329,7 @@ index abe606ad9c1..2d930b64cca 100644
val tblTargetName = "tbl_target"
val tblSourceQualified = s"default.$tblSourceName"
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
-index dd55fcfe42c..b4776c50e49 100644
+index dd55fcfe42c..293e9dc2986 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
@@ -41,6 +41,7 @@ import org.apache.spark.sql.catalyst.plans.PlanTest
@@ -1364,20 +2366,20 @@ index dd55fcfe42c..b4776c50e49 100644
+ }
+
+ /**
-+ * Whether Spark should only apply Comet scan optimization. This is only
effective when
++ * Whether to enable ansi mode This is only effective when
+ * [[isCometEnabled]] returns true.
+ */
-+ protected def isCometScanOnly: Boolean = {
-+ val v = System.getenv("ENABLE_COMET_SCAN_ONLY")
++ protected def enableCometAnsiMode: Boolean = {
++ val v = System.getenv("ENABLE_COMET_ANSI_MODE")
+ v != null && v.toBoolean
+ }
+
+ /**
-+ * Whether to enable ansi mode This is only effective when
++ * Whether Spark should only apply Comet scan optimization. This is only
effective when
+ * [[isCometEnabled]] returns true.
+ */
-+ protected def enableCometAnsiMode: Boolean = {
-+ val v = System.getenv("ENABLE_COMET_ANSI_MODE")
++ protected def isCometScanOnly: Boolean = {
++ val v = System.getenv("ENABLE_COMET_SCAN_ONLY")
+ v != null && v.toBoolean
+ }
+
@@ -1394,7 +2396,7 @@ index dd55fcfe42c..b4776c50e49 100644
spark.internalCreateDataFrame(withoutFilters.execute(), schema)
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala
b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala
-index ed2e309fa07..f64cc283903 100644
+index ed2e309fa07..e071fc44960 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala
@@ -74,6 +74,28 @@ trait SharedSparkSessionBase
@@ -1414,6 +2416,7 @@ index ed2e309fa07..f64cc283903 100644
+ .set("spark.shuffle.manager",
+
"org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager")
+ .set("spark.comet.exec.shuffle.enabled", "true")
++ .set("spark.comet.memoryOverhead", "10g")
+ }
+
+ if (enableCometAnsiMode) {
@@ -1421,11 +2424,23 @@ index ed2e309fa07..f64cc283903 100644
+ .set("spark.sql.ansi.enabled", "true")
+ .set("spark.comet.ansi.enabled", "true")
+ }
-+
+ }
conf.set(
StaticSQLConf.WAREHOUSE_PATH,
conf.get(StaticSQLConf.WAREHOUSE_PATH) + "/" +
getClass.getCanonicalName)
+diff --git
a/sql/core/src/test/scala/org/apache/spark/status/api/v1/sql/SqlResourceWithActualMetricsSuite.scala
b/sql/core/src/test/scala/org/apache/spark/status/api/v1/sql/SqlResourceWithActualMetricsSuite.scala
+index 1510e8957f9..7618419d8ff 100644
+---
a/sql/core/src/test/scala/org/apache/spark/status/api/v1/sql/SqlResourceWithActualMetricsSuite.scala
++++
b/sql/core/src/test/scala/org/apache/spark/status/api/v1/sql/SqlResourceWithActualMetricsSuite.scala
+@@ -43,7 +43,7 @@ class SqlResourceWithActualMetricsSuite
+ import testImplicits._
+
+ // Exclude nodes which may not have the metrics
+- val excludedNodes = List("WholeStageCodegen", "Project",
"SerializeFromObject")
++ val excludedNodes = List("WholeStageCodegen", "Project",
"SerializeFromObject", "RowToColumnar")
+
+ implicit val formats = new DefaultFormats {
+ override def dateFormatter = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss")
diff --git
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/DynamicPartitionPruningHiveScanSuite.scala
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/DynamicPartitionPruningHiveScanSuite.scala
index 52abd248f3a..7a199931a08 100644
---
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/DynamicPartitionPruningHiveScanSuite.scala
@@ -1463,10 +2478,10 @@ index 1966e1e64fd..cde97a0aafe 100644
spark.sql(
"""
diff --git
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala
-index 07361cfdce9..1763168a808 100644
+index 07361cfdce9..25b0dc3ef7e 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala
-@@ -55,25 +55,54 @@ object TestHive
+@@ -55,25 +55,53 @@ object TestHive
new SparkContext(
System.getProperty("spark.sql.test.master", "local[1]"),
"TestSQLContext",
@@ -1530,9 +2545,8 @@ index 07361cfdce9..1763168a808 100644
+ .set("spark.sql.ansi.enabled", "true")
+ .set("spark.comet.ansi.enabled", "true")
+ }
-
+ }
-+
+
+ conf
+ }
+ ))
diff --git a/docs/source/user-guide/configs.md
b/docs/source/user-guide/configs.md
index 02ecbd69..22a7a098 100644
--- a/docs/source/user-guide/configs.md
+++ b/docs/source/user-guide/configs.md
@@ -29,7 +29,7 @@ Comet provides the following configuration settings.
| spark.comet.columnar.shuffle.async.enabled | Whether to enable asynchronous
shuffle for Arrow-based shuffle. By default, this config is false. | false |
| spark.comet.columnar.shuffle.async.max.thread.num | Maximum number of
threads on an executor used for Comet async columnar shuffle. By default, this
config is 100. This is the upper bound of total number of shuffle threads per
executor. In other words, if the number of cores * the number of shuffle
threads per task `spark.comet.columnar.shuffle.async.thread.num` is larger than
this config. Comet will use this config as the number of shuffle threads per
executor instead. | 100 |
| spark.comet.columnar.shuffle.async.thread.num | Number of threads used for
Comet async columnar shuffle per shuffle task. By default, this config is 3.
Note that more threads means more memory requirement to buffer shuffle data
before flushing to disk. Also, more threads may not always improve performance,
and should be set based on the number of cores available. | 3 |
-| spark.comet.columnar.shuffle.enabled | Force Comet to only use columnar
shuffle for CometScan and Spark regular operators. If this is enabled, Comet
native shuffle will not be enabled but only Arrow shuffle. By default, this
config is false. | false |
+| spark.comet.columnar.shuffle.enabled | Whether to enable Arrow-based
columnar shuffle for Comet and Spark regular operators. If this is enabled,
Comet prefers columnar shuffle than native shuffle. By default, this config is
true. | true |
| spark.comet.columnar.shuffle.memory.factor | Fraction of Comet memory to be
allocated per executor process for Comet shuffle. Comet memory size is
specified by `spark.comet.memoryOverhead` or calculated by
`spark.comet.memory.overhead.factor` * `spark.executor.memory`. By default,
this config is 1.0. | 1.0 |
| spark.comet.debug.enabled | Whether to enable debug mode for Comet. By
default, this config is false. When enabled, Comet will do additional checks
for debugging purpose. For example, validating array when importing arrays from
JVM at native side. Note that these checks may be expensive in performance and
should only be enabled for debugging purpose. | false |
| spark.comet.enabled | Whether to enable Comet extension for Spark. When this
is turned on, Spark will use Comet to read Parquet data source. Note that to
enable native vectorized execution, both this config and
'spark.comet.exec.enabled' need to be enabled. By default, this config is the
value of the env var `ENABLE_COMET` if set, or true otherwise. | true |
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index ad477664..7238990a 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -2530,7 +2530,7 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde with CometExprShim
withInfo(join, "SortMergeJoin is not enabled")
None
- case op if isCometSink(op) =>
+ case op if isCometSink(op) && op.output.forall(a =>
supportedDataType(a.dataType)) =>
// These operators are source of Comet native execution chain
val scanBuilder = OperatorOuterClass.Scan.newBuilder()
diff --git
a/spark/src/main/scala/org/apache/spark/sql/comet/plans/AliasAwareOutputExpression.scala
b/spark/src/main/scala/org/apache/spark/sql/comet/plans/AliasAwareOutputExpression.scala
index 6e5b44c8..996526e5 100644
---
a/spark/src/main/scala/org/apache/spark/sql/comet/plans/AliasAwareOutputExpression.scala
+++
b/spark/src/main/scala/org/apache/spark/sql/comet/plans/AliasAwareOutputExpression.scala
@@ -31,7 +31,8 @@ import org.apache.spark.sql.catalyst.trees.CurrentOrigin
trait AliasAwareOutputExpression extends SQLConfHelper {
// `SQLConf.EXPRESSION_PROJECTION_CANDIDATE_LIMIT` is Spark 3.4+ only.
// Use a default value for now.
- protected val aliasCandidateLimit = 100
+ protected val aliasCandidateLimit: Int =
+
conf.getConfString("spark.sql.optimizer.expressionProjectionCandidateLimit",
"100").toInt
protected def outputExpressions: Seq[NamedExpression]
/**
diff --git
a/spark/src/test/scala/org/apache/spark/sql/CometTPCDSQuerySuite.scala
b/spark/src/test/scala/org/apache/spark/sql/CometTPCDSQuerySuite.scala
index 3342d750..eb27dd36 100644
--- a/spark/src/test/scala/org/apache/spark/sql/CometTPCDSQuerySuite.scala
+++ b/spark/src/test/scala/org/apache/spark/sql/CometTPCDSQuerySuite.scala
@@ -157,8 +157,9 @@ class CometTPCDSQuerySuite
conf.set(CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key, "true")
conf.set(CometConf.COMET_EXEC_ALL_EXPR_ENABLED.key, "true")
conf.set(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key, "true")
+ conf.set(CometConf.COMET_MEMORY_OVERHEAD.key, "20g")
conf.set(MEMORY_OFFHEAP_ENABLED.key, "true")
- conf.set(MEMORY_OFFHEAP_SIZE.key, "2g")
+ conf.set(MEMORY_OFFHEAP_SIZE.key, "20g")
conf
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]