dejankrak-db commented on code in PR #48950: URL: https://github.com/apache/spark/pull/48950#discussion_r1857523544
########## sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala: ########## @@ -0,0 +1,1057 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.scripting + +import org.apache.spark.SparkConf +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans.logical.CompoundBody +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession + +/** + * SQL Scripting interpreter tests. + * Output from the parser is provided to the interpreter. + * Output from the interpreter (iterator over executable statements) is then checked - statements + * are executed and output DataFrames are compared with expected outputs. + */ +class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { + + // Tests setup + override protected def sparkConf: SparkConf = { + super.sparkConf.set(SQLConf.SQL_SCRIPTING_ENABLED.key, "true") + } + + // Helpers + private def runSqlScript( + sqlText: String, + args: Map[String, Expression] = Map.empty): Seq[Array[Row]] = { + val compoundBody = spark.sessionState.sqlParser.parsePlan(sqlText).asInstanceOf[CompoundBody] + val sse = new SqlScriptingExecution(compoundBody, spark, args) + sse.map { df => df.collect() }.toList + } + + private def verifySqlScriptResult(sqlText: String, expected: Seq[Seq[Row]]): Unit = { + val result = runSqlScript(sqlText) + assert(result.length == expected.length) + result.zip(expected).foreach { + case (actualAnswer, expectedAnswer) => + assert(actualAnswer.sameElements(expectedAnswer)) + } + } + + // Tests + test("multi statement - simple") { + withTable("t") { + val sqlScript = + """ + |BEGIN + |CREATE TABLE t (a INT, b STRING, c DOUBLE) USING parquet; + |INSERT INTO t VALUES (1, 'a', 1.0); + |SELECT a, b FROM t WHERE a = 12; + |SELECT a FROM t; + |END + |""".stripMargin + val expected = Seq( + Seq.empty[Row], // select + Seq(Row(1)) // select + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("multi statement - count") { + withTable("t") { + val sqlScript = + """ + |BEGIN + |CREATE TABLE t (a INT, b STRING, c DOUBLE) USING parquet; + |INSERT INTO t VALUES (1, 'a', 1.0); + |INSERT INTO t VALUES (1, 'a', 1.0); + |SELECT + | CASE WHEN COUNT(*) > 10 THEN true + | ELSE false + | END AS MoreThanTen + |FROM t; + |END + |""".stripMargin + val expected = Seq(Seq(Row(false))) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("session vars - set and read (SET VAR)") { + val sqlScript = + """ + |BEGIN + |DECLARE var = 1; + |SET VAR var = var + 1; + |SELECT var; + |END + |""".stripMargin + val expected = Seq(Seq(Row(2))) + verifySqlScriptResult(sqlScript, expected) + } + + test("session vars - set and read (SET)") { + val sqlScript = + """ + |BEGIN + |DECLARE var = 1; + |SET var = var + 1; + |SELECT var; + |END + |""".stripMargin + val expected = Seq(Seq(Row(2))) + verifySqlScriptResult(sqlScript, expected) + } + + test("session vars - set and read scoped") { + val sqlScript = + """ + |BEGIN + | BEGIN + | DECLARE var = 1; + | SELECT var; + | END; + | BEGIN + | DECLARE var = 2; + | SELECT var; + | END; + | BEGIN + | DECLARE var = 3; + | SET VAR var = var + 1; + | SELECT var; + | END; + |END + |""".stripMargin + val expected = Seq( + Seq(Row(1)), // select + Seq(Row(2)), // select + Seq(Row(4)) // select + ) + verifySqlScriptResult(sqlScript, expected) + } + + test("session vars - drop var statement") { + val sqlScript = + """ + |BEGIN + |DECLARE var = 1; + |SET VAR var = var + 1; + |SELECT var; + |DROP TEMPORARY VARIABLE var; + |END + |""".stripMargin + val expected = Seq(Seq(Row(2))) + verifySqlScriptResult(sqlScript, expected) + } + + test("if") { + val commands = + """ + |BEGIN + | IF 1=1 THEN + | SELECT 42; + | END IF; + |END + |""".stripMargin + val expected = Seq(Seq(Row(42))) + verifySqlScriptResult(commands, expected) + } + + test("if nested") { + val commands = + """ + |BEGIN + | IF 1=1 THEN + | IF 2=1 THEN + | SELECT 41; + | ELSE + | SELECT 42; + | END IF; + | END IF; + |END + |""".stripMargin + val expected = Seq(Seq(Row(42))) + verifySqlScriptResult(commands, expected) + } + + test("if else going in if") { + val commands = + """ + |BEGIN + | IF 1=1 + | THEN + | SELECT 42; + | ELSE + | SELECT 43; + | END IF; + |END + |""".stripMargin + val expected = Seq(Seq(Row(42))) + verifySqlScriptResult(commands, expected) + } + + test("if else if going in else if") { + val commands = + """ + |BEGIN + | IF 1=2 + | THEN + | SELECT 42; + | ELSE IF 1=1 + | THEN + | SELECT 43; + | ELSE + | SELECT 44; + | END IF; + |END + |""".stripMargin + val expected = Seq(Seq(Row(43))) + verifySqlScriptResult(commands, expected) + } + + test("if else going in else") { + val commands = + """ + |BEGIN + | IF 1=2 + | THEN + | SELECT 42; + | ELSE + | SELECT 43; + | END IF; + |END + |""".stripMargin + val expected = Seq(Seq(Row(43))) + verifySqlScriptResult(commands, expected) + } + + test("if else if going in else") { + val commands = + """ + |BEGIN + | IF 1=2 + | THEN + | SELECT 42; + | ELSE IF 1=3 + | THEN + | SELECT 43; + | ELSE + | SELECT 44; + | END IF; + |END + |""".stripMargin + val expected = Seq(Seq(Row(44))) + verifySqlScriptResult(commands, expected) + } + + test("if with count") { + withTable("t") { + val commands = + """ + |BEGIN + |CREATE TABLE t (a INT, b STRING, c DOUBLE) USING parquet; + |INSERT INTO t VALUES (1, 'a', 1.0); + |INSERT INTO t VALUES (1, 'a', 1.0); + |IF (SELECT COUNT(*) > 2 FROM t) THEN + | SELECT 42; + | ELSE + | SELECT 43; + | END IF; + |END + |""".stripMargin + val expected = Seq(Seq(Row(43))) + verifySqlScriptResult(commands, expected) + } + } + + test("if else if with count") { + withTable("t") { + val commands = + """ + |BEGIN + | CREATE TABLE t (a INT, b STRING, c DOUBLE) USING parquet; + | INSERT INTO t VALUES (1, 'a', 1.0); + | INSERT INTO t VALUES (1, 'a', 1.0); + | IF (SELECT COUNT(*) > 2 FROM t) THEN + | SELECT 42; + | ELSE IF (SELECT COUNT(*) > 1 FROM t) THEN + | SELECT 43; + | ELSE + | SELECT 44; + | END IF; + |END + |""".stripMargin + val expected = Seq(Seq(Row(43))) + verifySqlScriptResult(commands, expected) + } + } + + test("searched case") { + val commands = + """ + |BEGIN + | CASE + | WHEN 1 = 1 THEN + | SELECT 42; + | END CASE; + |END + |""".stripMargin + val expected = Seq(Seq(Row(42))) + verifySqlScriptResult(commands, expected) + } + + test("searched case nested") { + val commands = + """ + |BEGIN + | CASE + | WHEN 1=1 THEN + | CASE + | WHEN 2=1 THEN + | SELECT 41; + | ELSE + | SELECT 42; + | END CASE; + | END CASE; + |END + |""".stripMargin + val expected = Seq(Seq(Row(42))) + verifySqlScriptResult(commands, expected) + } + + test("searched case second case") { + val commands = + """ + |BEGIN + | CASE + | WHEN 1 = (SELECT 2) THEN + | SELECT 1; + | WHEN 2 = 2 THEN + | SELECT 42; + | WHEN (SELECT * FROM t) THEN + | SELECT * FROM b; + | END CASE; + |END + |""".stripMargin + val expected = Seq(Seq(Row(42))) + verifySqlScriptResult(commands, expected) + } + + test("searched case going in else") { + val commands = + """ + |BEGIN + | CASE + | WHEN 2 = 1 THEN + | SELECT 1; + | WHEN 3 IN (1,2) THEN + | SELECT 2; + | ELSE + | SELECT 43; + | END CASE; + |END + |""".stripMargin + val expected = Seq(Seq(Row(43))) + verifySqlScriptResult(commands, expected) + } + + test("searched case with count") { + withTable("t") { + val commands = + """ + |BEGIN + |CREATE TABLE t (a INT, b STRING, c DOUBLE) USING parquet; + |INSERT INTO t VALUES (1, 'a', 1.0); + |INSERT INTO t VALUES (1, 'a', 1.0); + |CASE + | WHEN (SELECT COUNT(*) > 2 FROM t) THEN + | SELECT 42; + | ELSE + | SELECT 43; + | END CASE; + |END + |""".stripMargin + val expected = Seq(Seq(Row(43))) + verifySqlScriptResult(commands, expected) + } + } + + test("searched case else with count") { + withTable("t") { + val commands = + """ + |BEGIN + | CREATE TABLE t (a INT, b STRING, c DOUBLE) USING parquet; + | INSERT INTO t VALUES (1, 'a', 1.0); + | INSERT INTO t VALUES (1, 'a', 1.0); + | CASE + | WHEN (SELECT COUNT(*) > 2 FROM t) THEN + | SELECT 42; + | WHEN (SELECT COUNT(*) > 1 FROM t) THEN + | SELECT 43; + | ELSE + | SELECT 44; + | END CASE; + |END + |""".stripMargin + val expected = Seq(Seq(Row(43))) + verifySqlScriptResult(commands, expected) + } + } + + test("searched case no cases matched no else") { + val commands = + """ + |BEGIN + | CASE + | WHEN 1 = 2 THEN + | SELECT 42; + | WHEN 1 = 3 THEN + | SELECT 43; + | END CASE; + |END + |""".stripMargin + val expected = Seq.empty + verifySqlScriptResult(commands, expected) + } + + test("simple case") { + val commands = + """ + |BEGIN + | CASE 1 + | WHEN 1 THEN + | SELECT 42; + | END CASE; + |END + |""".stripMargin + val expected = Seq(Seq(Row(42))) + verifySqlScriptResult(commands, expected) + } + + test("simple case nested") { + val commands = + """ + |BEGIN + | CASE 1 + | WHEN 1 THEN + | CASE 2 + | WHEN (SELECT 3) THEN + | SELECT 41; + | ELSE + | SELECT 42; + | END CASE; + | END CASE; + |END + |""".stripMargin + val expected = Seq(Seq(Row(42))) + verifySqlScriptResult(commands, expected) + } + + test("simple case second case") { + val commands = + """ + |BEGIN + | CASE (SELECT 2) + | WHEN 1 THEN + | SELECT 1; + | WHEN 2 THEN + | SELECT 42; + | WHEN (SELECT * FROM t) THEN + | SELECT * FROM b; + | END CASE; + |END + |""".stripMargin + val expected = Seq(Seq(Row(42))) + verifySqlScriptResult(commands, expected) + } + + test("simple case going in else") { + val commands = + """ + |BEGIN + | CASE 1 + | WHEN 2 THEN + | SELECT 1; + | WHEN 3 THEN + | SELECT 2; + | ELSE + | SELECT 43; + | END CASE; + |END + |""".stripMargin + val expected = Seq(Seq(Row(43))) + verifySqlScriptResult(commands, expected) + } + + test("simple case with count") { + withTable("t") { + val commands = + """ + |BEGIN + |CREATE TABLE t (a INT, b STRING, c DOUBLE) USING parquet; + |INSERT INTO t VALUES (1, 'a', 1.0); + |INSERT INTO t VALUES (1, 'a', 1.0); + |CASE (SELECT COUNT(*) FROM t) + | WHEN 1 THEN + | SELECT 41; + | WHEN 2 THEN + | SELECT 42; + | ELSE + | SELECT 43; + | END CASE; + |END + |""".stripMargin + val expected = Seq(Seq(Row(42))) + verifySqlScriptResult(commands, expected) + } + } + + test("simple case else with count") { + withTable("t") { + val commands = + """ + |BEGIN + | CREATE TABLE t (a INT, b STRING, c DOUBLE) USING parquet; + | INSERT INTO t VALUES (1, 'a', 1.0); + | INSERT INTO t VALUES (2, 'b', 2.0); + | CASE (SELECT COUNT(*) FROM t) + | WHEN 1 THEN + | SELECT 42; + | WHEN 3 THEN + | SELECT 43; + | ELSE + | SELECT 44; + | END CASE; + |END + |""".stripMargin + val expected = Seq(Seq(Row(44))) + verifySqlScriptResult(commands, expected) + } + } + + test("simple case no cases matched no else") { + val commands = + """ + |BEGIN + | CASE 1 + | WHEN 2 THEN + | SELECT 42; + | WHEN 3 THEN + | SELECT 43; + | END CASE; + |END + |""".stripMargin + val expected = Seq.empty + verifySqlScriptResult(commands, expected) + } + + test("simple case compare with null") { + withTable("t") { + val commands = + """ + |BEGIN + | CREATE TABLE t (a INT) USING parquet; + | CASE (SELECT COUNT(*) FROM t) + | WHEN 1 THEN + | SELECT 42; + | ELSE + | SELECT 43; + | END CASE; + |END + |""".stripMargin + val expected = Seq(Seq(Row(43))) + verifySqlScriptResult(commands, expected) + } + } + + test("while") { + val commands = + """ + |BEGIN + | DECLARE i = 0; + | WHILE i < 3 DO + | SELECT i; + | SET VAR i = i + 1; + | END WHILE; + |END + |""".stripMargin + val expected = Seq( + Seq(Row(0)), // select i + Seq(Row(1)), // select i + Seq(Row(2)) // select i + ) + verifySqlScriptResult(commands, expected) + } + + test("while: not entering body") { + val commands = + """ + |BEGIN + | DECLARE i = 3; + | WHILE i < 3 DO + | SELECT i; + | SET VAR i = i + 1; + | END WHILE; + |END + |""".stripMargin + val expected = Seq.empty + verifySqlScriptResult(commands, expected) + } + + test("nested while") { + val commands = + """ + |BEGIN + | DECLARE i = 0; + | DECLARE j = 0; + | WHILE i < 2 DO + | SET VAR j = 0; + | WHILE j < 2 DO + | SELECT i, j; + | SET VAR j = j + 1; + | END WHILE; + | SET VAR i = i + 1; + | END WHILE; + |END + |""".stripMargin + val expected = Seq( + Seq(Row(0, 0)), // select i, j + Seq(Row(0, 1)), // select i, j + Seq(Row(1, 0)), // select i, j + Seq(Row(1, 1)) // select i, j + ) + verifySqlScriptResult(commands, expected) + } + + test("while with count") { + withTable("t") { + val commands = + """ + |BEGIN + |CREATE TABLE t (a INT, b STRING, c DOUBLE) USING parquet; + |WHILE (SELECT COUNT(*) < 2 FROM t) DO + | SELECT 42; + | INSERT INTO t VALUES (1, 'a', 1.0); + |END WHILE; + |END + |""".stripMargin + val expected = Seq( + Seq(Row(42)), // select + Seq(Row(42)) // select + ) + verifySqlScriptResult(commands, expected) + } + } + + test("repeat") { + val commands = + """ + |BEGIN + | DECLARE i = 0; + | REPEAT + | SELECT i; + | SET VAR i = i + 1; + | UNTIL + | i = 3 + | END REPEAT; + |END + |""".stripMargin + val expected = Seq( + Seq(Row(0)), // select i + Seq(Row(1)), // select i + Seq(Row(2)) // select i + ) + verifySqlScriptResult(commands, expected) + } + + test("repeat: enters body only once") { + val commands = + """ + |BEGIN + | DECLARE i = 3; + | REPEAT + | SELECT i; + | SET VAR i = i + 1; + | UNTIL + | 1 = 1 + | END REPEAT; + |END + |""".stripMargin + + val expected = Seq(Seq(Row(3))) + verifySqlScriptResult(commands, expected) + } + + test("nested repeat") { + val commands = + """ + |BEGIN + | DECLARE i = 0; + | DECLARE j = 0; + | REPEAT + | SET VAR j = 0; + | REPEAT + | SELECT i, j; + | SET VAR j = j + 1; + | UNTIL j >= 2 + | END REPEAT; + | SET VAR i = i + 1; + | UNTIL i >= 2 + | END REPEAT; + |END + |""".stripMargin + + val expected = Seq( + Seq(Row(0, 0)), // select i, j + Seq(Row(0, 1)), // select i, j + Seq(Row(1, 0)), // select i, j + Seq(Row(1, 1)) // select i, j + ) + verifySqlScriptResult(commands, expected) + } + + test("repeat with count") { + withTable("t") { + val commands = + """ + |BEGIN + |CREATE TABLE t (a INT, b STRING, c DOUBLE) USING parquet; + |REPEAT + | SELECT 42; + | INSERT INTO t VALUES (1, 'a', 1.0); + |UNTIL (SELECT COUNT(*) >= 2 FROM t) + |END REPEAT; + |END + |""".stripMargin + + val expected = Seq( + Seq(Row(42)), // select + Seq(Row(42)) // select + ) + verifySqlScriptResult(commands, expected) + } + } + + test("leave compound block") { + val sqlScriptText = + """ + |BEGIN + | lbl: BEGIN + | SELECT 1; + | LEAVE lbl; + | SELECT 2; + | END; + |END""".stripMargin + val expected = Seq(Seq(Row(1))) + verifySqlScriptResult(sqlScriptText, expected) + } + + test("leave while loop") { + val sqlScriptText = + """ + |BEGIN + | lbl: WHILE 1 = 1 DO + | SELECT 1; + | LEAVE lbl; + | END WHILE; + |END""".stripMargin + val expected = Seq(Seq(Row(1))) + verifySqlScriptResult(sqlScriptText, expected) + } + + test("leave repeat loop") { + val sqlScriptText = + """ + |BEGIN + | lbl: REPEAT + | SELECT 1; + | LEAVE lbl; + | UNTIL 1 = 2 + | END REPEAT; + |END""".stripMargin + val expected = Seq(Seq(Row(1))) + verifySqlScriptResult(sqlScriptText, expected) + } + + test("iterate while loop") { + val sqlScriptText = + """ + |BEGIN + | DECLARE x INT; + | SET x = 0; + | lbl: WHILE x < 2 DO + | SET x = x + 1; + | ITERATE lbl; + | SET x = x + 2; + | END WHILE; + | SELECT x; + |END""".stripMargin + val expected = Seq(Seq(Row(2))) + verifySqlScriptResult(sqlScriptText, expected) + } + + test("iterate repeat loop") { + val sqlScriptText = + """ + |BEGIN + | DECLARE x INT; + | SET x = 0; + | lbl: REPEAT + | SET x = x + 1; + | ITERATE lbl; + | SET x = x + 2; + | UNTIL x > 1 + | END REPEAT; + | SELECT x; + |END""".stripMargin + val expected = Seq(Seq(Row(2))) + verifySqlScriptResult(sqlScriptText, expected) + } + + test("leave outer loop from nested repeat loop") { + val sqlScriptText = + """ + |BEGIN + | lbl: REPEAT + | lbl2: REPEAT + | SELECT 1; + | LEAVE lbl; + | UNTIL 1 = 2 + | END REPEAT; + | UNTIL 1 = 2 + | END REPEAT; + |END""".stripMargin + val expected = Seq(Seq(Row(1))) + verifySqlScriptResult(sqlScriptText, expected) + } + + test("leave outer loop from nested while loop") { + val sqlScriptText = + """ + |BEGIN + | lbl: WHILE 1 = 1 DO + | lbl2: WHILE 2 = 2 DO + | SELECT 1; + | LEAVE lbl; + | END WHILE; + | END WHILE; + |END""".stripMargin + val expected = Seq(Seq(Row(1))) + verifySqlScriptResult(sqlScriptText, expected) + } + + test("iterate outer loop from nested while loop") { + val sqlScriptText = + """ + |BEGIN + | DECLARE x INT; + | SET x = 0; + | lbl: WHILE x < 2 DO + | SET x = x + 1; + | lbl2: WHILE 2 = 2 DO + | SELECT 1; + | ITERATE lbl; + | END WHILE; + | END WHILE; + | SELECT x; + |END""".stripMargin + val expected = Seq( + Seq(Row(1)), // select 1 + Seq(Row(1)), // select 1 + Seq(Row(2)) // select x + ) + verifySqlScriptResult(sqlScriptText, expected) + } + + test("nested compounds in loop - leave in inner compound") { + val sqlScriptText = + """ + |BEGIN + | DECLARE x INT; + | SET x = 0; + | lbl: WHILE x < 2 DO + | SET x = x + 1; + | BEGIN + | SELECT 1; + | lbl2: BEGIN + | SELECT 2; + | LEAVE lbl2; + | SELECT 3; + | END; + | END; + | END WHILE; + | SELECT x; + |END""".stripMargin + val expected = Seq( + Seq(Row(1)), // select 1 + Seq(Row(2)), // select 2 + Seq(Row(1)), // select 1 + Seq(Row(2)), // select 2 + Seq(Row(2)) // select x + ) + verifySqlScriptResult(sqlScriptText, expected) + } + + test("iterate outer loop from nested repeat loop") { + val sqlScriptText = + """ + |BEGIN + | DECLARE x INT; + | SET x = 0; + | lbl: REPEAT + | SET x = x + 1; + | lbl2: REPEAT + | SELECT 1; + | ITERATE lbl; + | UNTIL 1 = 2 + | END REPEAT; + | UNTIL x > 1 + | END REPEAT; + | SELECT x; + |END""".stripMargin + val expected = Seq( + Seq(Row(1)), // select 1 + Seq(Row(1)), // select 1 + Seq(Row(2)) // select x + ) + verifySqlScriptResult(sqlScriptText, expected) + } + + test("loop statement with leave") { + val sqlScriptText = + """ + |BEGIN + | DECLARE x INT; + | SET x = 0; + | lbl: LOOP + | SET x = x + 1; + | SELECT x; + | IF x > 2 + | THEN + | LEAVE lbl; + | END IF; + | END LOOP; + | SELECT x; + |END""".stripMargin + val expected = Seq( + Seq(Row(1)), // select x + Seq(Row(2)), // select x + Seq(Row(3)), // select x + Seq(Row(3)) // select x + ) + verifySqlScriptResult(sqlScriptText, expected) + } + + test("nested loop statement with leave") { + val commands = + """ + |BEGIN + | DECLARE x = 0; + | DECLARE y = 0; + | lbl1: LOOP + | SET VAR y = 0; + | lbl2: LOOP + | SELECT x, y; + | SET VAR y = y + 1; + | IF y >= 2 THEN + | LEAVE lbl2; + | END IF; + | END LOOP; + | SET VAR x = x + 1; + | IF x >= 2 THEN + | LEAVE lbl1; + | END IF; + | END LOOP; + |END + |""".stripMargin + + val expected = Seq( + Seq(Row(0, 0)), // select x, y + Seq(Row(0, 1)), // select x, y + Seq(Row(1, 0)), // select x, y + Seq(Row(1, 1)) // select x, y + ) + verifySqlScriptResult(commands, expected) + } + + test("iterate loop statement") { + val sqlScriptText = + """ + |BEGIN + | DECLARE x INT; + | SET x = 0; + | lbl: LOOP + | SET x = x + 1; + | IF x > 1 THEN + | LEAVE lbl; + | END IF; + | ITERATE lbl; + | SET x = x + 2; + | END LOOP; + | SELECT x; + |END""".stripMargin + val expected = Seq(Seq(Row(2))) + verifySqlScriptResult(sqlScriptText, expected) + } + + test("leave outer loop from nested loop statement") { + val sqlScriptText = + """ + |BEGIN + | lbl: LOOP + | lbl2: LOOP + | SELECT 1; + | LEAVE lbl; + | END LOOP; + | END LOOP; + |END""".stripMargin + val expected = Seq(Seq(Row(1))) Review Comment: Perhaps add a comment explaining that since we immediately leave in the first iteration, we expect only a single row in the resultset. ########## sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala: ########## @@ -0,0 +1,1057 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.scripting + +import org.apache.spark.SparkConf +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans.logical.CompoundBody +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession + +/** + * SQL Scripting interpreter tests. + * Output from the parser is provided to the interpreter. + * Output from the interpreter (iterator over executable statements) is then checked - statements + * are executed and output DataFrames are compared with expected outputs. + */ +class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { + + // Tests setup + override protected def sparkConf: SparkConf = { + super.sparkConf.set(SQLConf.SQL_SCRIPTING_ENABLED.key, "true") + } + + // Helpers + private def runSqlScript( + sqlText: String, + args: Map[String, Expression] = Map.empty): Seq[Array[Row]] = { + val compoundBody = spark.sessionState.sqlParser.parsePlan(sqlText).asInstanceOf[CompoundBody] + val sse = new SqlScriptingExecution(compoundBody, spark, args) + sse.map { df => df.collect() }.toList + } + + private def verifySqlScriptResult(sqlText: String, expected: Seq[Seq[Row]]): Unit = { + val result = runSqlScript(sqlText) + assert(result.length == expected.length) + result.zip(expected).foreach { + case (actualAnswer, expectedAnswer) => + assert(actualAnswer.sameElements(expectedAnswer)) + } + } + + // Tests + test("multi statement - simple") { + withTable("t") { + val sqlScript = + """ + |BEGIN + |CREATE TABLE t (a INT, b STRING, c DOUBLE) USING parquet; + |INSERT INTO t VALUES (1, 'a', 1.0); + |SELECT a, b FROM t WHERE a = 12; + |SELECT a FROM t; + |END + |""".stripMargin + val expected = Seq( + Seq.empty[Row], // select + Seq(Row(1)) // select + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("multi statement - count") { + withTable("t") { + val sqlScript = + """ + |BEGIN + |CREATE TABLE t (a INT, b STRING, c DOUBLE) USING parquet; + |INSERT INTO t VALUES (1, 'a', 1.0); + |INSERT INTO t VALUES (1, 'a', 1.0); + |SELECT + | CASE WHEN COUNT(*) > 10 THEN true + | ELSE false + | END AS MoreThanTen + |FROM t; + |END + |""".stripMargin + val expected = Seq(Seq(Row(false))) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("session vars - set and read (SET VAR)") { + val sqlScript = + """ + |BEGIN + |DECLARE var = 1; + |SET VAR var = var + 1; + |SELECT var; + |END + |""".stripMargin + val expected = Seq(Seq(Row(2))) + verifySqlScriptResult(sqlScript, expected) + } + + test("session vars - set and read (SET)") { + val sqlScript = + """ + |BEGIN + |DECLARE var = 1; + |SET var = var + 1; + |SELECT var; + |END + |""".stripMargin + val expected = Seq(Seq(Row(2))) + verifySqlScriptResult(sqlScript, expected) + } + + test("session vars - set and read scoped") { + val sqlScript = + """ + |BEGIN + | BEGIN + | DECLARE var = 1; + | SELECT var; + | END; + | BEGIN + | DECLARE var = 2; + | SELECT var; + | END; + | BEGIN + | DECLARE var = 3; + | SET VAR var = var + 1; + | SELECT var; + | END; + |END + |""".stripMargin + val expected = Seq( + Seq(Row(1)), // select + Seq(Row(2)), // select + Seq(Row(4)) // select + ) + verifySqlScriptResult(sqlScript, expected) + } + + test("session vars - drop var statement") { + val sqlScript = + """ + |BEGIN + |DECLARE var = 1; + |SET VAR var = var + 1; + |SELECT var; + |DROP TEMPORARY VARIABLE var; + |END + |""".stripMargin + val expected = Seq(Seq(Row(2))) + verifySqlScriptResult(sqlScript, expected) + } + + test("if") { + val commands = + """ + |BEGIN + | IF 1=1 THEN + | SELECT 42; + | END IF; + |END + |""".stripMargin + val expected = Seq(Seq(Row(42))) + verifySqlScriptResult(commands, expected) + } + + test("if nested") { + val commands = + """ + |BEGIN + | IF 1=1 THEN + | IF 2=1 THEN + | SELECT 41; + | ELSE + | SELECT 42; + | END IF; + | END IF; + |END + |""".stripMargin + val expected = Seq(Seq(Row(42))) + verifySqlScriptResult(commands, expected) + } + + test("if else going in if") { + val commands = + """ + |BEGIN + | IF 1=1 + | THEN + | SELECT 42; + | ELSE + | SELECT 43; + | END IF; + |END + |""".stripMargin + val expected = Seq(Seq(Row(42))) + verifySqlScriptResult(commands, expected) + } + + test("if else if going in else if") { + val commands = + """ + |BEGIN + | IF 1=2 + | THEN + | SELECT 42; + | ELSE IF 1=1 + | THEN + | SELECT 43; + | ELSE + | SELECT 44; + | END IF; + |END + |""".stripMargin + val expected = Seq(Seq(Row(43))) + verifySqlScriptResult(commands, expected) + } + + test("if else going in else") { + val commands = + """ + |BEGIN + | IF 1=2 + | THEN + | SELECT 42; + | ELSE + | SELECT 43; + | END IF; + |END + |""".stripMargin + val expected = Seq(Seq(Row(43))) + verifySqlScriptResult(commands, expected) + } + + test("if else if going in else") { + val commands = + """ + |BEGIN + | IF 1=2 + | THEN + | SELECT 42; + | ELSE IF 1=3 + | THEN + | SELECT 43; + | ELSE + | SELECT 44; + | END IF; + |END + |""".stripMargin + val expected = Seq(Seq(Row(44))) + verifySqlScriptResult(commands, expected) + } + + test("if with count") { + withTable("t") { + val commands = + """ + |BEGIN + |CREATE TABLE t (a INT, b STRING, c DOUBLE) USING parquet; + |INSERT INTO t VALUES (1, 'a', 1.0); + |INSERT INTO t VALUES (1, 'a', 1.0); + |IF (SELECT COUNT(*) > 2 FROM t) THEN + | SELECT 42; + | ELSE + | SELECT 43; + | END IF; + |END + |""".stripMargin + val expected = Seq(Seq(Row(43))) + verifySqlScriptResult(commands, expected) + } + } + + test("if else if with count") { + withTable("t") { + val commands = + """ + |BEGIN + | CREATE TABLE t (a INT, b STRING, c DOUBLE) USING parquet; + | INSERT INTO t VALUES (1, 'a', 1.0); + | INSERT INTO t VALUES (1, 'a', 1.0); + | IF (SELECT COUNT(*) > 2 FROM t) THEN + | SELECT 42; + | ELSE IF (SELECT COUNT(*) > 1 FROM t) THEN + | SELECT 43; + | ELSE + | SELECT 44; + | END IF; + |END + |""".stripMargin + val expected = Seq(Seq(Row(43))) + verifySqlScriptResult(commands, expected) + } + } + + test("searched case") { + val commands = + """ + |BEGIN + | CASE + | WHEN 1 = 1 THEN + | SELECT 42; + | END CASE; + |END + |""".stripMargin + val expected = Seq(Seq(Row(42))) + verifySqlScriptResult(commands, expected) + } + + test("searched case nested") { + val commands = + """ + |BEGIN + | CASE + | WHEN 1=1 THEN + | CASE + | WHEN 2=1 THEN + | SELECT 41; + | ELSE + | SELECT 42; + | END CASE; + | END CASE; + |END + |""".stripMargin + val expected = Seq(Seq(Row(42))) + verifySqlScriptResult(commands, expected) + } + + test("searched case second case") { + val commands = + """ + |BEGIN + | CASE + | WHEN 1 = (SELECT 2) THEN + | SELECT 1; + | WHEN 2 = 2 THEN + | SELECT 42; + | WHEN (SELECT * FROM t) THEN + | SELECT * FROM b; + | END CASE; + |END + |""".stripMargin + val expected = Seq(Seq(Row(42))) + verifySqlScriptResult(commands, expected) + } + + test("searched case going in else") { + val commands = + """ + |BEGIN + | CASE + | WHEN 2 = 1 THEN + | SELECT 1; + | WHEN 3 IN (1,2) THEN + | SELECT 2; + | ELSE + | SELECT 43; + | END CASE; + |END + |""".stripMargin + val expected = Seq(Seq(Row(43))) + verifySqlScriptResult(commands, expected) + } + + test("searched case with count") { + withTable("t") { + val commands = + """ + |BEGIN + |CREATE TABLE t (a INT, b STRING, c DOUBLE) USING parquet; + |INSERT INTO t VALUES (1, 'a', 1.0); + |INSERT INTO t VALUES (1, 'a', 1.0); + |CASE + | WHEN (SELECT COUNT(*) > 2 FROM t) THEN + | SELECT 42; + | ELSE + | SELECT 43; + | END CASE; + |END + |""".stripMargin + val expected = Seq(Seq(Row(43))) + verifySqlScriptResult(commands, expected) + } + } + + test("searched case else with count") { + withTable("t") { + val commands = + """ + |BEGIN + | CREATE TABLE t (a INT, b STRING, c DOUBLE) USING parquet; + | INSERT INTO t VALUES (1, 'a', 1.0); + | INSERT INTO t VALUES (1, 'a', 1.0); + | CASE + | WHEN (SELECT COUNT(*) > 2 FROM t) THEN + | SELECT 42; + | WHEN (SELECT COUNT(*) > 1 FROM t) THEN + | SELECT 43; + | ELSE + | SELECT 44; + | END CASE; + |END + |""".stripMargin + val expected = Seq(Seq(Row(43))) + verifySqlScriptResult(commands, expected) + } + } + + test("searched case no cases matched no else") { + val commands = + """ + |BEGIN + | CASE + | WHEN 1 = 2 THEN + | SELECT 42; + | WHEN 1 = 3 THEN + | SELECT 43; + | END CASE; + |END + |""".stripMargin + val expected = Seq.empty + verifySqlScriptResult(commands, expected) + } + + test("simple case") { + val commands = + """ + |BEGIN + | CASE 1 + | WHEN 1 THEN + | SELECT 42; + | END CASE; + |END + |""".stripMargin + val expected = Seq(Seq(Row(42))) + verifySqlScriptResult(commands, expected) + } + + test("simple case nested") { + val commands = + """ + |BEGIN + | CASE 1 + | WHEN 1 THEN + | CASE 2 + | WHEN (SELECT 3) THEN + | SELECT 41; + | ELSE + | SELECT 42; + | END CASE; + | END CASE; + |END + |""".stripMargin + val expected = Seq(Seq(Row(42))) + verifySqlScriptResult(commands, expected) + } + + test("simple case second case") { + val commands = + """ + |BEGIN + | CASE (SELECT 2) + | WHEN 1 THEN + | SELECT 1; + | WHEN 2 THEN + | SELECT 42; + | WHEN (SELECT * FROM t) THEN + | SELECT * FROM b; + | END CASE; + |END + |""".stripMargin + val expected = Seq(Seq(Row(42))) + verifySqlScriptResult(commands, expected) + } + + test("simple case going in else") { + val commands = + """ + |BEGIN + | CASE 1 + | WHEN 2 THEN + | SELECT 1; + | WHEN 3 THEN + | SELECT 2; + | ELSE + | SELECT 43; + | END CASE; + |END + |""".stripMargin + val expected = Seq(Seq(Row(43))) + verifySqlScriptResult(commands, expected) + } + + test("simple case with count") { + withTable("t") { + val commands = + """ + |BEGIN + |CREATE TABLE t (a INT, b STRING, c DOUBLE) USING parquet; + |INSERT INTO t VALUES (1, 'a', 1.0); + |INSERT INTO t VALUES (1, 'a', 1.0); + |CASE (SELECT COUNT(*) FROM t) + | WHEN 1 THEN + | SELECT 41; + | WHEN 2 THEN + | SELECT 42; + | ELSE + | SELECT 43; + | END CASE; + |END + |""".stripMargin + val expected = Seq(Seq(Row(42))) + verifySqlScriptResult(commands, expected) + } + } + + test("simple case else with count") { + withTable("t") { + val commands = + """ + |BEGIN + | CREATE TABLE t (a INT, b STRING, c DOUBLE) USING parquet; + | INSERT INTO t VALUES (1, 'a', 1.0); + | INSERT INTO t VALUES (2, 'b', 2.0); + | CASE (SELECT COUNT(*) FROM t) + | WHEN 1 THEN + | SELECT 42; + | WHEN 3 THEN + | SELECT 43; + | ELSE + | SELECT 44; + | END CASE; + |END + |""".stripMargin + val expected = Seq(Seq(Row(44))) + verifySqlScriptResult(commands, expected) + } + } + + test("simple case no cases matched no else") { + val commands = + """ + |BEGIN + | CASE 1 + | WHEN 2 THEN + | SELECT 42; + | WHEN 3 THEN + | SELECT 43; + | END CASE; + |END + |""".stripMargin + val expected = Seq.empty + verifySqlScriptResult(commands, expected) + } + + test("simple case compare with null") { + withTable("t") { + val commands = + """ + |BEGIN + | CREATE TABLE t (a INT) USING parquet; + | CASE (SELECT COUNT(*) FROM t) + | WHEN 1 THEN + | SELECT 42; + | ELSE + | SELECT 43; + | END CASE; + |END + |""".stripMargin + val expected = Seq(Seq(Row(43))) + verifySqlScriptResult(commands, expected) + } + } + + test("while") { + val commands = + """ + |BEGIN + | DECLARE i = 0; + | WHILE i < 3 DO + | SELECT i; + | SET VAR i = i + 1; + | END WHILE; + |END + |""".stripMargin + val expected = Seq( + Seq(Row(0)), // select i + Seq(Row(1)), // select i + Seq(Row(2)) // select i + ) + verifySqlScriptResult(commands, expected) + } + + test("while: not entering body") { + val commands = + """ + |BEGIN + | DECLARE i = 3; + | WHILE i < 3 DO + | SELECT i; + | SET VAR i = i + 1; + | END WHILE; + |END + |""".stripMargin + val expected = Seq.empty + verifySqlScriptResult(commands, expected) + } + + test("nested while") { + val commands = + """ + |BEGIN + | DECLARE i = 0; + | DECLARE j = 0; + | WHILE i < 2 DO + | SET VAR j = 0; + | WHILE j < 2 DO + | SELECT i, j; + | SET VAR j = j + 1; + | END WHILE; + | SET VAR i = i + 1; + | END WHILE; + |END + |""".stripMargin + val expected = Seq( + Seq(Row(0, 0)), // select i, j + Seq(Row(0, 1)), // select i, j + Seq(Row(1, 0)), // select i, j + Seq(Row(1, 1)) // select i, j + ) + verifySqlScriptResult(commands, expected) + } + + test("while with count") { + withTable("t") { + val commands = + """ + |BEGIN + |CREATE TABLE t (a INT, b STRING, c DOUBLE) USING parquet; + |WHILE (SELECT COUNT(*) < 2 FROM t) DO + | SELECT 42; + | INSERT INTO t VALUES (1, 'a', 1.0); + |END WHILE; + |END + |""".stripMargin + val expected = Seq( + Seq(Row(42)), // select + Seq(Row(42)) // select + ) + verifySqlScriptResult(commands, expected) + } + } + + test("repeat") { + val commands = + """ + |BEGIN + | DECLARE i = 0; + | REPEAT + | SELECT i; + | SET VAR i = i + 1; + | UNTIL + | i = 3 + | END REPEAT; + |END + |""".stripMargin + val expected = Seq( + Seq(Row(0)), // select i + Seq(Row(1)), // select i + Seq(Row(2)) // select i + ) + verifySqlScriptResult(commands, expected) + } + + test("repeat: enters body only once") { + val commands = + """ + |BEGIN + | DECLARE i = 3; + | REPEAT + | SELECT i; + | SET VAR i = i + 1; + | UNTIL + | 1 = 1 + | END REPEAT; + |END + |""".stripMargin + + val expected = Seq(Seq(Row(3))) + verifySqlScriptResult(commands, expected) + } + + test("nested repeat") { + val commands = + """ + |BEGIN + | DECLARE i = 0; + | DECLARE j = 0; + | REPEAT + | SET VAR j = 0; + | REPEAT + | SELECT i, j; + | SET VAR j = j + 1; + | UNTIL j >= 2 + | END REPEAT; + | SET VAR i = i + 1; + | UNTIL i >= 2 + | END REPEAT; + |END + |""".stripMargin + + val expected = Seq( + Seq(Row(0, 0)), // select i, j + Seq(Row(0, 1)), // select i, j + Seq(Row(1, 0)), // select i, j + Seq(Row(1, 1)) // select i, j + ) + verifySqlScriptResult(commands, expected) + } + + test("repeat with count") { + withTable("t") { + val commands = + """ + |BEGIN + |CREATE TABLE t (a INT, b STRING, c DOUBLE) USING parquet; + |REPEAT + | SELECT 42; + | INSERT INTO t VALUES (1, 'a', 1.0); + |UNTIL (SELECT COUNT(*) >= 2 FROM t) + |END REPEAT; + |END + |""".stripMargin + + val expected = Seq( + Seq(Row(42)), // select + Seq(Row(42)) // select + ) + verifySqlScriptResult(commands, expected) + } + } + + test("leave compound block") { + val sqlScriptText = + """ + |BEGIN + | lbl: BEGIN + | SELECT 1; + | LEAVE lbl; + | SELECT 2; + | END; + |END""".stripMargin + val expected = Seq(Seq(Row(1))) + verifySqlScriptResult(sqlScriptText, expected) + } + + test("leave while loop") { + val sqlScriptText = + """ + |BEGIN + | lbl: WHILE 1 = 1 DO + | SELECT 1; + | LEAVE lbl; + | END WHILE; + |END""".stripMargin + val expected = Seq(Seq(Row(1))) + verifySqlScriptResult(sqlScriptText, expected) + } + + test("leave repeat loop") { + val sqlScriptText = + """ + |BEGIN + | lbl: REPEAT + | SELECT 1; + | LEAVE lbl; + | UNTIL 1 = 2 + | END REPEAT; + |END""".stripMargin + val expected = Seq(Seq(Row(1))) + verifySqlScriptResult(sqlScriptText, expected) + } + + test("iterate while loop") { + val sqlScriptText = + """ + |BEGIN + | DECLARE x INT; + | SET x = 0; + | lbl: WHILE x < 2 DO + | SET x = x + 1; + | ITERATE lbl; + | SET x = x + 2; + | END WHILE; + | SELECT x; + |END""".stripMargin + val expected = Seq(Seq(Row(2))) + verifySqlScriptResult(sqlScriptText, expected) + } + + test("iterate repeat loop") { + val sqlScriptText = + """ + |BEGIN + | DECLARE x INT; + | SET x = 0; + | lbl: REPEAT + | SET x = x + 1; + | ITERATE lbl; + | SET x = x + 2; + | UNTIL x > 1 + | END REPEAT; + | SELECT x; + |END""".stripMargin + val expected = Seq(Seq(Row(2))) + verifySqlScriptResult(sqlScriptText, expected) + } + + test("leave outer loop from nested repeat loop") { + val sqlScriptText = + """ + |BEGIN + | lbl: REPEAT + | lbl2: REPEAT + | SELECT 1; + | LEAVE lbl; + | UNTIL 1 = 2 + | END REPEAT; + | UNTIL 1 = 2 + | END REPEAT; + |END""".stripMargin + val expected = Seq(Seq(Row(1))) + verifySqlScriptResult(sqlScriptText, expected) + } + + test("leave outer loop from nested while loop") { + val sqlScriptText = + """ + |BEGIN + | lbl: WHILE 1 = 1 DO + | lbl2: WHILE 2 = 2 DO + | SELECT 1; + | LEAVE lbl; + | END WHILE; + | END WHILE; + |END""".stripMargin + val expected = Seq(Seq(Row(1))) + verifySqlScriptResult(sqlScriptText, expected) + } + + test("iterate outer loop from nested while loop") { + val sqlScriptText = + """ + |BEGIN + | DECLARE x INT; + | SET x = 0; + | lbl: WHILE x < 2 DO + | SET x = x + 1; + | lbl2: WHILE 2 = 2 DO + | SELECT 1; + | ITERATE lbl; + | END WHILE; + | END WHILE; + | SELECT x; + |END""".stripMargin + val expected = Seq( + Seq(Row(1)), // select 1 + Seq(Row(1)), // select 1 + Seq(Row(2)) // select x + ) + verifySqlScriptResult(sqlScriptText, expected) + } + + test("nested compounds in loop - leave in inner compound") { + val sqlScriptText = + """ + |BEGIN + | DECLARE x INT; + | SET x = 0; + | lbl: WHILE x < 2 DO + | SET x = x + 1; + | BEGIN + | SELECT 1; + | lbl2: BEGIN + | SELECT 2; + | LEAVE lbl2; + | SELECT 3; + | END; + | END; + | END WHILE; + | SELECT x; + |END""".stripMargin + val expected = Seq( + Seq(Row(1)), // select 1 + Seq(Row(2)), // select 2 + Seq(Row(1)), // select 1 + Seq(Row(2)), // select 2 + Seq(Row(2)) // select x + ) + verifySqlScriptResult(sqlScriptText, expected) + } + + test("iterate outer loop from nested repeat loop") { + val sqlScriptText = + """ + |BEGIN + | DECLARE x INT; + | SET x = 0; + | lbl: REPEAT + | SET x = x + 1; + | lbl2: REPEAT + | SELECT 1; + | ITERATE lbl; + | UNTIL 1 = 2 + | END REPEAT; + | UNTIL x > 1 + | END REPEAT; + | SELECT x; + |END""".stripMargin + val expected = Seq( + Seq(Row(1)), // select 1 + Seq(Row(1)), // select 1 + Seq(Row(2)) // select x + ) + verifySqlScriptResult(sqlScriptText, expected) + } + + test("loop statement with leave") { + val sqlScriptText = + """ + |BEGIN + | DECLARE x INT; + | SET x = 0; + | lbl: LOOP + | SET x = x + 1; + | SELECT x; + | IF x > 2 + | THEN + | LEAVE lbl; + | END IF; + | END LOOP; + | SELECT x; + |END""".stripMargin + val expected = Seq( + Seq(Row(1)), // select x + Seq(Row(2)), // select x + Seq(Row(3)), // select x + Seq(Row(3)) // select x + ) + verifySqlScriptResult(sqlScriptText, expected) + } + + test("nested loop statement with leave") { + val commands = + """ + |BEGIN + | DECLARE x = 0; + | DECLARE y = 0; + | lbl1: LOOP + | SET VAR y = 0; + | lbl2: LOOP + | SELECT x, y; + | SET VAR y = y + 1; + | IF y >= 2 THEN + | LEAVE lbl2; + | END IF; + | END LOOP; + | SET VAR x = x + 1; + | IF x >= 2 THEN + | LEAVE lbl1; + | END IF; + | END LOOP; + |END + |""".stripMargin + + val expected = Seq( + Seq(Row(0, 0)), // select x, y + Seq(Row(0, 1)), // select x, y + Seq(Row(1, 0)), // select x, y + Seq(Row(1, 1)) // select x, y + ) + verifySqlScriptResult(commands, expected) + } + + test("iterate loop statement") { + val sqlScriptText = + """ + |BEGIN + | DECLARE x INT; + | SET x = 0; + | lbl: LOOP + | SET x = x + 1; + | IF x > 1 THEN + | LEAVE lbl; + | END IF; + | ITERATE lbl; + | SET x = x + 2; + | END LOOP; + | SELECT x; + |END""".stripMargin + val expected = Seq(Seq(Row(2))) + verifySqlScriptResult(sqlScriptText, expected) + } + + test("leave outer loop from nested loop statement") { + val sqlScriptText = + """ + |BEGIN + | lbl: LOOP + | lbl2: LOOP + | SELECT 1; + | LEAVE lbl; + | END LOOP; + | END LOOP; + |END""".stripMargin + val expected = Seq(Seq(Row(1))) Review Comment: Perhaps add a comment explaining that since we immediately leave in the first iteration, we expect only a single row in the result set. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
