This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new cef3e0528 test: fix Spark 3.5 tests (#1482)
cef3e0528 is described below
commit cef3e05283f98f294eecae9c3460b0d79b222fcc
Author: KAZUYUKI TANIMURA <[email protected]>
AuthorDate: Tue Mar 11 14:22:14 2025 -0700
test: fix Spark 3.5 tests (#1482)
---
.github/workflows/spark_sql_test.yml | 2 +-
.../java/org/apache/comet/parquet/BatchReader.java | 32 +--
.../spark/sql/comet/shims/ShimTaskMetrics.scala | 29 +++
.../spark/sql/comet/shims/ShimTaskMetrics.scala | 29 +++
.../spark/sql/comet/shims/ShimTaskMetrics.scala | 31 +++
.../spark/sql/comet/shims/ShimTaskMetrics.scala | 29 +++
dev/diffs/{3.5.1.diff => 3.5.4.diff} | 214 ++++++++++++---------
docs/source/contributor-guide/spark-sql-tests.md | 6 +-
.../char_varchar_utils/read_side_padding.rs | 22 ++-
pom.xml | 4 +-
10 files changed, 266 insertions(+), 132 deletions(-)
diff --git a/.github/workflows/spark_sql_test.yml
b/.github/workflows/spark_sql_test.yml
index b325a5193..8d60f0769 100644
--- a/.github/workflows/spark_sql_test.yml
+++ b/.github/workflows/spark_sql_test.yml
@@ -45,7 +45,7 @@ jobs:
matrix:
os: [ubuntu-24.04]
java-version: [11]
- spark-version: [{short: '3.4', full: '3.4.3'}, {short: '3.5', full:
'3.5.1'}]
+ spark-version: [{short: '3.4', full: '3.4.3'}, {short: '3.5', full:
'3.5.4'}]
module:
- {name: "catalyst", args1: "catalyst/test", args2: ""}
- {name: "sql/core-1", args1: "", args2: sql/testOnly * -- -l
org.apache.spark.tags.ExtendedSQLTest -l org.apache.spark.tags.SlowSQLTest}
diff --git a/common/src/main/java/org/apache/comet/parquet/BatchReader.java
b/common/src/main/java/org/apache/comet/parquet/BatchReader.java
index 675dae9e7..dbf1b8180 100644
--- a/common/src/main/java/org/apache/comet/parquet/BatchReader.java
+++ b/common/src/main/java/org/apache/comet/parquet/BatchReader.java
@@ -21,8 +21,6 @@ package org.apache.comet.parquet;
import java.io.Closeable;
import java.io.IOException;
-import java.lang.reflect.InvocationTargetException;
-import java.lang.reflect.Method;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.Arrays;
@@ -35,8 +33,6 @@ import java.util.concurrent.Future;
import java.util.concurrent.LinkedBlockingQueue;
import scala.Option;
-import scala.collection.Seq;
-import scala.collection.mutable.Buffer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -61,9 +57,9 @@ import org.apache.parquet.schema.MessageType;
import org.apache.parquet.schema.Type;
import org.apache.spark.TaskContext;
import org.apache.spark.TaskContext$;
-import org.apache.spark.executor.TaskMetrics;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.comet.parquet.CometParquetReadSupport;
+import org.apache.spark.sql.comet.shims.ShimTaskMetrics;
import org.apache.spark.sql.execution.datasources.PartitionedFile;
import
org.apache.spark.sql.execution.datasources.parquet.ParquetToSparkSchemaConverter;
import org.apache.spark.sql.execution.metric.SQLMetric;
@@ -350,7 +346,8 @@ public class BatchReader extends RecordReader<Void,
ColumnarBatch> implements Cl
// Note that this tries to get thread local TaskContext object, if this is
called at other
// thread, it won't update the accumulator.
if (taskContext != null) {
- Option<AccumulatorV2<?, ?>> accu =
getTaskAccumulator(taskContext.taskMetrics());
+ Option<AccumulatorV2<?, ?>> accu =
+ ShimTaskMetrics.getTaskAccumulator(taskContext.taskMetrics());
if (accu.isDefined() &&
accu.get().getClass().getSimpleName().equals("NumRowGroupsAcc")) {
@SuppressWarnings("unchecked")
AccumulatorV2<Integer, Integer> intAccum = (AccumulatorV2<Integer,
Integer>) accu.get();
@@ -637,27 +634,4 @@ public class BatchReader extends RecordReader<Void,
ColumnarBatch> implements Cl
}
}
}
-
- // Signature of externalAccums changed from returning a Buffer to returning
a Seq. If comet is
- // expecting a Buffer but the Spark version returns a Seq or vice versa, we
get a
- // method not found exception.
- @SuppressWarnings("unchecked")
- private Option<AccumulatorV2<?, ?>> getTaskAccumulator(TaskMetrics
taskMetrics) {
- Method externalAccumsMethod;
- try {
- externalAccumsMethod =
TaskMetrics.class.getDeclaredMethod("externalAccums");
- externalAccumsMethod.setAccessible(true);
- String returnType = externalAccumsMethod.getReturnType().getName();
- if (returnType.equals("scala.collection.mutable.Buffer")) {
- return ((Buffer<AccumulatorV2<?, ?>>)
externalAccumsMethod.invoke(taskMetrics))
- .lastOption();
- } else if (returnType.equals("scala.collection.Seq")) {
- return ((Seq<AccumulatorV2<?, ?>>)
externalAccumsMethod.invoke(taskMetrics)).lastOption();
- } else {
- return Option.apply(null); // None
- }
- } catch (NoSuchMethodException | InvocationTargetException |
IllegalAccessException e) {
- return Option.apply(null); // None
- }
- }
}
diff --git
a/common/src/main/spark-3.3/org/apache/spark/sql/comet/shims/ShimTaskMetrics.scala
b/common/src/main/spark-3.3/org/apache/spark/sql/comet/shims/ShimTaskMetrics.scala
new file mode 100644
index 000000000..5b2a5fb5b
--- /dev/null
+++
b/common/src/main/spark-3.3/org/apache/spark/sql/comet/shims/ShimTaskMetrics.scala
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.comet.shims
+
+import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.util.AccumulatorV2
+
+object ShimTaskMetrics {
+
+ def getTaskAccumulator(taskMetrics: TaskMetrics): Option[AccumulatorV2[_,
_]] =
+ taskMetrics.externalAccums.lastOption
+}
diff --git
a/common/src/main/spark-3.4/org/apache/spark/sql/comet/shims/ShimTaskMetrics.scala
b/common/src/main/spark-3.4/org/apache/spark/sql/comet/shims/ShimTaskMetrics.scala
new file mode 100644
index 000000000..5b2a5fb5b
--- /dev/null
+++
b/common/src/main/spark-3.4/org/apache/spark/sql/comet/shims/ShimTaskMetrics.scala
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.comet.shims
+
+import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.util.AccumulatorV2
+
+object ShimTaskMetrics {
+
+ def getTaskAccumulator(taskMetrics: TaskMetrics): Option[AccumulatorV2[_,
_]] =
+ taskMetrics.externalAccums.lastOption
+}
diff --git
a/common/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimTaskMetrics.scala
b/common/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimTaskMetrics.scala
new file mode 100644
index 000000000..2ca0ef277
--- /dev/null
+++
b/common/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimTaskMetrics.scala
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.comet.shims
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.util.AccumulatorV2
+
+object ShimTaskMetrics {
+
+ def getTaskAccumulator(taskMetrics: TaskMetrics): Option[AccumulatorV2[_,
_]] =
+ taskMetrics.withExternalAccums(identity[ArrayBuffer[AccumulatorV2[_,
_]]](_)).lastOption
+}
diff --git
a/common/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimTaskMetrics.scala
b/common/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimTaskMetrics.scala
new file mode 100644
index 000000000..5b2a5fb5b
--- /dev/null
+++
b/common/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimTaskMetrics.scala
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.comet.shims
+
+import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.util.AccumulatorV2
+
+object ShimTaskMetrics {
+
+ def getTaskAccumulator(taskMetrics: TaskMetrics): Option[AccumulatorV2[_,
_]] =
+ taskMetrics.externalAccums.lastOption
+}
diff --git a/dev/diffs/3.5.1.diff b/dev/diffs/3.5.4.diff
similarity index 96%
rename from dev/diffs/3.5.1.diff
rename to dev/diffs/3.5.4.diff
index 762cd948d..47bc3ccd0 100644
--- a/dev/diffs/3.5.1.diff
+++ b/dev/diffs/3.5.4.diff
@@ -1,5 +1,5 @@
diff --git a/pom.xml b/pom.xml
-index 0f504dbee85..430ec217e59 100644
+index 8dc47f391f9..8a3e72133a8 100644
--- a/pom.xml
+++ b/pom.xml
@@ -152,6 +152,8 @@
@@ -11,9 +11,9 @@ index 0f504dbee85..430ec217e59 100644
<!--
If you changes codahale.metrics.version, you also need to change
the link to metrics.dropwizard.io in docs/monitoring.md.
-@@ -2787,6 +2789,25 @@
- <artifactId>arpack</artifactId>
- <version>${netlib.ludovic.dev.version}</version>
+@@ -2836,6 +2838,25 @@
+ <artifactId>okio</artifactId>
+ <version>${okio.version}</version>
</dependency>
+ <dependency>
+ <groupId>org.apache.datafusion</groupId>
@@ -38,7 +38,7 @@ index 0f504dbee85..430ec217e59 100644
</dependencyManagement>
diff --git a/sql/core/pom.xml b/sql/core/pom.xml
-index c46ab7b8fce..13357e8c7a6 100644
+index 9577de81c20..a37f4a1f89f 100644
--- a/sql/core/pom.xml
+++ b/sql/core/pom.xml
@@ -77,6 +77,10 @@
@@ -203,7 +203,7 @@ index 0efe0877e9b..423d3b3d76d 100644
-- SELECT_HAVING
--
https://github.com/postgres/postgres/blob/REL_12_BETA2/src/test/regress/sql/select_having.sql
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
-index 8331a3c10fc..b4e22732a91 100644
+index 9815cb816c9..95b5f9992b0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
@@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeConstants
@@ -226,10 +226,10 @@ index 8331a3c10fc..b4e22732a91 100644
test("A cached table preserves the partitioning and ordering of its cached
SparkPlan") {
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
-index 631fcd8c0d8..6df0e1b4176 100644
+index 5a8681aed97..da9d25e2eb4 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
-@@ -27,7 +27,7 @@ import org.apache.spark.{SparkException, SparkThrowable}
+@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical.Expand
import org.apache.spark.sql.execution.WholeStageCodegenExec
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec,
ObjectHashAggregateExec, SortAggregateExec}
@@ -238,7 +238,7 @@ index 631fcd8c0d8..6df0e1b4176 100644
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
-@@ -792,7 +792,7 @@ class DataFrameAggregateSuite extends QueryTest
+@@ -793,7 +793,7 @@ class DataFrameAggregateSuite extends QueryTest
assert(objHashAggPlans.nonEmpty)
val exchangePlans = collect(aggPlan) {
@@ -263,7 +263,7 @@ index 56e9520fdab..917932336df 100644
spark.range(100).write.saveAsTable(s"$dbName.$table2Name")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
-index 002719f0689..784d24afe2d 100644
+index 7ee18df3756..64f01a68048 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -40,11 +40,12 @@ import
org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation
@@ -280,7 +280,7 @@ index 002719f0689..784d24afe2d 100644
import org.apache.spark.sql.expressions.{Aggregator, Window}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
-@@ -2020,7 +2021,7 @@ class DataFrameSuite extends QueryTest
+@@ -2006,7 +2007,7 @@ class DataFrameSuite extends QueryTest
fail("Should not have back to back Aggregates")
}
atFirstAgg = true
@@ -289,7 +289,7 @@ index 002719f0689..784d24afe2d 100644
case _ =>
}
}
-@@ -2344,7 +2345,7 @@ class DataFrameSuite extends QueryTest
+@@ -2330,7 +2331,7 @@ class DataFrameSuite extends QueryTest
checkAnswer(join, df)
assert(
collect(join.queryExecution.executedPlan) {
@@ -298,7 +298,7 @@ index 002719f0689..784d24afe2d 100644
assert(
collect(join.queryExecution.executedPlan) { case e:
ReusedExchangeExec => true }.size === 1)
val broadcasted = broadcast(join)
-@@ -2352,10 +2353,12 @@ class DataFrameSuite extends QueryTest
+@@ -2338,10 +2339,12 @@ class DataFrameSuite extends QueryTest
checkAnswer(join2, df)
assert(
collect(join2.queryExecution.executedPlan) {
@@ -313,7 +313,7 @@ index 002719f0689..784d24afe2d 100644
assert(
collect(join2.queryExecution.executedPlan) { case e:
ReusedExchangeExec => true }.size == 4)
}
-@@ -2915,7 +2918,7 @@ class DataFrameSuite extends QueryTest
+@@ -2901,7 +2904,7 @@ class DataFrameSuite extends QueryTest
// Assert that no extra shuffle introduced by cogroup.
val exchanges = collect(df3.queryExecution.executedPlan) {
@@ -322,7 +322,7 @@ index 002719f0689..784d24afe2d 100644
}
assert(exchanges.size == 2)
}
-@@ -3364,7 +3367,8 @@ class DataFrameSuite extends QueryTest
+@@ -3350,7 +3353,8 @@ class DataFrameSuite extends QueryTest
assert(df2.isLocal)
}
@@ -333,7 +333,7 @@ index 002719f0689..784d24afe2d 100644
sql(
"""
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
-index c2fe31520ac..0f54b233d14 100644
+index f32b32ffc5a..447d7c6416e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.plans.{LeftAnti,
LeftSemi}
@@ -715,10 +715,10 @@ index 7af826583bd..3c3def1eb67 100644
assert(shuffleMergeJoins.size == 1)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
-index 9dcf7ec2904..d8b014a4eb8 100644
+index 4d256154c85..43f0bebb00c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
-@@ -30,7 +30,8 @@ import
org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
+@@ -31,7 +31,8 @@ import
org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.expressions.{Ascending, GenericRow,
SortOrder}
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight}
import org.apache.spark.sql.catalyst.plans.logical.{Filter, HintInfo, Join,
JoinHint, NO_BROADCAST_AND_REPLICATION}
@@ -728,7 +728,7 @@ index 9dcf7ec2904..d8b014a4eb8 100644
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.exchange.{ShuffleExchangeExec,
ShuffleExchangeLike}
import org.apache.spark.sql.execution.joins._
-@@ -801,7 +802,8 @@ class JoinSuite extends QueryTest with SharedSparkSession
with AdaptiveSparkPlan
+@@ -802,7 +803,8 @@ class JoinSuite extends QueryTest with SharedSparkSession
with AdaptiveSparkPlan
}
}
@@ -738,7 +738,7 @@ index 9dcf7ec2904..d8b014a4eb8 100644
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1",
SQLConf.SORT_MERGE_JOIN_EXEC_BUFFER_IN_MEMORY_THRESHOLD.key -> "0",
SQLConf.SORT_MERGE_JOIN_EXEC_BUFFER_SPILL_THRESHOLD.key -> "1") {
-@@ -927,10 +929,12 @@ class JoinSuite extends QueryTest with
SharedSparkSession with AdaptiveSparkPlan
+@@ -928,10 +930,12 @@ class JoinSuite extends QueryTest with
SharedSparkSession with AdaptiveSparkPlan
val physical = df.queryExecution.sparkPlan
val physicalJoins = physical.collect {
case j: SortMergeJoinExec => j
@@ -751,7 +751,7 @@ index 9dcf7ec2904..d8b014a4eb8 100644
}
// This only applies to the above tested queries, in which a child
SortMergeJoin always
// contains the SortOrder required by its parent SortMergeJoin. Thus,
SortExec should never
-@@ -1176,9 +1180,11 @@ class JoinSuite extends QueryTest with
SharedSparkSession with AdaptiveSparkPlan
+@@ -1177,9 +1181,11 @@ class JoinSuite extends QueryTest with
SharedSparkSession with AdaptiveSparkPlan
val plan = df1.join(df2.hint("SHUFFLE_HASH"), $"k1" === $"k2", joinType)
.groupBy($"k1").count()
.queryExecution.executedPlan
@@ -765,7 +765,7 @@ index 9dcf7ec2904..d8b014a4eb8 100644
})
}
-@@ -1195,10 +1201,11 @@ class JoinSuite extends QueryTest with
SharedSparkSession with AdaptiveSparkPlan
+@@ -1196,10 +1202,11 @@ class JoinSuite extends QueryTest with
SharedSparkSession with AdaptiveSparkPlan
.join(df4.hint("SHUFFLE_MERGE"), $"k1" === $"k4", joinType)
.queryExecution
.executedPlan
@@ -779,7 +779,7 @@ index 9dcf7ec2904..d8b014a4eb8 100644
})
// Test shuffled hash join
-@@ -1208,10 +1215,13 @@ class JoinSuite extends QueryTest with
SharedSparkSession with AdaptiveSparkPlan
+@@ -1209,10 +1216,13 @@ class JoinSuite extends QueryTest with
SharedSparkSession with AdaptiveSparkPlan
.join(df4.hint("SHUFFLE_MERGE"), $"k1" === $"k4", joinType)
.queryExecution
.executedPlan
@@ -796,7 +796,7 @@ index 9dcf7ec2904..d8b014a4eb8 100644
})
}
-@@ -1302,12 +1312,12 @@ class JoinSuite extends QueryTest with
SharedSparkSession with AdaptiveSparkPlan
+@@ -1303,12 +1313,12 @@ class JoinSuite extends QueryTest with
SharedSparkSession with AdaptiveSparkPlan
inputDFs.foreach { case (df1, df2, joinExprs) =>
val smjDF = df1.join(df2.hint("SHUFFLE_MERGE"), joinExprs, "full")
assert(collect(smjDF.queryExecution.executedPlan) {
@@ -811,7 +811,7 @@ index 9dcf7ec2904..d8b014a4eb8 100644
// Same result between shuffled hash join and sort merge join
checkAnswer(shjDF, smjResult)
}
-@@ -1366,12 +1376,14 @@ class JoinSuite extends QueryTest with
SharedSparkSession with AdaptiveSparkPlan
+@@ -1367,12 +1377,14 @@ class JoinSuite extends QueryTest with
SharedSparkSession with AdaptiveSparkPlan
val smjDF = df1.hint("SHUFFLE_MERGE").join(df2, joinExprs,
"leftouter")
assert(collect(smjDF.queryExecution.executedPlan) {
case _: SortMergeJoinExec => true
@@ -826,7 +826,7 @@ index 9dcf7ec2904..d8b014a4eb8 100644
}.size === 1)
// Same result between shuffled hash join and sort merge join
checkAnswer(shjDF, smjResult)
-@@ -1382,12 +1394,14 @@ class JoinSuite extends QueryTest with
SharedSparkSession with AdaptiveSparkPlan
+@@ -1383,12 +1395,14 @@ class JoinSuite extends QueryTest with
SharedSparkSession with AdaptiveSparkPlan
val smjDF = df2.join(df1.hint("SHUFFLE_MERGE"), joinExprs,
"rightouter")
assert(collect(smjDF.queryExecution.executedPlan) {
case _: SortMergeJoinExec => true
@@ -841,7 +841,7 @@ index 9dcf7ec2904..d8b014a4eb8 100644
}.size === 1)
// Same result between shuffled hash join and sort merge join
checkAnswer(shjDF, smjResult)
-@@ -1431,13 +1445,19 @@ class JoinSuite extends QueryTest with
SharedSparkSession with AdaptiveSparkPlan
+@@ -1432,13 +1446,19 @@ class JoinSuite extends QueryTest with
SharedSparkSession with AdaptiveSparkPlan
assert(shjCodegenDF.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(_ : ShuffledHashJoinExec) => true
case WholeStageCodegenExec(ProjectExec(_, _ :
ShuffledHashJoinExec)) => true
@@ -862,7 +862,7 @@ index 9dcf7ec2904..d8b014a4eb8 100644
checkAnswer(shjNonCodegenDF, Seq.empty)
}
}
-@@ -1485,7 +1505,8 @@ class JoinSuite extends QueryTest with
SharedSparkSession with AdaptiveSparkPlan
+@@ -1486,7 +1506,8 @@ class JoinSuite extends QueryTest with
SharedSparkSession with AdaptiveSparkPlan
val plan = sql(getAggQuery(selectExpr,
joinType)).queryExecution.executedPlan
assert(collect(plan) { case _: BroadcastNestedLoopJoinExec => true
}.size === 1)
// Have shuffle before aggregation
@@ -872,7 +872,7 @@ index 9dcf7ec2904..d8b014a4eb8 100644
}
def getJoinQuery(selectExpr: String, joinType: String): String = {
-@@ -1514,9 +1535,12 @@ class JoinSuite extends QueryTest with
SharedSparkSession with AdaptiveSparkPlan
+@@ -1515,9 +1536,12 @@ class JoinSuite extends QueryTest with
SharedSparkSession with AdaptiveSparkPlan
}
val plan = sql(getJoinQuery(selectExpr,
joinType)).queryExecution.executedPlan
assert(collect(plan) { case _: BroadcastNestedLoopJoinExec => true
}.size === 1)
@@ -887,7 +887,7 @@ index 9dcf7ec2904..d8b014a4eb8 100644
}
// Test output ordering is not preserved
-@@ -1525,9 +1549,12 @@ class JoinSuite extends QueryTest with
SharedSparkSession with AdaptiveSparkPlan
+@@ -1526,9 +1550,12 @@ class JoinSuite extends QueryTest with
SharedSparkSession with AdaptiveSparkPlan
val selectExpr = "/*+ BROADCAST(left_t) */ k1 as k0"
val plan = sql(getJoinQuery(selectExpr,
joinType)).queryExecution.executedPlan
assert(collect(plan) { case _: BroadcastNestedLoopJoinExec => true
}.size === 1)
@@ -902,7 +902,7 @@ index 9dcf7ec2904..d8b014a4eb8 100644
}
// Test singe partition
-@@ -1537,7 +1564,8 @@ class JoinSuite extends QueryTest with
SharedSparkSession with AdaptiveSparkPlan
+@@ -1538,7 +1565,8 @@ class JoinSuite extends QueryTest with
SharedSparkSession with AdaptiveSparkPlan
|FROM range(0, 10, 1, 1) t1 FULL OUTER JOIN range(0, 10, 1, 1) t2
|""".stripMargin)
val plan = fullJoinDF.queryExecution.executedPlan
@@ -912,7 +912,7 @@ index 9dcf7ec2904..d8b014a4eb8 100644
checkAnswer(fullJoinDF, Row(100))
}
}
-@@ -1582,6 +1610,9 @@ class JoinSuite extends QueryTest with
SharedSparkSession with AdaptiveSparkPlan
+@@ -1583,6 +1611,9 @@ class JoinSuite extends QueryTest with
SharedSparkSession with AdaptiveSparkPlan
Seq(semiJoinDF, antiJoinDF).foreach { df =>
assert(collect(df.queryExecution.executedPlan) {
case j: ShuffledHashJoinExec if j.ignoreDuplicatedKey ==
ignoreDuplicatedKey => true
@@ -922,7 +922,7 @@ index 9dcf7ec2904..d8b014a4eb8 100644
}.size == 1)
}
}
-@@ -1626,14 +1657,20 @@ class JoinSuite extends QueryTest with
SharedSparkSession with AdaptiveSparkPlan
+@@ -1627,14 +1658,20 @@ class JoinSuite extends QueryTest with
SharedSparkSession with AdaptiveSparkPlan
test("SPARK-43113: Full outer join with duplicate stream-side references in
condition (SMJ)") {
def check(plan: SparkPlan): Unit = {
@@ -946,7 +946,7 @@ index 9dcf7ec2904..d8b014a4eb8 100644
dupStreamSideColTest("SHUFFLE_HASH", check)
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala
-index b5b34922694..a72403780c4 100644
+index c26757c9cff..d55775f09d7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala
@@ -69,7 +69,7 @@ import org.apache.spark.tags.ExtendedSQLTest
@@ -959,10 +959,10 @@ index b5b34922694..a72403780c4 100644
protected val baseResourcePath = {
// use the same way as `SQLQueryTestSuite` to get the resource path
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
-index cfeccbdf648..803d8734cc4 100644
+index 793a0da6a86..6ccb9d62582 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
-@@ -1510,7 +1510,8 @@ class SQLQuerySuite extends QueryTest with
SharedSparkSession with AdaptiveSpark
+@@ -1521,7 +1521,8 @@ class SQLQuerySuite extends QueryTest with
SharedSparkSession with AdaptiveSpark
checkAnswer(sql("select -0.001"), Row(BigDecimal("-0.001")))
}
@@ -1004,7 +1004,7 @@ index 8b4ac474f87..3f79f20822f 100644
extensions.injectColumnar(session =>
MyColumnarRule(PreRuleReplaceAddWithBrokenVersion(), MyPostRule()))
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
-index fbc256b3396..0821999c7c2 100644
+index 260c992f1ae..b9d8e22337c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
@@ -22,10 +22,11 @@ import scala.collection.mutable.ArrayBuffer
@@ -1043,7 +1043,7 @@ index fbc256b3396..0821999c7c2 100644
assert(exchanges.size === 1)
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala
-index 52d0151ee46..2b6d493cf38 100644
+index d269290e616..13726a31e07 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala
@@ -24,6 +24,7 @@ import test.org.apache.spark.sql.connector._
@@ -1131,7 +1131,7 @@ index cfc8b2cc845..c6fcfd7bd08 100644
} finally {
spark.listenerManager.unregister(listener)
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
-index 6b07c77aefb..8277661560e 100644
+index 71e030f535e..d5ae6cbf3d5 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
@@ -22,6 +22,7 @@ import org.apache.spark.sql.{DataFrame, Row}
@@ -1449,23 +1449,23 @@ index 5a413c77754..a6f97dccb67 100644
val df = spark.read.parquet(path).selectExpr(projection: _*)
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
-index 68bae34790a..0cc77ad09d7 100644
+index 2f8e401e743..a4f94417dcc 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
-@@ -26,9 +26,11 @@ import org.scalatest.time.SpanSugar._
-
+@@ -27,9 +27,11 @@ import org.scalatest.time.SpanSugar._
import org.apache.spark.SparkException
import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent,
SparkListenerJobStart}
+ import org.apache.spark.shuffle.sort.SortShuffleManager
-import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession, Strategy}
+import org.apache.spark.sql.{Dataset, IgnoreComet, QueryTest, Row,
SparkSession, Strategy}
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
+import org.apache.spark.sql.comet._
+import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
- import org.apache.spark.sql.execution.{CollectLimitExec, ColumnarToRowExec,
LocalTableScanExec, PartialReducerPartitionSpec, QueryExecution,
ReusedSubqueryExec, ShuffledRowRDD, SortExec, SparkPlan, SparkPlanInfo,
UnionExec}
+ import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
- import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
-@@ -112,6 +114,7 @@ class AdaptiveQueryExecSuite
+ import org.apache.spark.sql.execution.columnar.{InMemoryTableScanExec,
InMemoryTableScanLike}
+@@ -117,6 +119,7 @@ class AdaptiveQueryExecSuite
private def findTopLevelBroadcastHashJoin(plan: SparkPlan):
Seq[BroadcastHashJoinExec] = {
collect(plan) {
case j: BroadcastHashJoinExec => j
@@ -1473,7 +1473,7 @@ index 68bae34790a..0cc77ad09d7 100644
}
}
-@@ -124,30 +127,39 @@ class AdaptiveQueryExecSuite
+@@ -129,36 +132,46 @@ class AdaptiveQueryExecSuite
private def findTopLevelSortMergeJoin(plan: SparkPlan):
Seq[SortMergeJoinExec] = {
collect(plan) {
case j: SortMergeJoinExec => j
@@ -1513,7 +1513,14 @@ index 68bae34790a..0cc77ad09d7 100644
}
}
-@@ -191,6 +203,7 @@ class AdaptiveQueryExecSuite
+ private def findTopLevelLimit(plan: SparkPlan): Seq[CollectLimitExec] = {
+ collect(plan) {
+ case l: CollectLimitExec => l
++ case l: CometCollectLimitExec =>
l.originalPlan.asInstanceOf[CollectLimitExec]
+ }
+ }
+
+@@ -202,6 +215,7 @@ class AdaptiveQueryExecSuite
val parts = rdd.partitions
assert(parts.forall(rdd.preferredLocations(_).nonEmpty))
}
@@ -1521,7 +1528,7 @@ index 68bae34790a..0cc77ad09d7 100644
assert(numShuffles === (numLocalReads.length +
numShufflesWithoutLocalRead))
}
-@@ -199,7 +212,7 @@ class AdaptiveQueryExecSuite
+@@ -210,7 +224,7 @@ class AdaptiveQueryExecSuite
val plan = df.queryExecution.executedPlan
assert(plan.isInstanceOf[AdaptiveSparkPlanExec])
val shuffle =
plan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan.collect {
@@ -1530,7 +1537,7 @@ index 68bae34790a..0cc77ad09d7 100644
}
assert(shuffle.size == 1)
assert(shuffle(0).outputPartitioning.numPartitions == numPartition)
-@@ -215,7 +228,8 @@ class AdaptiveQueryExecSuite
+@@ -226,7 +240,8 @@ class AdaptiveQueryExecSuite
assert(smj.size == 1)
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)
@@ -1540,7 +1547,7 @@ index 68bae34790a..0cc77ad09d7 100644
}
}
-@@ -242,7 +256,8 @@ class AdaptiveQueryExecSuite
+@@ -253,7 +268,8 @@ class AdaptiveQueryExecSuite
}
}
@@ -1550,7 +1557,7 @@ index 68bae34790a..0cc77ad09d7 100644
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80",
-@@ -274,7 +289,8 @@ class AdaptiveQueryExecSuite
+@@ -285,7 +301,8 @@ class AdaptiveQueryExecSuite
}
}
@@ -1560,7 +1567,7 @@ index 68bae34790a..0cc77ad09d7 100644
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80",
-@@ -288,7 +304,8 @@ class AdaptiveQueryExecSuite
+@@ -299,7 +316,8 @@ class AdaptiveQueryExecSuite
val localReads = collect(adaptivePlan) {
case read: AQEShuffleReadExec if read.isLocalRead => read
}
@@ -1570,7 +1577,7 @@ index 68bae34790a..0cc77ad09d7 100644
val localShuffleRDD0 =
localReads(0).execute().asInstanceOf[ShuffledRowRDD]
val localShuffleRDD1 =
localReads(1).execute().asInstanceOf[ShuffledRowRDD]
// the final parallelism is math.max(1, numReduces / numMappers):
math.max(1, 5/2) = 2
-@@ -313,7 +330,9 @@ class AdaptiveQueryExecSuite
+@@ -324,7 +342,9 @@ class AdaptiveQueryExecSuite
.groupBy($"a").count()
checkAnswer(testDf, Seq())
val plan = testDf.queryExecution.executedPlan
@@ -1581,7 +1588,7 @@ index 68bae34790a..0cc77ad09d7 100644
val coalescedReads = collect(plan) {
case r: AQEShuffleReadExec => r
}
-@@ -327,7 +346,9 @@ class AdaptiveQueryExecSuite
+@@ -338,7 +358,9 @@ class AdaptiveQueryExecSuite
.groupBy($"a").count()
checkAnswer(testDf, Seq())
val plan = testDf.queryExecution.executedPlan
@@ -1592,7 +1599,7 @@ index 68bae34790a..0cc77ad09d7 100644
val coalescedReads = collect(plan) {
case r: AQEShuffleReadExec => r
}
-@@ -337,7 +358,7 @@ class AdaptiveQueryExecSuite
+@@ -348,7 +370,7 @@ class AdaptiveQueryExecSuite
}
}
@@ -1601,7 +1608,7 @@ index 68bae34790a..0cc77ad09d7 100644
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
-@@ -352,7 +373,7 @@ class AdaptiveQueryExecSuite
+@@ -363,7 +385,7 @@ class AdaptiveQueryExecSuite
}
}
@@ -1610,7 +1617,7 @@ index 68bae34790a..0cc77ad09d7 100644
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
-@@ -368,7 +389,7 @@ class AdaptiveQueryExecSuite
+@@ -379,7 +401,7 @@ class AdaptiveQueryExecSuite
}
}
@@ -1619,7 +1626,7 @@ index 68bae34790a..0cc77ad09d7 100644
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
-@@ -413,7 +434,7 @@ class AdaptiveQueryExecSuite
+@@ -424,7 +446,7 @@ class AdaptiveQueryExecSuite
}
}
@@ -1628,7 +1635,7 @@ index 68bae34790a..0cc77ad09d7 100644
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
-@@ -458,7 +479,7 @@ class AdaptiveQueryExecSuite
+@@ -469,7 +491,7 @@ class AdaptiveQueryExecSuite
}
}
@@ -1637,7 +1644,7 @@ index 68bae34790a..0cc77ad09d7 100644
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "500") {
-@@ -504,7 +525,7 @@ class AdaptiveQueryExecSuite
+@@ -515,7 +537,7 @@ class AdaptiveQueryExecSuite
}
}
@@ -1646,7 +1653,7 @@ index 68bae34790a..0cc77ad09d7 100644
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
-@@ -523,7 +544,7 @@ class AdaptiveQueryExecSuite
+@@ -534,7 +556,7 @@ class AdaptiveQueryExecSuite
}
}
@@ -1655,7 +1662,7 @@ index 68bae34790a..0cc77ad09d7 100644
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
-@@ -554,7 +575,9 @@ class AdaptiveQueryExecSuite
+@@ -565,7 +587,9 @@ class AdaptiveQueryExecSuite
assert(smj.size == 1)
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)
@@ -1666,7 +1673,7 @@ index 68bae34790a..0cc77ad09d7 100644
// Even with local shuffle read, the query stage reuse can also work.
val ex = findReusedExchange(adaptivePlan)
assert(ex.nonEmpty)
-@@ -575,7 +598,9 @@ class AdaptiveQueryExecSuite
+@@ -586,7 +610,9 @@ class AdaptiveQueryExecSuite
assert(smj.size == 1)
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)
@@ -1677,7 +1684,7 @@ index 68bae34790a..0cc77ad09d7 100644
// Even with local shuffle read, the query stage reuse can also work.
val ex = findReusedExchange(adaptivePlan)
assert(ex.isEmpty)
-@@ -584,7 +609,8 @@ class AdaptiveQueryExecSuite
+@@ -595,7 +621,8 @@ class AdaptiveQueryExecSuite
}
}
@@ -1687,7 +1694,7 @@ index 68bae34790a..0cc77ad09d7 100644
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "20000000",
-@@ -679,7 +705,8 @@ class AdaptiveQueryExecSuite
+@@ -690,7 +717,8 @@ class AdaptiveQueryExecSuite
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)
// There is still a SMJ, and its two shuffles can't apply local read.
@@ -1697,7 +1704,7 @@ index 68bae34790a..0cc77ad09d7 100644
}
}
-@@ -801,7 +828,8 @@ class AdaptiveQueryExecSuite
+@@ -812,7 +840,8 @@ class AdaptiveQueryExecSuite
}
}
@@ -1707,7 +1714,7 @@ index 68bae34790a..0cc77ad09d7 100644
Seq("SHUFFLE_MERGE", "SHUFFLE_HASH").foreach { joinHint =>
def getJoinNode(plan: SparkPlan): Seq[ShuffledJoin] = if (joinHint ==
"SHUFFLE_MERGE") {
findTopLevelSortMergeJoin(plan)
-@@ -1019,7 +1047,8 @@ class AdaptiveQueryExecSuite
+@@ -1030,7 +1059,8 @@ class AdaptiveQueryExecSuite
}
}
@@ -1717,7 +1724,7 @@ index 68bae34790a..0cc77ad09d7 100644
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
val (_, adaptivePlan) = runAdaptiveAndVerifyResult(
"SELECT key FROM testData GROUP BY key")
-@@ -1614,7 +1643,7 @@ class AdaptiveQueryExecSuite
+@@ -1625,7 +1655,7 @@ class AdaptiveQueryExecSuite
val (_, adaptivePlan) = runAdaptiveAndVerifyResult(
"SELECT id FROM v1 GROUP BY id DISTRIBUTE BY id")
assert(collect(adaptivePlan) {
@@ -1726,7 +1733,7 @@ index 68bae34790a..0cc77ad09d7 100644
}.length == 1)
}
}
-@@ -1694,7 +1723,8 @@ class AdaptiveQueryExecSuite
+@@ -1705,7 +1735,8 @@ class AdaptiveQueryExecSuite
}
}
@@ -1736,7 +1743,7 @@ index 68bae34790a..0cc77ad09d7 100644
def hasRepartitionShuffle(plan: SparkPlan): Boolean = {
find(plan) {
case s: ShuffleExchangeLike =>
-@@ -1879,6 +1909,9 @@ class AdaptiveQueryExecSuite
+@@ -1890,6 +1921,9 @@ class AdaptiveQueryExecSuite
def checkNoCoalescePartitions(ds: Dataset[Row], origin: ShuffleOrigin):
Unit = {
assert(collect(ds.queryExecution.executedPlan) {
case s: ShuffleExchangeExec if s.shuffleOrigin == origin &&
s.numPartitions == 2 => s
@@ -1746,7 +1753,7 @@ index 68bae34790a..0cc77ad09d7 100644
}.size == 1)
ds.collect()
val plan = ds.queryExecution.executedPlan
-@@ -1887,6 +1920,9 @@ class AdaptiveQueryExecSuite
+@@ -1898,6 +1932,9 @@ class AdaptiveQueryExecSuite
}.isEmpty)
assert(collect(plan) {
case s: ShuffleExchangeExec if s.shuffleOrigin == origin &&
s.numPartitions == 2 => s
@@ -1756,7 +1763,7 @@ index 68bae34790a..0cc77ad09d7 100644
}.size == 1)
checkAnswer(ds, testData)
}
-@@ -2043,7 +2079,8 @@ class AdaptiveQueryExecSuite
+@@ -2054,7 +2091,8 @@ class AdaptiveQueryExecSuite
}
}
@@ -1766,7 +1773,7 @@ index 68bae34790a..0cc77ad09d7 100644
withTempView("t1", "t2") {
def checkJoinStrategy(shouldShuffleHashJoin: Boolean): Unit = {
Seq("100", "100000").foreach { size =>
-@@ -2129,7 +2166,8 @@ class AdaptiveQueryExecSuite
+@@ -2140,7 +2178,8 @@ class AdaptiveQueryExecSuite
}
}
@@ -1776,7 +1783,7 @@ index 68bae34790a..0cc77ad09d7 100644
withTempView("v") {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
-@@ -2228,7 +2266,7 @@ class AdaptiveQueryExecSuite
+@@ -2239,7 +2278,7 @@ class AdaptiveQueryExecSuite
runAdaptiveAndVerifyResult(s"SELECT $repartition key1 FROM
skewData1 " +
s"JOIN skewData2 ON key1 = key2 GROUP BY key1")
val shuffles1 = collect(adaptive1) {
@@ -1785,7 +1792,7 @@ index 68bae34790a..0cc77ad09d7 100644
}
assert(shuffles1.size == 3)
// shuffles1.head is the top-level shuffle under the Aggregate
operator
-@@ -2241,7 +2279,7 @@ class AdaptiveQueryExecSuite
+@@ -2252,7 +2291,7 @@ class AdaptiveQueryExecSuite
runAdaptiveAndVerifyResult(s"SELECT $repartition key1 FROM
skewData1 " +
s"JOIN skewData2 ON key1 = key2")
val shuffles2 = collect(adaptive2) {
@@ -1794,7 +1801,7 @@ index 68bae34790a..0cc77ad09d7 100644
}
if (hasRequiredDistribution) {
assert(shuffles2.size == 3)
-@@ -2275,7 +2313,8 @@ class AdaptiveQueryExecSuite
+@@ -2286,7 +2325,8 @@ class AdaptiveQueryExecSuite
}
}
@@ -1804,7 +1811,17 @@ index 68bae34790a..0cc77ad09d7 100644
CostEvaluator.instantiate(
classOf[SimpleShuffleSortCostEvaluator].getCanonicalName,
spark.sparkContext.getConf)
intercept[IllegalArgumentException] {
-@@ -2419,6 +2458,7 @@ class AdaptiveQueryExecSuite
+@@ -2417,7 +2457,8 @@ class AdaptiveQueryExecSuite
+ }
+
+ test("SPARK-48037: Fix SortShuffleWriter lacks shuffle write related
metrics " +
+- "resulting in potentially inaccurate data") {
++ "resulting in potentially inaccurate data",
++ IgnoreComet("https://github.com/apache/datafusion-comet/issues/1501")) {
+ withTable("t3") {
+ withSQLConf(
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+@@ -2452,6 +2493,7 @@ class AdaptiveQueryExecSuite
val (_, adaptive) = runAdaptiveAndVerifyResult(query)
assert(adaptive.collect {
case sort: SortExec => sort
@@ -1812,7 +1829,7 @@ index 68bae34790a..0cc77ad09d7 100644
}.size == 1)
val read = collect(adaptive) {
case read: AQEShuffleReadExec => read
-@@ -2436,7 +2476,8 @@ class AdaptiveQueryExecSuite
+@@ -2469,7 +2511,8 @@ class AdaptiveQueryExecSuite
}
}
@@ -1822,7 +1839,7 @@ index 68bae34790a..0cc77ad09d7 100644
withTempView("v") {
withSQLConf(
SQLConf.ADAPTIVE_OPTIMIZE_SKEWS_IN_REBALANCE_PARTITIONS_ENABLED.key
-> "true",
-@@ -2548,7 +2589,7 @@ class AdaptiveQueryExecSuite
+@@ -2581,7 +2624,7 @@ class AdaptiveQueryExecSuite
runAdaptiveAndVerifyResult("SELECT key1 FROM skewData1 JOIN
skewData2 ON key1 = key2 " +
"JOIN skewData3 ON value2 = value3")
val shuffles1 = collect(adaptive1) {
@@ -1831,7 +1848,7 @@ index 68bae34790a..0cc77ad09d7 100644
}
assert(shuffles1.size == 4)
val smj1 = findTopLevelSortMergeJoin(adaptive1)
-@@ -2559,7 +2600,7 @@ class AdaptiveQueryExecSuite
+@@ -2592,7 +2635,7 @@ class AdaptiveQueryExecSuite
runAdaptiveAndVerifyResult("SELECT key1 FROM skewData1 JOIN
skewData2 ON key1 = key2 " +
"JOIN skewData3 ON value1 = value3")
val shuffles2 = collect(adaptive2) {
@@ -1840,14 +1857,25 @@ index 68bae34790a..0cc77ad09d7 100644
}
assert(shuffles2.size == 4)
val smj2 = findTopLevelSortMergeJoin(adaptive2)
-@@ -2756,6 +2797,7 @@ class AdaptiveQueryExecSuite
+@@ -2850,6 +2893,7 @@ class AdaptiveQueryExecSuite
}.size == (if (firstAccess) 1 else 0))
assert(collect(initialExecutedPlan) {
case s: SortExec => s
+ case s: CometSortExec => s
}.size == (if (firstAccess) 2 else 0))
assert(collect(initialExecutedPlan) {
- case i: InMemoryTableScanExec => i
+ case i: InMemoryTableScanLike => i
+@@ -2980,7 +3024,9 @@ class AdaptiveQueryExecSuite
+
+ val plan =
df.queryExecution.executedPlan.asInstanceOf[AdaptiveSparkPlanExec]
+ assert(plan.inputPlan.isInstanceOf[TakeOrderedAndProjectExec])
+- assert(plan.finalPhysicalPlan.isInstanceOf[WindowExec])
++ assert(
++ plan.finalPhysicalPlan.isInstanceOf[WindowExec] ||
++ plan.finalPhysicalPlan.find(_.isInstanceOf[CometWindowExec]).nonEmpty)
+ plan.inputPlan.output.zip(plan.finalPhysicalPlan.output).foreach { case
(o1, o2) =>
+ assert(o1.semanticEquals(o2), "Different output column order after
AQE optimization")
+ }
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceCustomMetadataStructSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceCustomMetadataStructSuite.scala
index 05872d41131..a2c328b9742 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceCustomMetadataStructSuite.scala
@@ -2094,10 +2122,10 @@ index 4f8a9e39716..fb55ac7a955 100644
checkAnswer(
// "fruit" column in this file is encoded using
DELTA_LENGTH_BYTE_ARRAY.
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
-index 828ec39c7d7..369b3848192 100644
+index f6472ba3d9d..dc13e00c853 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
-@@ -1041,7 +1041,8 @@ abstract class ParquetQuerySuite extends QueryTest with
ParquetTest with SharedS
+@@ -1067,7 +1067,8 @@ abstract class ParquetQuerySuite extends QueryTest with
ParquetTest with SharedS
checkAnswer(readParquet(schema, path), df)
}
@@ -2107,7 +2135,7 @@ index 828ec39c7d7..369b3848192 100644
val schema1 = "a DECIMAL(3, 2), b DECIMAL(18, 3), c DECIMAL(37, 3)"
checkAnswer(readParquet(schema1, path), df)
val schema2 = "a DECIMAL(3, 0), b DECIMAL(18, 1), c DECIMAL(37, 1)"
-@@ -1063,7 +1064,8 @@ abstract class ParquetQuerySuite extends QueryTest with
ParquetTest with SharedS
+@@ -1089,7 +1090,8 @@ abstract class ParquetQuerySuite extends QueryTest with
ParquetTest with SharedS
val df = sql(s"SELECT 1 a, 123456 b, ${Int.MaxValue.toLong * 10} c,
CAST('1.2' AS BINARY) d")
df.write.parquet(path.toString)
@@ -2117,7 +2145,7 @@ index 828ec39c7d7..369b3848192 100644
checkAnswer(readParquet("a DECIMAL(3, 2)", path), sql("SELECT 1.00"))
checkAnswer(readParquet("b DECIMAL(3, 2)", path), Row(null))
checkAnswer(readParquet("b DECIMAL(11, 1)", path), sql("SELECT
123456.0"))
-@@ -1122,7 +1124,7 @@ abstract class ParquetQuerySuite extends QueryTest with
ParquetTest with SharedS
+@@ -1148,7 +1150,7 @@ abstract class ParquetQuerySuite extends QueryTest with
ParquetTest with SharedS
.where(s"a < ${Long.MaxValue}")
.collect()
}
@@ -2241,7 +2269,7 @@ index b8f3ea3c6f3..bbd44221288 100644
val workDirPath = workDir.getAbsolutePath
val input = spark.range(5).toDF("id")
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
-index 6347757e178..6d0fa493308 100644
+index 5cdbdc27b32..307fba16578 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
@@ -46,8 +46,10 @@ import org.apache.spark.sql.util.QueryExecutionListener
@@ -2533,7 +2561,7 @@ index d675503a8ba..659fa686fb7 100644
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala
-index 75f440caefc..36b1146bc3a 100644
+index 1954cce7fdc..73d1464780e 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala
@@ -34,6 +34,7 @@ import org.apache.spark.paths.SparkPath
@@ -2544,7 +2572,7 @@ index 75f440caefc..36b1146bc3a 100644
import org.apache.spark.sql.execution.DataSourceScanExec
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec,
DataSourceV2Relation, FileScan, FileTable}
-@@ -748,6 +749,8 @@ class FileStreamSinkV2Suite extends FileStreamSinkSuite {
+@@ -761,6 +762,8 @@ class FileStreamSinkV2Suite extends FileStreamSinkSuite {
val fileScan = df.queryExecution.executedPlan.collect {
case batch: BatchScanExec if batch.scan.isInstanceOf[FileScan] =>
batch.scan.asInstanceOf[FileScan]
@@ -2706,7 +2734,7 @@ index b4c4ec7acbf..20579284856 100644
val aggregateExecsWithoutPartialAgg = allAggregateExecs.filter {
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
-index 3e1bc57dfa2..4a8d75ff512 100644
+index aad91601758..201083bd621 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
@@ -31,7 +31,7 @@ import org.apache.spark.scheduler.ExecutorCacheTaskLocation
@@ -2778,7 +2806,7 @@ index abe606ad9c1..2d930b64cca 100644
val tblTargetName = "tbl_target"
val tblSourceQualified = s"default.$tblSourceName"
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
-index dd55fcfe42c..0d66bcccbdc 100644
+index e937173a590..c2e00c53cc3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
@@ -41,6 +41,7 @@ import org.apache.spark.sql.catalyst.plans.PlanTest
@@ -2832,7 +2860,7 @@ index dd55fcfe42c..0d66bcccbdc 100644
protected override def withSQLConf(pairs: (String, String)*)(f: => Unit):
Unit = {
SparkSession.setActiveSession(spark)
super.withSQLConf(pairs: _*)(f)
-@@ -434,6 +462,8 @@ private[sql] trait SQLTestUtilsBase
+@@ -435,6 +463,8 @@ private[sql] trait SQLTestUtilsBase
val schema = df.schema
val withoutFilters = df.queryExecution.executedPlan.transform {
case FilterExec(_, child) => child
@@ -2927,7 +2955,7 @@ index dc8b184fcee..dd69a989d40 100644
spark.sql(
"""
diff --git
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala
-index 9284b35fb3e..37f91610500 100644
+index 1d646f40b3e..7f2cdb8f061 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala
@@ -53,25 +53,55 @@ object TestHive
diff --git a/docs/source/contributor-guide/spark-sql-tests.md
b/docs/source/contributor-guide/spark-sql-tests.md
index 0cdef50ab..cb88d2f43 100644
--- a/docs/source/contributor-guide/spark-sql-tests.md
+++ b/docs/source/contributor-guide/spark-sql-tests.md
@@ -72,11 +72,11 @@ of Apache Spark to enable Comet when running tests. This is
a highly manual proc
vary depending on the changes in the new version of Spark, but here is a
general guide to the process.
We typically start by applying a patch from a previous version of Spark. For
example, when enabling the tests
-for Spark version 3.5.1 we may start by applying the existing diff for 3.4.3
first.
+for Spark version 3.5.4 we may start by applying the existing diff for 3.4.3
first.
```shell
cd git/apache/spark
-git checkout v3.5.1
+git checkout v3.5.4
git apply --reject --whitespace=fix ../datafusion-comet/dev/diffs/3.4.3.diff
```
@@ -118,7 +118,7 @@ wiggle --replace
./sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.sc
## Generating The Diff File
```shell
-git diff v3.5.1 > ../datafusion-comet/dev/diffs/3.5.1.diff
+git diff v3.5.4 > ../datafusion-comet/dev/diffs/3.5.4.diff
```
## Running Tests in CI
diff --git
a/native/spark-expr/src/static_invoke/char_varchar_utils/read_side_padding.rs
b/native/spark-expr/src/static_invoke/char_varchar_utils/read_side_padding.rs
index 1f9400b35..320938a5f 100644
---
a/native/spark-expr/src/static_invoke/char_varchar_utils/read_side_padding.rs
+++
b/native/spark-expr/src/static_invoke/char_varchar_utils/read_side_padding.rs
@@ -17,7 +17,9 @@
use arrow::array::{ArrayRef, OffsetSizeTrait};
use arrow_array::builder::GenericStringBuilder;
-use arrow_array::Array;
+use arrow_array::cast::as_dictionary_array;
+use arrow_array::types::Int32Type;
+use arrow_array::{make_array, Array, DictionaryArray};
use arrow_schema::DataType;
use datafusion::physical_plan::ColumnarValue;
use datafusion_common::{cast::as_generic_string_array, DataFusionError,
ScalarValue};
@@ -45,14 +47,26 @@ fn spark_read_side_padding2(
DataType::LargeUtf8 => {
spark_read_side_padding_internal::<i64>(array, *length,
truncate)
}
- // TODO: handle Dictionary types
+ // Dictionary support required for SPARK-48498
+ DataType::Dictionary(_, value_type) => {
+ let dict = as_dictionary_array::<Int32Type>(array);
+ let col = if value_type.as_ref() == &DataType::Utf8 {
+ spark_read_side_padding_internal::<i32>(dict.values(),
*length, truncate)?
+ } else {
+ spark_read_side_padding_internal::<i64>(dict.values(),
*length, truncate)?
+ };
+ // col consists of an array, so arg of to_array() is not
used. Can be anything
+ let values = col.to_array(0)?;
+ let result = DictionaryArray::try_new(dict.keys().clone(),
values)?;
+ Ok(ColumnarValue::Array(make_array(result.into())))
+ }
other => Err(DataFusionError::Internal(format!(
- "Unsupported data type {other:?} for function
read_side_padding",
+ "Unsupported data type {other:?} for function
rpad/read_side_padding",
))),
}
}
other => Err(DataFusionError::Internal(format!(
- "Unsupported arguments {other:?} for function read_side_padding",
+ "Unsupported arguments {other:?} for function
rpad/read_side_padding",
))),
}
}
diff --git a/pom.xml b/pom.xml
index f55e44316..e236c1dd7 100644
--- a/pom.xml
+++ b/pom.xml
@@ -50,8 +50,8 @@ under the License.
<scala.version>2.12.17</scala.version>
<scala.binary.version>2.12</scala.binary.version>
<scala.plugin.version>4.7.2</scala.plugin.version>
- <scalatest.version>3.2.9</scalatest.version>
- <scalatest-maven-plugin.version>2.0.2</scalatest-maven-plugin.version>
+ <scalatest.version>3.2.16</scalatest.version>
+ <scalatest-maven-plugin.version>2.2.0</scalatest-maven-plugin.version>
<spark.version>3.4.3</spark.version>
<spark.version.short>3.4</spark.version.short>
<spark.maven.scope>provided</spark.maven.scope>
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]