ionagamed commented on code in PR #56190: URL: https://github.com/apache/spark/pull/56190#discussion_r3441613379
########## sql/core/src/test/scala/org/apache/spark/sql/classic/SessionQueryTest.scala: ########## @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.classic + +import org.apache.spark.sql + +/** + * Override of [[sql.SessionQueryTest]] that provides [[SparkSession classic.SparkSession]]. + * + * Can be used to declare classic-specific tests: + * {{{ + * class FooSuite extends sql.SessionQueryTest { + * // shared classic/connect-agnostic testcases + * } + * + * // no need to extend FooSuite as sql.SessionQueryTest + * // already executes shared tests via classic internally. + * class FooClassicSuite extends classic.SessionQueryTest { Review Comment: QQ: What should a test author do when they want to reuse some common utilities from `FooSuite` in their `FooClassicSuite`? `FooBaseSuite` should work fine, but I think it might be worth mentioning the recommended way to handle these situations. ########## core/src/test/scala/org/apache/spark/CheckErrorHelper.scala: ########## @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import scala.collection.mutable.ListBuffer +import scala.jdk.CollectionConverters._ + +import org.scalatest.Suite + +trait CheckErrorHelper { self: Suite => + + case class ExpectedContext( + contextType: QueryContextType, + objectType: String, + objectName: String, + startIndex: Int, + stopIndex: Int, + fragment: String, + callSitePattern: String + ) + + object ExpectedContext { + def apply(fragment: String, start: Int, stop: Int): ExpectedContext = { + ExpectedContext("", "", start, stop, fragment) + } + + // Check the fragment only. This is only used when the fragment is distinguished within + // the query text + def apply(fragment: String): ExpectedContext = { + ExpectedContext("", "", -1, -1, fragment) + } + + def apply( + objectType: String, + objectName: String, + startIndex: Int, + stopIndex: Int, + fragment: String): ExpectedContext = { + new ExpectedContext(QueryContextType.SQL, objectType, objectName, startIndex, stopIndex, + fragment, "") + } + + def apply(fragment: String, callSitePattern: String): ExpectedContext = { + new ExpectedContext(QueryContextType.DataFrame, "", "", -1, -1, fragment, callSitePattern) + } + } + + /** + * Parameter keys that are omitted from comparison when absent from the expected map. + * For each error condition, the set lists keys that are removed from the actual + * exception parameters before comparison with the expected map. + * Test suites may override this to add or change ignorable parameters per condition. + */ + protected def checkErrorIgnorableParameters: Map[String, Set[String]] = Map( + "TABLE_OR_VIEW_NOT_FOUND" -> Set("searchPath") + ) + + /** + * Checks an exception with an error condition against expected results. + * @param exception The exception to check + * @param condition The expected error condition identifying the error + * @param sqlState Optional the expected SQLSTATE, not verified if not supplied + * @param parameters A map of parameter names and values. The names are as defined + * in the error-classes file. + * @param matchPVals Optionally treat the parameters value as regular expression pattern. + * false if not supplied. + */ + protected def checkError( + exception: SparkThrowable, + condition: String, + sqlState: Option[String] = None, + parameters: Map[String, String] = Map.empty, + matchPVals: Boolean = false, + queryContext: Array[ExpectedContext] = Array.empty): Unit = { Review Comment: nit: `queryContext` seems missing from the doc above. ########## sql/core/src/test/scala/org/apache/spark/sql/SessionQueryTestBase.scala: ########## @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +// scalastyle:off funsuite +import org.scalatest.funsuite.AnyFunSuite +// scalastyle:on + +/** + * TODO should be moved to sql/api + * + * base for fully sql/core independent tests, i.e. this trait could be moved to sql/api and then + * used in sql/connect/client. + */ +trait SessionQueryTestBase + extends AnyFunSuite + with SparkSessionProvider + with CheckAnswerHelper + with QueryCleanupHelper { + + /** + * Documents used session so that tests can handle and document session-specific behaviour + * + * {{{ + * test(...) { + * val df = // query with connect-specific behaviour + * if (sessionType == "connect") { + * checkError(...) + * } else { + * checkAnswer(df, ...) + * } + * } + * }}} + */ + def sessionType: String Review Comment: QQ: Why is this escape hatch here? The overall approach that this PR seems to be going by is discouraging writing implementation-specific tests, and it seems to me that this goes against that spirit? It might be worth it to either leave the escape hatch in and document the discouragement; or leave it out altogether. ########## sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkSessionBinder.scala: ########## @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect + +import java.util.UUID + +import org.apache.spark.{SparkEnv, SparkFunSuite} +import org.apache.spark.sql +import org.apache.spark.sql.classic +import org.apache.spark.sql.connect.client.SparkConnectClient +import org.apache.spark.sql.connect.config.Connect +import org.apache.spark.sql.connect.service.SparkConnectService + +/** + * Provides a [[SparkSession connect.SparkSession]] backed by an in-process gRPC server. + * Extends [[sql.SparkSessionBinder sql.SparkSessionBinder]] (which creates a + * [[classic.SparkSession classic.SparkSession]] and SparkContext), then layers a Connect client + * session on top by starting the gRPC service in-process. + */ +trait SparkSessionBinder extends sql.SparkSessionBinder { self: SparkFunSuite => + + private var _connectSpark: SparkSession = _ + + protected override def spark: SparkSession = _connectSpark + + /** The underlying classic session used by the in-process server. */ + protected def classicSpark: classic.SparkSession = super.spark.asInstanceOf[classic.SparkSession] + + override protected def beforeAll(): Unit = { + super.beforeAll() + // Other suites using mocks leave a mess in the global executionManager, + // shut it down so that it's cleared before starting server. + SparkConnectService.executionManager.shutdown() + val prevPort = SparkEnv.get.conf.get(Connect.CONNECT_GRPC_BINDING_PORT) + try { + // set GRPC_BINDING_PORT to 0 so that the server picks a random, freely available port. + SparkEnv.get.conf.set(Connect.CONNECT_GRPC_BINDING_PORT, 0) + SparkConnectService.start(classicSpark.sparkContext) + } finally { + SparkEnv.get.conf.set(Connect.CONNECT_GRPC_BINDING_PORT, prevPort) + } + } + + override def beforeEach(): Unit = { + val client = SparkConnectClient Review Comment: QQ: Why are we creating a fresh `client` for each test? The current `SharedSparkSession` is using the suite-level hooks to provide a single shared session for the whole suite. ########## sql/core/src/test/scala/org/apache/spark/sql/CheckAnswerHelper.scala: ########## @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.util.TimeZone + +import org.scalatest.Assertions + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.util.{SparkErrorUtils, SparkStringUtils} + +/** + * Provides [[checkAnswer]] helper for SQL- & DataFrame-API tests. + * + * TODO: should be moved to sql/api together with SessionQueryTestBase + */ +@Experimental +trait CheckAnswerHelper extends Assertions { + + /** + * Runs the plan and makes sure the answer matches the expected result. + * + * @param df the DataFrame to be executed + * @param expectedAnswer the expected result in a Seq of Rows. + */ + protected def checkAnswer(df: => DataFrame, expectedAnswer: Seq[Row]): Unit = { + getErrorMessageInCheckAnswer(df, expectedAnswer) match { + case Some(errorMessage) => fail(errorMessage) + case None => + } + } + + /* + * Note: when moving this to sql/api, implementation should stay in sql/core + * (i.e. only have abstract decl in sql/api) + */ + protected def isDfSorted(df: DataFrame): Boolean = { + df match { + case df: classic.DataFrame => + df.logicalPlan.collectFirst { case s: logical.Sort => s }.nonEmpty + case _ => throw new RuntimeException(s"Cannot determine whether df is sorted: $df") + } + } + + /** + * Runs the plan and makes sure the answer matches the expected result. + * If there was exception during the execution or the contents of the DataFrame does not Review Comment: nit: ```suggestion * If there was an exception during the execution or the contents of the DataFrame do not ``` ########## sql/core/src/test/scala/org/apache/spark/sql/CheckAnswerHelper.scala: ########## @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.util.TimeZone + +import org.scalatest.Assertions + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.util.{SparkErrorUtils, SparkStringUtils} + +/** + * Provides [[checkAnswer]] helper for SQL- & DataFrame-API tests. + * + * TODO: should be moved to sql/api together with SessionQueryTestBase + */ +@Experimental +trait CheckAnswerHelper extends Assertions { + + /** + * Runs the plan and makes sure the answer matches the expected result. + * + * @param df the DataFrame to be executed + * @param expectedAnswer the expected result in a Seq of Rows. + */ + protected def checkAnswer(df: => DataFrame, expectedAnswer: Seq[Row]): Unit = { + getErrorMessageInCheckAnswer(df, expectedAnswer) match { + case Some(errorMessage) => fail(errorMessage) + case None => + } + } + + /* + * Note: when moving this to sql/api, implementation should stay in sql/core + * (i.e. only have abstract decl in sql/api) + */ + protected def isDfSorted(df: DataFrame): Boolean = { + df match { + case df: classic.DataFrame => + df.logicalPlan.collectFirst { case s: logical.Sort => s }.nonEmpty + case _ => throw new RuntimeException(s"Cannot determine whether df is sorted: $df") + } + } + + /** + * Runs the plan and makes sure the answer matches the expected result. + * If there was exception during the execution or the contents of the DataFrame does not + * match the expected result, an error message will be returned. Otherwise, a None will + * be returned. + * + * @param df the DataFrame to be executed + * @param expectedAnswer the expected result in a Seq of Rows. + */ + private def getErrorMessageInCheckAnswer( + df: DataFrame, + expectedAnswer: Seq[Row]): Option[String] = { + val sparkAnswer = try df.collect().toSeq catch { + case e: Exception => + val errorMessage = + s""" + |Exception thrown while executing query: + |${df.queryExecution} + |== Exception == + |$e + |${SparkErrorUtils.stackTraceToString(e)} + """.stripMargin + return Some(errorMessage) + } + + sameRows(expectedAnswer, sparkAnswer, isDfSorted(df)).map { results => + s""" + |Results do not match for query: + |Timezone: ${TimeZone.getDefault} + |Timezone Env: ${sys.env.getOrElse("TZ", "")} + | + |${df.queryExecution} + |== Results == + |$results + """.stripMargin + } + } + + private def prepareAnswer(answer: Seq[Row], isSorted: Boolean): Seq[Row] = { + // Converts data to types that we can do equality comparison using Scala collections. + // For BigDecimal type, the Scala type has a better definition of equality test (similar to + // Java's java.math.BigDecimal.compareTo). + // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for + // equality test. + val converted: Seq[Row] = answer.map(prepareRow) + if (!isSorted) converted.sortBy(_.toString()) else converted + } + + // We need to call prepareRow recursively to handle schemas with struct types. + private def prepareRow(row: Row): Row = { + Row.fromSeq(row.toSeq.map { + case null => null + case bd: java.math.BigDecimal => BigDecimal(bd) + // Equality of WrappedArray differs for AnyVal and AnyRef in Scala 2.12.2+ + case seq: Seq[_] => seq.map { + case b: java.lang.Byte => b.byteValue + case s: java.lang.Short => s.shortValue + case i: java.lang.Integer => i.intValue + case l: java.lang.Long => l.longValue + case f: java.lang.Float => f.floatValue + case d: java.lang.Double => d.doubleValue + case x => x + } + // Convert array to Seq for easy equality check. + case b: Array[_] => b.toSeq + case r: Row => prepareRow(r) + // SPARK-51349: "null" and null had the same precedence in sorting + case "null" => "__null_string__" + case o => o + }) + } + + private def genError( + expectedAnswer: Seq[Row], + sparkAnswer: Seq[Row], + isSorted: Boolean = false): String = { + val getRowType: Option[Row] => String = row => + row.map(row => + if (row.schema == null) { + "struct<>" + } else { + s"${row.schema.catalogString}" + }).getOrElse("struct<>") Review Comment: nit: this might read better? ```suggestion row.map(r => Option(r.schema)) .map(s => s"${s.catalogString}") .getOrElse("struct<>") ``` ########## sql/core/src/test/scala/org/apache/spark/sql/CheckAnswerHelper.scala: ########## @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.util.TimeZone + +import org.scalatest.Assertions + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.util.{SparkErrorUtils, SparkStringUtils} + +/** + * Provides [[checkAnswer]] helper for SQL- & DataFrame-API tests. + * + * TODO: should be moved to sql/api together with SessionQueryTestBase + */ +@Experimental +trait CheckAnswerHelper extends Assertions { + + /** + * Runs the plan and makes sure the answer matches the expected result. + * + * @param df the DataFrame to be executed + * @param expectedAnswer the expected result in a Seq of Rows. + */ + protected def checkAnswer(df: => DataFrame, expectedAnswer: Seq[Row]): Unit = { + getErrorMessageInCheckAnswer(df, expectedAnswer) match { + case Some(errorMessage) => fail(errorMessage) + case None => + } + } + + /* + * Note: when moving this to sql/api, implementation should stay in sql/core + * (i.e. only have abstract decl in sql/api) + */ + protected def isDfSorted(df: DataFrame): Boolean = { + df match { + case df: classic.DataFrame => + df.logicalPlan.collectFirst { case s: logical.Sort => s }.nonEmpty + case _ => throw new RuntimeException(s"Cannot determine whether df is sorted: $df") + } + } + + /** + * Runs the plan and makes sure the answer matches the expected result. + * If there was exception during the execution or the contents of the DataFrame does not + * match the expected result, an error message will be returned. Otherwise, a None will + * be returned. + * + * @param df the DataFrame to be executed + * @param expectedAnswer the expected result in a Seq of Rows. + */ + private def getErrorMessageInCheckAnswer( + df: DataFrame, + expectedAnswer: Seq[Row]): Option[String] = { + val sparkAnswer = try df.collect().toSeq catch { + case e: Exception => + val errorMessage = + s""" + |Exception thrown while executing query: + |${df.queryExecution} + |== Exception == + |$e + |${SparkErrorUtils.stackTraceToString(e)} + """.stripMargin + return Some(errorMessage) + } + + sameRows(expectedAnswer, sparkAnswer, isDfSorted(df)).map { results => + s""" + |Results do not match for query: + |Timezone: ${TimeZone.getDefault} + |Timezone Env: ${sys.env.getOrElse("TZ", "")} + | + |${df.queryExecution} + |== Results == + |$results + """.stripMargin + } + } + + private def prepareAnswer(answer: Seq[Row], isSorted: Boolean): Seq[Row] = { + // Converts data to types that we can do equality comparison using Scala collections. + // For BigDecimal type, the Scala type has a better definition of equality test (similar to + // Java's java.math.BigDecimal.compareTo). + // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for + // equality test. + val converted: Seq[Row] = answer.map(prepareRow) + if (!isSorted) converted.sortBy(_.toString()) else converted + } + + // We need to call prepareRow recursively to handle schemas with struct types. + private def prepareRow(row: Row): Row = { + Row.fromSeq(row.toSeq.map { + case null => null + case bd: java.math.BigDecimal => BigDecimal(bd) + // Equality of WrappedArray differs for AnyVal and AnyRef in Scala 2.12.2+ + case seq: Seq[_] => seq.map { + case b: java.lang.Byte => b.byteValue + case s: java.lang.Short => s.shortValue + case i: java.lang.Integer => i.intValue + case l: java.lang.Long => l.longValue + case f: java.lang.Float => f.floatValue + case d: java.lang.Double => d.doubleValue + case x => x + } + // Convert array to Seq for easy equality check. + case b: Array[_] => b.toSeq + case r: Row => prepareRow(r) + // SPARK-51349: "null" and null had the same precedence in sorting + case "null" => "__null_string__" + case o => o + }) + } + + private def genError( + expectedAnswer: Seq[Row], + sparkAnswer: Seq[Row], + isSorted: Boolean = false): String = { + val getRowType: Option[Row] => String = row => + row.map(row => + if (row.schema == null) { + "struct<>" + } else { + s"${row.schema.catalogString}" + }).getOrElse("struct<>") + + s""" + |== Results == + |${ + SparkStringUtils.sideBySide( + s"== Correct Answer - ${expectedAnswer.size} ==" +: + getRowType(expectedAnswer.headOption) +: + prepareAnswer(expectedAnswer, isSorted).map(_.toString()), + s"== Spark Answer - ${sparkAnswer.size} ==" +: + getRowType(sparkAnswer.headOption) +: + prepareAnswer(sparkAnswer, isSorted).map(_.toString())).mkString("\n") + } + """.stripMargin + } + + private def compare(obj1: Any, obj2: Any): Boolean = (obj1, obj2) match { + case (null, null) => true + case (null, _) => false + case (_, null) => false + case (a: Array[_], b: Array[_]) => Review Comment: QQ: Aren't `Array`s always converted to `Seq`s above? ########## sql/core/src/test/scala/org/apache/spark/sql/SessionQueryTest.scala: ########## @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.SparkFunSuite + +/** + * Provides connect-compatible test utils to write suites that have 'connect variants': + * {{{ + * // in sql/core + * FooSuite extends SessionQueryTest { test("") { ... } } Review Comment: nit: same as my comment from above, consider adding guidance on how to handle `FooConnectSuite` needing utils from `FooSuite`. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
