fwc commented on code in PR #56190: URL: https://github.com/apache/spark/pull/56190#discussion_r3453819305
########## sql/core/src/test/scala/org/apache/spark/sql/CheckAnswerHelper.scala: ########## @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.util.TimeZone + +import org.scalatest.Assertions + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.util.{SparkErrorUtils, SparkStringUtils} + +/** + * Provides [[checkAnswer]] helper for SQL- & DataFrame-API tests. + * + * TODO: should be moved to sql/api together with SessionQueryTestBase + */ +@Experimental +trait CheckAnswerHelper extends Assertions { + + /** + * Runs the plan and makes sure the answer matches the expected result. + * + * @param df the DataFrame to be executed + * @param expectedAnswer the expected result in a Seq of Rows. + */ + protected def checkAnswer(df: => DataFrame, expectedAnswer: Seq[Row]): Unit = { + getErrorMessageInCheckAnswer(df, expectedAnswer) match { + case Some(errorMessage) => fail(errorMessage) + case None => + } + } + + /* + * Note: when moving this to sql/api, implementation should stay in sql/core + * (i.e. only have abstract decl in sql/api) + */ + protected def isDfSorted(df: DataFrame): Boolean = { + df match { + case df: classic.DataFrame => + df.logicalPlan.collectFirst { case s: logical.Sort => s }.nonEmpty + case _ => throw new RuntimeException(s"Cannot determine whether df is sorted: $df") + } + } + + /** + * Runs the plan and makes sure the answer matches the expected result. + * If there was exception during the execution or the contents of the DataFrame does not + * match the expected result, an error message will be returned. Otherwise, a None will + * be returned. + * + * @param df the DataFrame to be executed + * @param expectedAnswer the expected result in a Seq of Rows. + */ + private def getErrorMessageInCheckAnswer( + df: DataFrame, + expectedAnswer: Seq[Row]): Option[String] = { + val sparkAnswer = try df.collect().toSeq catch { + case e: Exception => + val errorMessage = + s""" + |Exception thrown while executing query: + |${df.queryExecution} + |== Exception == + |$e + |${SparkErrorUtils.stackTraceToString(e)} + """.stripMargin + return Some(errorMessage) + } + + sameRows(expectedAnswer, sparkAnswer, isDfSorted(df)).map { results => + s""" + |Results do not match for query: + |Timezone: ${TimeZone.getDefault} + |Timezone Env: ${sys.env.getOrElse("TZ", "")} + | + |${df.queryExecution} + |== Results == + |$results + """.stripMargin + } + } + + private def prepareAnswer(answer: Seq[Row], isSorted: Boolean): Seq[Row] = { + // Converts data to types that we can do equality comparison using Scala collections. + // For BigDecimal type, the Scala type has a better definition of equality test (similar to + // Java's java.math.BigDecimal.compareTo). + // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for + // equality test. + val converted: Seq[Row] = answer.map(prepareRow) + if (!isSorted) converted.sortBy(_.toString()) else converted + } + + // We need to call prepareRow recursively to handle schemas with struct types. + private def prepareRow(row: Row): Row = { + Row.fromSeq(row.toSeq.map { + case null => null + case bd: java.math.BigDecimal => BigDecimal(bd) + // Equality of WrappedArray differs for AnyVal and AnyRef in Scala 2.12.2+ + case seq: Seq[_] => seq.map { + case b: java.lang.Byte => b.byteValue + case s: java.lang.Short => s.shortValue + case i: java.lang.Integer => i.intValue + case l: java.lang.Long => l.longValue + case f: java.lang.Float => f.floatValue + case d: java.lang.Double => d.doubleValue + case x => x + } + // Convert array to Seq for easy equality check. + case b: Array[_] => b.toSeq + case r: Row => prepareRow(r) + // SPARK-51349: "null" and null had the same precedence in sorting + case "null" => "__null_string__" + case o => o + }) + } + + private def genError( + expectedAnswer: Seq[Row], + sparkAnswer: Seq[Row], + isSorted: Boolean = false): String = { + val getRowType: Option[Row] => String = row => + row.map(row => + if (row.schema == null) { + "struct<>" + } else { + s"${row.schema.catalogString}" + }).getOrElse("struct<>") Review Comment: This code is copy-pasted from `QueryTest`, I'd like to keep it as-is for now. ########## sql/core/src/test/scala/org/apache/spark/sql/CheckAnswerHelper.scala: ########## @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.util.TimeZone + +import org.scalatest.Assertions + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.util.{SparkErrorUtils, SparkStringUtils} + +/** + * Provides [[checkAnswer]] helper for SQL- & DataFrame-API tests. + * + * TODO: should be moved to sql/api together with SessionQueryTestBase + */ +@Experimental +trait CheckAnswerHelper extends Assertions { + + /** + * Runs the plan and makes sure the answer matches the expected result. + * + * @param df the DataFrame to be executed + * @param expectedAnswer the expected result in a Seq of Rows. + */ + protected def checkAnswer(df: => DataFrame, expectedAnswer: Seq[Row]): Unit = { + getErrorMessageInCheckAnswer(df, expectedAnswer) match { + case Some(errorMessage) => fail(errorMessage) + case None => + } + } + + /* + * Note: when moving this to sql/api, implementation should stay in sql/core + * (i.e. only have abstract decl in sql/api) + */ + protected def isDfSorted(df: DataFrame): Boolean = { + df match { + case df: classic.DataFrame => + df.logicalPlan.collectFirst { case s: logical.Sort => s }.nonEmpty + case _ => throw new RuntimeException(s"Cannot determine whether df is sorted: $df") + } + } + + /** + * Runs the plan and makes sure the answer matches the expected result. + * If there was exception during the execution or the contents of the DataFrame does not Review Comment: This code is copy-pasted from `QueryTest`, I'd like to keep it as-is for now. ########## core/src/test/scala/org/apache/spark/CheckErrorHelper.scala: ########## @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import scala.collection.mutable.ListBuffer +import scala.jdk.CollectionConverters._ + +import org.scalatest.Suite + +trait CheckErrorHelper { self: Suite => + + case class ExpectedContext( + contextType: QueryContextType, + objectType: String, + objectName: String, + startIndex: Int, + stopIndex: Int, + fragment: String, + callSitePattern: String + ) + + object ExpectedContext { + def apply(fragment: String, start: Int, stop: Int): ExpectedContext = { + ExpectedContext("", "", start, stop, fragment) + } + + // Check the fragment only. This is only used when the fragment is distinguished within + // the query text + def apply(fragment: String): ExpectedContext = { + ExpectedContext("", "", -1, -1, fragment) + } + + def apply( + objectType: String, + objectName: String, + startIndex: Int, + stopIndex: Int, + fragment: String): ExpectedContext = { + new ExpectedContext(QueryContextType.SQL, objectType, objectName, startIndex, stopIndex, + fragment, "") + } + + def apply(fragment: String, callSitePattern: String): ExpectedContext = { + new ExpectedContext(QueryContextType.DataFrame, "", "", -1, -1, fragment, callSitePattern) + } + } + + /** + * Parameter keys that are omitted from comparison when absent from the expected map. + * For each error condition, the set lists keys that are removed from the actual + * exception parameters before comparison with the expected map. + * Test suites may override this to add or change ignorable parameters per condition. + */ + protected def checkErrorIgnorableParameters: Map[String, Set[String]] = Map( + "TABLE_OR_VIEW_NOT_FOUND" -> Set("searchPath") + ) + + /** + * Checks an exception with an error condition against expected results. + * @param exception The exception to check + * @param condition The expected error condition identifying the error + * @param sqlState Optional the expected SQLSTATE, not verified if not supplied + * @param parameters A map of parameter names and values. The names are as defined + * in the error-classes file. + * @param matchPVals Optionally treat the parameters value as regular expression pattern. + * false if not supplied. + */ + protected def checkError( + exception: SparkThrowable, + condition: String, + sqlState: Option[String] = None, + parameters: Map[String, String] = Map.empty, + matchPVals: Boolean = false, + queryContext: Array[ExpectedContext] = Array.empty): Unit = { Review Comment: This code is copy-pasted from `SparkTestSuite`, I'd like to keep it as-is for now. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
