This is an automated email from the ASF dual-hosted git repository.

csy pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/auron.git


The following commit(s) were added to refs/heads/master by this push:
     new b1b5fefe [AURON #1814][Correctness Testing] Spark 3.3 - Add Sort 
operator related tests  (#1820)
b1b5fefe is described below

commit b1b5fefe9ad593fd684051118e140c68af6e3543
Author: Shreyesh <[email protected]>
AuthorDate: Tue Jan 20 20:47:41 2026 -0800

    [AURON #1814][Correctness Testing] Spark 3.3 - Add Sort operator related 
tests  (#1820)
    
    # Which issue does this PR close?
    Closes #1814
    
    # Rationale for this change
    
    This change is part of the effort for enhancing Auron's correctness
    testing (https://github.com/apache/auron/issues/1745).
    
    # What changes are included in this PR?
    This change adds supports for sort operator related tests.
    
    # Are there any user-facing changes?
    N/A
    
    # How was this patch tested?
    Existing unit tests
    
    ---------
    
    Co-authored-by: sarangat_LinkedIn <[email protected]>
---
 .../auron/utils/AuronSparkTestSettings.scala       |  15 ++
 .../org/apache/spark/sql/AuronDataFrameSuite.scala | 280 +++++++++++++++++++++
 2 files changed, 295 insertions(+)

diff --git 
a/auron-spark-tests/spark33/src/test/scala/org/apache/auron/utils/AuronSparkTestSettings.scala
 
b/auron-spark-tests/spark33/src/test/scala/org/apache/auron/utils/AuronSparkTestSettings.scala
index 4884eda9..55672fdd 100644
--- 
a/auron-spark-tests/spark33/src/test/scala/org/apache/auron/utils/AuronSparkTestSettings.scala
+++ 
b/auron-spark-tests/spark33/src/test/scala/org/apache/auron/utils/AuronSparkTestSettings.scala
@@ -42,6 +42,21 @@ class AuronSparkTestSettings extends SparkTestSettings {
 
   enableSuite[AuronTypedImperativeAggregateSuite]
 
+  enableSuite[AuronDataFrameSuite]
+    // Auron-specific implementations of these tests are provided above
+    .exclude("repartitionByRange")
+    .exclude("distributeBy and localSort")
+    .exclude("reuse exchange")
+    .exclude("SPARK-22520: support code generation for large CaseWhen")
+    .exclude("SPARK-27439: Explain result should match collected result after 
view change")
+    // These tests fail due to Auron native execution differences
+    .exclude("SPARK-28067: Aggregate sum should not return wrong results for 
decimal overflow")
+    .exclude("SPARK-35955: Aggregate avg should not return wrong results for 
decimal overflow")
+    .exclude("NaN is greater than all other non-NaN numeric values")
+    .exclude("SPARK-20897: cached self-join should not fail")
+    .exclude("SPARK-22271: mean overflows and returns null for some decimal 
variables")
+    .exclude("SPARK-32764: -0.0 and 0.0 should be equal")
+
   // Will be implemented in the future.
   override def getSQLQueryTestSettings = new SQLQueryTestSettings {
     override def getResourceFilePath: String = ???
diff --git 
a/auron-spark-tests/spark33/src/test/scala/org/apache/spark/sql/AuronDataFrameSuite.scala
 
b/auron-spark-tests/spark33/src/test/scala/org/apache/spark/sql/AuronDataFrameSuite.scala
new file mode 100644
index 00000000..74e27698
--- /dev/null
+++ 
b/auron-spark-tests/spark33/src/test/scala/org/apache/spark/sql/AuronDataFrameSuite.scala
@@ -0,0 +1,280 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql
+
+import java.io.ByteArrayOutputStream
+
+import scala.util.Random
+
+import org.apache.spark.sql.catalyst.expressions.{EqualTo, Expression}
+import org.apache.spark.sql.execution.auron.plan.{NativeAggExec, 
NativeShuffleExchangeExec}
+import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SQLTestData.TestData2
+
+class AuronDataFrameSuite extends DataFrameSuite with SparkQueryTestsBase {
+
+  testAuron("repartitionByRange") {
+    val partitionNum = 10
+    withSQLConf(
+      SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false",
+      SQLConf.SHUFFLE_PARTITIONS.key -> partitionNum.toString) {
+      import testImplicits._
+      val data1d = Random.shuffle(0.to(partitionNum - 1))
+      val data2d = data1d.map(i => (i, data1d.size - i))
+
+      checkAnswer(
+        data1d
+          .toDF("val")
+          .repartitionByRange(data1d.size, $"val".asc)
+          .select(spark_partition_id().as("id"), $"val"),
+        data1d.map(i => Row(i, i)))
+
+      checkAnswer(
+        data1d
+          .toDF("val")
+          .repartitionByRange(data1d.size, $"val".desc)
+          .select(spark_partition_id().as("id"), $"val"),
+        data1d.map(i => Row(i, data1d.size - 1 - i)))
+
+      checkAnswer(
+        data1d
+          .toDF("val")
+          .repartitionByRange(data1d.size, lit(42))
+          .select(spark_partition_id().as("id"), $"val"),
+        data1d.map(i => Row(0, i)))
+
+      checkAnswer(
+        data1d
+          .toDF("val")
+          .repartitionByRange(data1d.size, lit(null), $"val".asc, rand())
+          .select(spark_partition_id().as("id"), $"val"),
+        data1d.map(i => Row(i, i)))
+
+      checkAnswer(
+        data2d
+          .toDF("a", "b")
+          .repartitionByRange(data2d.size, $"a".desc, $"b")
+          .select(spark_partition_id().as("id"), $"a", $"b"),
+        data2d
+          .toDF("a", "b")
+          .repartitionByRange(data2d.size, $"a".desc, $"b".asc)
+          .select(spark_partition_id().as("id"), $"a", $"b"))
+
+      intercept[IllegalArgumentException] {
+        data1d.toDF("val").repartitionByRange(data1d.size)
+      }
+      intercept[IllegalArgumentException] {
+        data1d.toDF("val").repartitionByRange(data1d.size, Seq.empty: _*)
+      }
+    }
+  }
+
+  testAuron("distributeBy and localSort") {
+    import testImplicits._
+    val data = spark.sparkContext.parallelize((1 to 100).map(i => TestData2(i 
% 10, i))).toDF()
+
+    var partitionNum = 1
+    val original = testData.repartition(partitionNum)
+    assert(original.rdd.partitions.length == partitionNum)
+
+    val df6 = data.repartition(partitionNum, $"a").sortWithinPartitions("b")
+    df6.rdd.foreachPartition { p =>
+      var previousValue: Int = -1
+      var allSequential: Boolean = true
+      p.foreach { r =>
+        val v: Int = r.getInt(1)
+        if (previousValue != -1) {
+          if (previousValue > v) throw new Exception("Partition is not 
ordered.")
+          if (v - 1 != previousValue) allSequential = false
+        }
+        previousValue = v
+      }
+      if (!allSequential) {
+        throw new Exception("Partition should contain all sequential values")
+      }
+    }
+
+    partitionNum = 5
+    withSQLConf(
+      SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false",
+      SQLConf.SHUFFLE_PARTITIONS.key -> partitionNum.toString) {
+      val df = original.repartition(partitionNum, $"key")
+      assert(df.rdd.partitions.length == partitionNum)
+      checkAnswer(original.select(), df.select())
+
+      val df4 = data.repartition(partitionNum, 
$"a").sortWithinPartitions($"b".desc)
+      df4.rdd.foreachPartition { p =>
+        if (p.hasNext) {
+          var previousValue: Int = -1
+          var allSequential: Boolean = true
+          p.foreach { r =>
+            val v: Int = r.getInt(1)
+            if (previousValue != -1) {
+              if (previousValue < v) throw new Exception("Partition is not 
ordered.")
+              if (v + 1 != previousValue) allSequential = false
+            }
+            previousValue = v
+          }
+          if (allSequential) throw new Exception("Partition should not be 
globally ordered")
+        }
+      }
+    }
+
+    partitionNum = 10
+    withSQLConf(
+      SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false",
+      SQLConf.SHUFFLE_PARTITIONS.key -> partitionNum.toString) {
+      val df2 = original.repartition(partitionNum, $"key")
+      assert(df2.rdd.partitions.length == partitionNum)
+      checkAnswer(original.select(), df2.select())
+    }
+
+    val df3 = testData.repartition($"key").groupBy("key").count()
+    verifyNonExchangingAgg(df3)
+    verifyNonExchangingAgg(
+      testData
+        .repartition($"key", $"value")
+        .groupBy("key", "value")
+        .count())
+
+    verifyExchangingAgg(
+      testData
+        .repartition($"key", $"value")
+        .groupBy("key")
+        .count())
+
+    partitionNum = 2
+    withSQLConf(
+      SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false",
+      SQLConf.SHUFFLE_PARTITIONS.key -> partitionNum.toString) {
+      val df5 = data.repartition(partitionNum, 
$"a").sortWithinPartitions($"b".asc, $"a".asc)
+      df5.rdd.foreachPartition { p =>
+        var previousValue: Int = -1
+        var allSequential: Boolean = true
+        p.foreach { r =>
+          val v: Int = r.getInt(1)
+          if (previousValue != -1) {
+            if (previousValue > v) throw new Exception("Partition is not 
ordered.")
+            if (v - 1 != previousValue) allSequential = false
+          }
+          previousValue = v
+        }
+        if (allSequential) throw new Exception("Partition should not be all 
sequential")
+      }
+    }
+  }
+
+  testAuron("reuse exchange") {
+    withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2") {
+      val df = spark.range(100).toDF()
+      val join = df.join(df, "id")
+      checkAnswer(join, df)
+      val shuffleCount = collect(join.queryExecution.executedPlan) {
+        case e: NativeShuffleExchangeExec =>
+          true
+      }.size
+      assert(shuffleCount === 1, s"Expected 1 shuffle exchange, got 
$shuffleCount")
+      assert(collect(join.queryExecution.executedPlan) { case e: 
ReusedExchangeExec =>
+        true
+      }.size === 1)
+      val broadcasted = broadcast(join)
+      val join2 = join.join(broadcasted, "id").join(broadcasted, "id")
+      checkAnswer(join2, df)
+      val shuffleCount2 = collect(join2.queryExecution.executedPlan) {
+        case e: NativeShuffleExchangeExec =>
+          true
+      }.size
+      assert(shuffleCount2 == 1, s"Expected 1 shuffle exchange in join2, got 
$shuffleCount2")
+      assert(collect(join2.queryExecution.executedPlan) { case e: 
ReusedExchangeExec =>
+        true
+      }.size == 4)
+    }
+  }
+
+  testAuron("SPARK-22520: support code generation for large CaseWhen") {
+    import 
org.apache.spark.sql.catalyst.dsl.expressions.StringToAttributeConversionHelper
+    val N = 30
+    var expr1 = when(equalizer($"id", lit(0)), 0)
+    var expr2 = when(equalizer($"id", lit(0)), 10)
+    (1 to N).foreach { i =>
+      expr1 = expr1.when(equalizer($"id", lit(i)), -i)
+      expr2 = expr2.when(equalizer($"id", lit(i + 10)), i)
+    }
+    val df = spark.range(1).select(expr1, expr2.otherwise(0))
+    checkAnswer(df, Row(0, 10) :: Nil)
+  }
+
+  testAuron("SPARK-27439: Explain result should match collected result after 
view change") {
+    withTempView("test", "test2", "tmp") {
+      spark.range(10).createOrReplaceTempView("test")
+      spark.range(5).createOrReplaceTempView("test2")
+      spark.sql("select * from test").createOrReplaceTempView("tmp")
+      val df = spark.sql("select * from tmp")
+      spark.sql("select * from test2").createOrReplaceTempView("tmp")
+
+      val captured = new ByteArrayOutputStream()
+      Console.withOut(captured) {
+        df.explain(extended = true)
+      }
+      checkAnswer(df, spark.range(10).toDF)
+      val output = captured.toString
+      assert(output.contains("""== Parsed Logical Plan ==
+                               |'Project [*]
+                               |+- 'UnresolvedRelation [tmp]""".stripMargin))
+    }
+  }
+
+  private def withExpr(newExpr: Expression): Column = new Column(newExpr)
+
+  def equalizer(expr: Expression, other: Any): Column = withExpr {
+    val right = lit(other).expr
+    if (expr == right) {
+      logWarning(
+        s"Constructing trivially true equals predicate, '$expr = $right'. " +
+          "Perhaps you need to use aliases.")
+    }
+    EqualTo(expr, right)
+  }
+
+  private def verifyNonExchangingAgg(df: DataFrame): Unit = {
+    var atFirstAgg: Boolean = false
+    df.queryExecution.executedPlan.foreach {
+      case _: NativeAggExec =>
+        atFirstAgg = !atFirstAgg
+      case _ =>
+        if (atFirstAgg) {
+          fail("Should not have operators between the two aggregations")
+        }
+    }
+  }
+
+  private def verifyExchangingAgg(df: DataFrame): Unit = {
+    var atFirstAgg: Boolean = false
+    df.queryExecution.executedPlan.foreach {
+      case _: NativeAggExec =>
+        if (atFirstAgg) {
+          fail("Should not have back to back Aggregates")
+        }
+        atFirstAgg = true
+      case _: NativeShuffleExchangeExec =>
+        atFirstAgg = false
+      case _ =>
+    }
+  }
+}

Reply via email to