This is an automated email from the ASF dual-hosted git repository. hvanhovell pushed a commit to branch branch-3.5 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push: new 02d1f09df0d [SPARK-44824][CONNECT][TESTS][3.5] Reset `ammoniteOut` in the `afterEach` method of `ReplE2ESuite` 02d1f09df0d is described below commit 02d1f09df0da202e3996cdcfbca44525862528b9 Author: yangjie01 <yangji...@baidu.com> AuthorDate: Wed Aug 16 21:03:01 2023 +0200 [SPARK-44824][CONNECT][TESTS][3.5] Reset `ammoniteOut` in the `afterEach` method of `ReplE2ESuite` ### What changes were proposed in this pull request? This PR add `ammoniteOut.reset()` in the `afterEach` method of `ReplE2ESuite` to ensure that the 'output' used for assertions in each test case is only related to the current case and not all content. ### Why are the changes needed? The current `ammoniteOut` records the output content of all executed tests, without isolating between cases. This can lead to unexpected assertion results. For example, adding 'assertContains("""String = "[MyTestClass(1), MyTestClass(3)]"""", output)' in the following test case would still pass the test because it is a result content printed to `ammoniteOut` in the previous test case. https://github.com/apache/spark/blob/2be20e54a2222f6cdf64e8486d1910133b43665f/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala#L283-L290 Hence, we need to clear the content in `ammoniteOut` after each test to achieve isolation between test cases. ### Does this PR introduce _any_ user-facing change? No, just for test ### How was this patch tested? - Pass Github Actions - Manual check Prints the `output` after `val output = runCommandsInShell(input)` in the the case `streaming works with REPL generated code` https://github.com/apache/spark/blob/2be20e54a2222f6cdf64e8486d1910133b43665f/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala#L313-L318 run ``` build/sbt "connect-client-jvm/testOnly org.apache.spark.sql.application.ReplE2ESuite" -Phive ``` **Before**: we can see the content of all test cases that have been executed in the `ReplE2ESuite` ``` Spark session available as 'spark'. _____ __ ______ __ / ___/____ ____ ______/ /__ / ____/___ ____ ____ ___ _____/ /_ \__ \/ __ \/ __ `/ ___/ //_/ / / / __ \/ __ \/ __ \/ _ \/ ___/ __/ ___/ / /_/ / /_/ / / / ,< / /___/ /_/ / / / / / / / __/ /__/ /_ /____/ .___/\__,_/_/ /_/|_| \____/\____/_/ /_/_/ /_/\___/\___/\__/ /_/ spark.sql("select 1").collect() res0: Array[org.apache.spark.sql.Row] = Array([1]) semaphore.release() class A(x: Int) { def get = x * 5 + 19 } defined class A def dummyUdf(x: Int): Int = new A(x).get defined function dummyUdf val myUdf = udf(dummyUdf _) myUdf: org.apache.spark.sql.expressions.UserDefinedFunction = ScalarUserDefinedFunction( Array( ... spark.range(5).select(myUdf(col("id"))).as[Int].collect() res5: Array[Int] = Array(19, 24, 29, 34, 39) semaphore.release() class A(x: Int) { def get = x * 42 + 5 } defined class A val myUdf = udf((x: Int) => new A(x).get) myUdf: org.apache.spark.sql.expressions.UserDefinedFunction = ScalarUserDefinedFunction( Array( ... spark.range(5).select(myUdf(col("id"))).as[Int].collect() res9: Array[Int] = Array(5, 47, 89, 131, 173) semaphore.release() class A(x: Int) { def get = x * 7 } defined class A val myUdf = udf((x: Int) => new A(x).get) myUdf: org.apache.spark.sql.expressions.UserDefinedFunction = ScalarUserDefinedFunction( Array( ... val modifiedUdf = myUdf.withName("myUdf").asNondeterministic() modifiedUdf: org.apache.spark.sql.expressions.UserDefinedFunction = ScalarUserDefinedFunction( Array( ... spark.range(5).select(modifiedUdf(col("id"))).as[Int].collect() res14: Array[Int] = Array(0, 7, 14, 21, 28) semaphore.release() spark.range(10).filter(n => n % 2 == 0).collect() res16: Array[java.lang.Long] = Array(0L, 2L, 4L, 6L, 8L) semaphore.release() import java.nio.file.Paths import java.nio.file.Paths def classLoadingTest(x: Int): Int = { val classloader = Option(Thread.currentThread().getContextClassLoader).getOrElse(getClass.getClassLoader) val cls = Class.forName("com.example.Hello$", true, classloader) val module = cls.getField("MODULE$").get(null) cls.getMethod("test").invoke(module).asInstanceOf[Int] } defined function classLoadingTest val classLoaderUdf = udf(classLoadingTest _) classLoaderUdf: org.apache.spark.sql.expressions.UserDefinedFunction = ScalarUserDefinedFunction( Array( ... val jarPath = Paths.get("/Users/yangjie01/SourceCode/git/spark-mine-sbt/connector/connect/client/jvm/src/test/resources/TestHelloV2_2.12.jar").toUri jarPath: java.net.URI = file:///Users/yangjie01/SourceCode/git/spark-mine-sbt/connector/connect/client/jvm/src/test/resources/TestHelloV2_2.12.jar spark.addArtifact(jarPath) spark.range(5).select(classLoaderUdf(col("id"))).as[Int].collect() res23: Array[Int] = Array(2, 2, 2, 2, 2) semaphore.release() import org.apache.spark.sql.api.java._ import org.apache.spark.sql.api.java._ import org.apache.spark.sql.types.LongType import org.apache.spark.sql.types.LongType val javaUdf = udf(new UDF1[Long, Long] { override def call(num: Long): Long = num * num + 25L }, LongType).asNondeterministic() javaUdf: org.apache.spark.sql.expressions.UserDefinedFunction = ScalarUserDefinedFunction( Array( ... spark.range(5).select(javaUdf(col("id"))).as[Long].collect() res28: Array[Long] = Array(25L, 26L, 29L, 34L, 41L) semaphore.release() import org.apache.spark.sql.api.java._ import org.apache.spark.sql.api.java._ import org.apache.spark.sql.types.LongType import org.apache.spark.sql.types.LongType spark.udf.register("javaUdf", new UDF1[Long, Long] { override def call(num: Long): Long = num * num * num + 250L }, LongType) spark.sql("select javaUdf(id) from range(5)").as[Long].collect() res33: Array[Long] = Array(250L, 251L, 258L, 277L, 314L) semaphore.release() class A(x: Int) { def get = x * 100 } defined class A val myUdf = udf((x: Int) => new A(x).get) myUdf: org.apache.spark.sql.expressions.UserDefinedFunction = ScalarUserDefinedFunction( Array( ... spark.udf.register("dummyUdf", myUdf) res37: org.apache.spark.sql.expressions.UserDefinedFunction = ScalarUserDefinedFunction( Array( ... spark.sql("select dummyUdf(id) from range(5)").as[Long].collect() res38: Array[Long] = Array(0L, 100L, 200L, 300L, 400L) semaphore.release() class A(x: Int) { def get = x * 15 } defined class A spark.udf.register("directUdf", (x: Int) => new A(x).get) res41: org.apache.spark.sql.expressions.UserDefinedFunction = ScalarUserDefinedFunction( Array( ... spark.sql("select directUdf(id) from range(5)").as[Long].collect() res42: Array[Long] = Array(0L, 15L, 30L, 45L, 60L) semaphore.release() val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value") df: org.apache.spark.sql.package.DataFrame = [id: string, value: int] spark.udf.register("simpleUDF", (v: Int) => v * v) res45: org.apache.spark.sql.expressions.UserDefinedFunction = ScalarUserDefinedFunction( Array( ... df.select($"id", call_udf("simpleUDF", $"value")).collect() res46: Array[org.apache.spark.sql.Row] = Array([id1,1], [id2,16], [id3,25]) semaphore.release() val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value") df: org.apache.spark.sql.package.DataFrame = [id: string, value: int] spark.udf.register("simpleUDF", (v: Int) => v * v) res49: org.apache.spark.sql.expressions.UserDefinedFunction = ScalarUserDefinedFunction( Array( ... df.select($"id", call_function("simpleUDF", $"value")).collect() res50: Array[org.apache.spark.sql.Row] = Array([id1,1], [id2,16], [id3,25]) semaphore.release() case class MyTestClass(value: Int) defined class MyTestClass spark.range(4). filter($"id" % 2 === 1). select($"id".cast("int").as("value")). as[MyTestClass]. collect(). map(mtc => s"MyTestClass(${mtc.value})"). mkString("[", ", ", "]") res53: String = "[MyTestClass(1), MyTestClass(3)]" semaphore.release() case class MyTestClass(value: Int) defined class MyTestClass spark.range(2).map(i => MyTestClass(i.toInt)).collect() res56: Array[MyTestClass] = Array(MyTestClass(0), MyTestClass(1)) semaphore.release() val add1 = udf((i: Long) => i + 1) add1: org.apache.spark.sql.expressions.UserDefinedFunction = ScalarUserDefinedFunction( Array( ... val query = { spark.readStream .format("rate") .option("rowsPerSecond", "10") .option("numPartitions", "1") .load() .withColumn("value", add1($"value")) .writeStream .format("memory") .queryName("my_sink") .start() } query: org.apache.spark.sql.streaming.StreamingQuery = org.apache.spark.sql.streaming.RemoteStreamingQuery79cdf37e var progress = query.lastProgress progress: org.apache.spark.sql.streaming.StreamingQueryProgress = null while (query.isActive && (progress == null || progress.numInputRows == 0)) { query.awaitTermination(100) progress = query.lastProgress } val noException = query.exception.isEmpty noException: Boolean = true query.stop() semaphore.release() ``` **After**: we can only see the content that is related to the test case `streaming works with REPL generated code` ``` val add1 = udf((i: Long) => i + 1) add1: org.apache.spark.sql.expressions.UserDefinedFunction = ScalarUserDefinedFunction( Array( ... val query = { spark.readStream .format("rate") .option("rowsPerSecond", "10") .option("numPartitions", "1") .load() .withColumn("value", add1($"value")) .writeStream .format("memory") .queryName("my_sink") .start() } query: org.apache.spark.sql.streaming.StreamingQuery = org.apache.spark.sql.streaming.RemoteStreamingQuery5429e19b var progress = query.lastProgress progress: org.apache.spark.sql.streaming.StreamingQueryProgress = null while (query.isActive && (progress == null || progress.numInputRows == 0)) { query.awaitTermination(100) progress = query.lastProgress } val noException = query.exception.isEmpty noException: Boolean = true query.stop() semaphore.release() ``` Closes #42512 from LuciferYang/SPARK-44824-35. Authored-by: yangjie01 <yangji...@baidu.com> Signed-off-by: Herman van Hovell <her...@databricks.com> --- .../src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala index f467aee73f2..b2971236147 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala @@ -83,6 +83,9 @@ class ReplE2ESuite extends RemoteSparkSession with BeforeAndAfterEach { override def afterEach(): Unit = { semaphore.drainPermits() + if (ammoniteOut != null) { + ammoniteOut.reset() + } } def runCommandsInShell(input: String): String = { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org