miland-db commented on code in PR #47609:
URL: https://github.com/apache/spark/pull/47609#discussion_r1721474497
##########
sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala:
##########
@@ -29,32 +30,48 @@ import org.apache.spark.sql.test.SharedSparkSession
* Output from the interpreter (iterator over executable statements) is then
checked - statements
* are executed and output DataFrames are compared with expected outputs.
*/
-class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession {
+class SqlScriptingInterpreterSuite extends SparkFunSuite with
SharedSparkSession {
// Helpers
- private def runSqlScript(sqlText: String): Array[DataFrame] = {
- val interpreter = SqlScriptingInterpreter()
- val compoundBody = spark.sessionState.sqlParser.parseScript(sqlText)
- val executionPlan = interpreter.buildExecutionPlan(compoundBody, spark)
- executionPlan.flatMap {
- case statement: SingleStatementExec =>
- if (statement.isExecuted) {
- None
- } else {
- Some(Dataset.ofRows(spark, statement.parsedPlan, new
QueryPlanningTracker))
- }
- case _ => None
- }.toArray
+ private def runSqlScript(sqlText: String): Seq[Array[Row]] = {
+ val interpreter = SqlScriptingInterpreter(spark)
+ val compoundBody =
spark.sessionState.sqlParser.parsePlan(sqlText).asInstanceOf[CompoundBody]
+ interpreter.executeInternal(compoundBody).toSeq
}
- private def verifySqlScriptResult(sqlText: String, expected: Seq[Seq[Row]]):
Unit = {
+ private def verifySqlScriptResult(sqlText: String, expected:
Seq[Array[Row]]): Unit = {
val result = runSqlScript(sqlText)
assert(result.length == expected.length)
- result.zip(expected).foreach { case (df, expectedAnswer) =>
checkAnswer(df, expectedAnswer) }
+ result.zip(expected).foreach {
+ case (actualAnswer, expectedAnswer) =>
+ assert(actualAnswer.sameElements(expectedAnswer))
+ }
+ }
+
+ // Tests setup
+ protected override def beforeAll(): Unit = {
+ super.beforeAll()
+ spark.conf.set(SQLConf.SQL_SCRIPTING_ENABLED.key, "true")
+ }
+
+ protected override def afterAll(): Unit = {
+ spark.conf.set(SQLConf.SQL_SCRIPTING_ENABLED.key, "false")
Review Comment:
Done.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]