fwc commented on code in PR #56190:
URL: https://github.com/apache/spark/pull/56190#discussion_r3453822699


##########
sql/core/src/test/scala/org/apache/spark/sql/CheckAnswerHelper.scala:
##########
@@ -0,0 +1,194 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.util.TimeZone
+
+import org.scalatest.Assertions
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.sql.catalyst.plans.logical
+import org.apache.spark.util.{SparkErrorUtils, SparkStringUtils}
+
+/**
+ * Provides [[checkAnswer]] helper for SQL- & DataFrame-API tests.
+ *
+ * TODO: should be moved to sql/api together with SessionQueryTestBase
+ */
+@Experimental
+trait CheckAnswerHelper extends Assertions {
+
+  /**
+   * Runs the plan and makes sure the answer matches the expected result.
+   *
+   * @param df the DataFrame to be executed
+   * @param expectedAnswer the expected result in a Seq of Rows.
+   */
+  protected def checkAnswer(df: => DataFrame, expectedAnswer: Seq[Row]): Unit 
= {
+    getErrorMessageInCheckAnswer(df, expectedAnswer) match {
+      case Some(errorMessage) => fail(errorMessage)
+      case None =>
+    }
+  }
+
+  /*
+   * Note: when moving this to sql/api, implementation should stay in sql/core
+   * (i.e. only have abstract decl in sql/api)
+   */
+  protected def isDfSorted(df: DataFrame): Boolean = {
+    df match {
+      case df: classic.DataFrame =>
+        df.logicalPlan.collectFirst { case s: logical.Sort => s }.nonEmpty
+      case _ => throw new RuntimeException(s"Cannot determine whether df is 
sorted: $df")
+    }
+  }
+
+  /**
+   * Runs the plan and makes sure the answer matches the expected result.
+   * If there was exception during the execution or the contents of the 
DataFrame does not
+   * match the expected result, an error message will be returned. Otherwise, 
a None will
+   * be returned.
+   *
+   * @param df the DataFrame to be executed
+   * @param expectedAnswer the expected result in a Seq of Rows.
+   */
+  private def getErrorMessageInCheckAnswer(
+      df: DataFrame,
+      expectedAnswer: Seq[Row]): Option[String] = {
+    val sparkAnswer = try df.collect().toSeq catch {
+      case e: Exception =>
+        val errorMessage =
+          s"""
+             |Exception thrown while executing query:
+             |${df.queryExecution}
+             |== Exception ==
+             |$e
+             |${SparkErrorUtils.stackTraceToString(e)}
+          """.stripMargin
+        return Some(errorMessage)
+    }
+
+    sameRows(expectedAnswer, sparkAnswer, isDfSorted(df)).map { results =>
+      s"""
+         |Results do not match for query:
+         |Timezone: ${TimeZone.getDefault}
+         |Timezone Env: ${sys.env.getOrElse("TZ", "")}
+         |
+         |${df.queryExecution}
+         |== Results ==
+         |$results
+       """.stripMargin
+    }
+  }
+
+  private def prepareAnswer(answer: Seq[Row], isSorted: Boolean): Seq[Row] = {
+    // Converts data to types that we can do equality comparison using Scala 
collections.
+    // For BigDecimal type, the Scala type has a better definition of equality 
test (similar to
+    // Java's java.math.BigDecimal.compareTo).
+    // For binary arrays, we convert it to Seq to avoid of calling 
java.util.Arrays.equals for
+    // equality test.
+    val converted: Seq[Row] = answer.map(prepareRow)
+    if (!isSorted) converted.sortBy(_.toString()) else converted
+  }
+
+  // We need to call prepareRow recursively to handle schemas with struct 
types.
+  private def prepareRow(row: Row): Row = {
+    Row.fromSeq(row.toSeq.map {
+      case null => null
+      case bd: java.math.BigDecimal => BigDecimal(bd)
+      // Equality of WrappedArray differs for AnyVal and AnyRef in Scala 
2.12.2+
+      case seq: Seq[_] => seq.map {
+        case b: java.lang.Byte => b.byteValue
+        case s: java.lang.Short => s.shortValue
+        case i: java.lang.Integer => i.intValue
+        case l: java.lang.Long => l.longValue
+        case f: java.lang.Float => f.floatValue
+        case d: java.lang.Double => d.doubleValue
+        case x => x
+      }
+      // Convert array to Seq for easy equality check.
+      case b: Array[_] => b.toSeq
+      case r: Row => prepareRow(r)
+      // SPARK-51349: "null" and null had the same precedence in sorting
+      case "null" => "__null_string__"
+      case o => o
+    })
+  }
+
+  private def genError(
+                        expectedAnswer: Seq[Row],
+                        sparkAnswer: Seq[Row],
+                        isSorted: Boolean = false): String = {
+    val getRowType: Option[Row] => String = row =>
+      row.map(row =>
+        if (row.schema == null) {
+          "struct<>"
+        } else {
+          s"${row.schema.catalogString}"
+        }).getOrElse("struct<>")
+
+    s"""
+       |== Results ==
+       |${
+      SparkStringUtils.sideBySide(
+        s"== Correct Answer - ${expectedAnswer.size} ==" +:
+          getRowType(expectedAnswer.headOption) +:
+          prepareAnswer(expectedAnswer, isSorted).map(_.toString()),
+        s"== Spark Answer - ${sparkAnswer.size} ==" +:
+          getRowType(sparkAnswer.headOption) +:
+          prepareAnswer(sparkAnswer, 
isSorted).map(_.toString())).mkString("\n")
+    }
+    """.stripMargin
+  }
+
+  private def compare(obj1: Any, obj2: Any): Boolean = (obj1, obj2) match {
+    case (null, null) => true
+    case (null, _) => false
+    case (_, null) => false
+    case (a: Array[_], b: Array[_]) =>

Review Comment:
   This code is copy-pasted from `QueryTest`, :shrug:



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to