LuciferYang opened a new pull request, #42512:
URL: https://github.com/apache/spark/pull/42512

   ### What changes were proposed in this pull request?
   This PR add `ammoniteOut.reset()` in the `afterEach` method of 
`ReplE2ESuite` to ensure that the 'output' used for assertions in each test 
case is only related to the current case and not all content.
   
   ### Why are the changes needed?
   The current `ammoniteOut` records the output content of all executed tests, 
without isolating between cases. This can lead to unexpected assertion results. 
   For example, adding 'assertContains("""String = "[MyTestClass(1), 
MyTestClass(3)]"""", output)' in the following test case would still pass the 
test because it is a result content printed to `ammoniteOut` in the previous 
test case. 
   
   
https://github.com/apache/spark/blob/2be20e54a2222f6cdf64e8486d1910133b43665f/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala#L283-L290
   
   Hence, we need to clear the content in `ammoniteOut` after each test to 
achieve isolation between test cases.
   
   ### Does this PR introduce _any_ user-facing change?
   No, just for test
   
   
   ### How was this patch tested?
   - Pass Github Actions
   - Manual check
   
   Prints the `output` after `val output = runCommandsInShell(input)` in the 
the case `streaming works with REPL generated code`
   
   
https://github.com/apache/spark/blob/2be20e54a2222f6cdf64e8486d1910133b43665f/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala#L313-L318
   
   run 
   
   ```
   build/sbt "connect-client-jvm/testOnly 
org.apache.spark.sql.application.ReplE2ESuite" -Phive 
   ```
   
   **Before**: we can see the content of all test cases that have been executed 
in the `ReplE2ESuite`
   
   
   ```
   Spark session available as 'spark'.
      _____                  __      ______                            __
     / ___/____  ____ ______/ /__   / ____/___  ____  ____  ___  _____/ /_
     \__ \/ __ \/ __ `/ ___/ //_/  / /   / __ \/ __ \/ __ \/ _ \/ ___/ __/
    ___/ / /_/ / /_/ / /  / ,<    / /___/ /_/ / / / / / / /  __/ /__/ /_
   /____/ .___/\__,_/_/  /_/|_|   \____/\____/_/ /_/_/ /_/\___/\___/\__/
       /_/
   
   @  
   
   @ spark.sql("select 1").collect() 
   res0: Array[org.apache.spark.sql.Row] = Array([1])
   
   @        
   
   @ semaphore.release() 
   
   
   @  
   
   @ class A(x: Int) { def get = x * 5 + 19 } 
   defined class A
   
   @ def dummyUdf(x: Int): Int = new A(x).get 
   defined function dummyUdf
   
   @ val myUdf = udf(dummyUdf _) 
   myUdf: org.apache.spark.sql.expressions.UserDefinedFunction = 
ScalarUserDefinedFunction(
     Array(
   ...
   
   @ spark.range(5).select(myUdf(col("id"))).as[Int].collect() 
   res5: Array[Int] = Array(19, 24, 29, 34, 39)
   
   @        
   
   @ semaphore.release() 
   
   
   @  
   
   @ class A(x: Int) { def get = x * 42 + 5 } 
   defined class A
   
   @ val myUdf = udf((x: Int) => new A(x).get) 
   myUdf: org.apache.spark.sql.expressions.UserDefinedFunction = 
ScalarUserDefinedFunction(
     Array(
   ...
   
   @ spark.range(5).select(myUdf(col("id"))).as[Int].collect() 
   res9: Array[Int] = Array(5, 47, 89, 131, 173)
   
   @        
   
   @ semaphore.release() 
   
   
   @  
   
   @ class A(x: Int) { def get = x * 7 } 
   defined class A
   
   @ val myUdf = udf((x: Int) => new A(x).get) 
   myUdf: org.apache.spark.sql.expressions.UserDefinedFunction = 
ScalarUserDefinedFunction(
     Array(
   ...
   
   @ val modifiedUdf = myUdf.withName("myUdf").asNondeterministic() 
   modifiedUdf: org.apache.spark.sql.expressions.UserDefinedFunction = 
ScalarUserDefinedFunction(
     Array(
   ...
   
   @ spark.range(5).select(modifiedUdf(col("id"))).as[Int].collect() 
   res14: Array[Int] = Array(0, 7, 14, 21, 28)
   
   @        
   
   @ semaphore.release() 
   
   
   @  
   
   @ spark.range(10).filter(n => n % 2 == 0).collect() 
   res16: Array[java.lang.Long] = Array(0L, 2L, 4L, 6L, 8L)
   
   @        
   
   @ semaphore.release() 
   
   
   @  
   
   @ import java.nio.file.Paths 
   import java.nio.file.Paths
   
   @ def classLoadingTest(x: Int): Int = {
       val classloader =
         
Option(Thread.currentThread().getContextClassLoader).getOrElse(getClass.getClassLoader)
       val cls = Class.forName("com.example.Hello$", true, classloader)
       val module = cls.getField("MODULE$").get(null)
       cls.getMethod("test").invoke(module).asInstanceOf[Int]
     } 
   defined function classLoadingTest
   
   @ val classLoaderUdf = udf(classLoadingTest _) 
   classLoaderUdf: org.apache.spark.sql.expressions.UserDefinedFunction = 
ScalarUserDefinedFunction(
     Array(
   ...
   
   @  
   
   @ val jarPath = 
Paths.get("/Users/yangjie01/SourceCode/git/spark-mine-sbt/connector/connect/client/jvm/src/test/resources/TestHelloV2_2.12.jar").toUri
 
   jarPath: java.net.URI = 
file:///Users/yangjie01/SourceCode/git/spark-mine-sbt/connector/connect/client/jvm/src/test/resources/TestHelloV2_2.12.jar
   
   @ spark.addArtifact(jarPath) 
   
   
   @  
   
   @ spark.range(5).select(classLoaderUdf(col("id"))).as[Int].collect() 
   res23: Array[Int] = Array(2, 2, 2, 2, 2)
   
   @        
   
   @ semaphore.release() 
   
   
   @  
   
   @ import org.apache.spark.sql.api.java._ 
   import org.apache.spark.sql.api.java._
   
   @ import org.apache.spark.sql.types.LongType 
   import org.apache.spark.sql.types.LongType
   
   @  
   
   @ val javaUdf = udf(new UDF1[Long, Long] {
       override def call(num: Long): Long = num * num + 25L
     }, LongType).asNondeterministic() 
   javaUdf: org.apache.spark.sql.expressions.UserDefinedFunction = 
ScalarUserDefinedFunction(
     Array(
   ...
   
   @ spark.range(5).select(javaUdf(col("id"))).as[Long].collect() 
   res28: Array[Long] = Array(25L, 26L, 29L, 34L, 41L)
   
   @        
   
   @ semaphore.release() 
   
   
   @  
   
   @ import org.apache.spark.sql.api.java._ 
   import org.apache.spark.sql.api.java._
   
   @ import org.apache.spark.sql.types.LongType 
   import org.apache.spark.sql.types.LongType
   
   @  
   
   @ spark.udf.register("javaUdf", new UDF1[Long, Long] {
       override def call(num: Long): Long = num * num * num + 250L
     }, LongType) 
   
   
   @ spark.sql("select javaUdf(id) from range(5)").as[Long].collect() 
   res33: Array[Long] = Array(250L, 251L, 258L, 277L, 314L)
   
   @        
   
   @ semaphore.release() 
   
   
   @  
   
   @ class A(x: Int) { def get = x * 100 } 
   defined class A
   
   @ val myUdf = udf((x: Int) => new A(x).get) 
   myUdf: org.apache.spark.sql.expressions.UserDefinedFunction = 
ScalarUserDefinedFunction(
     Array(
   ...
   
   @ spark.udf.register("dummyUdf", myUdf) 
   res37: org.apache.spark.sql.expressions.UserDefinedFunction = 
ScalarUserDefinedFunction(
     Array(
   ...
   
   @ spark.sql("select dummyUdf(id) from range(5)").as[Long].collect() 
   res38: Array[Long] = Array(0L, 100L, 200L, 300L, 400L)
   
   @        
   
   @ semaphore.release() 
   
   
   @  
   
   @ class A(x: Int) { def get = x * 15 } 
   defined class A
   
   @ spark.udf.register("directUdf", (x: Int) => new A(x).get) 
   res41: org.apache.spark.sql.expressions.UserDefinedFunction = 
ScalarUserDefinedFunction(
     Array(
   ...
   
   @ spark.sql("select directUdf(id) from range(5)").as[Long].collect() 
   res42: Array[Long] = Array(0L, 15L, 30L, 45L, 60L)
   
   @        
   
   @ semaphore.release() 
   
   
   @  
   
   @ val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value") 
   df: org.apache.spark.sql.package.DataFrame = [id: string, value: int]
   
   @ spark.udf.register("simpleUDF", (v: Int) => v * v) 
   res45: org.apache.spark.sql.expressions.UserDefinedFunction = 
ScalarUserDefinedFunction(
     Array(
   ...
   
   @ df.select($"id", call_udf("simpleUDF", $"value")).collect() 
   res46: Array[org.apache.spark.sql.Row] = Array([id1,1], [id2,16], [id3,25])
   
   @        
   
   semaphore.release() 
   
   
   @  
   
   @ val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value") 
   df: org.apache.spark.sql.package.DataFrame = [id: string, value: int]
   
   @ spark.udf.register("simpleUDF", (v: Int) => v * v) 
   res49: org.apache.spark.sql.expressions.UserDefinedFunction = 
ScalarUserDefinedFunction(
     Array(
   ...
   
   @ df.select($"id", call_function("simpleUDF", $"value")).collect() 
   res50: Array[org.apache.spark.sql.Row] = Array([id1,1], [id2,16], [id3,25])
   
   @        
   
   @ semaphore.release() 
   
   
   @  
   
   @ case class MyTestClass(value: Int) 
   defined class MyTestClass
   
   @ spark.range(4).
       filter($"id" % 2 === 1).
       select($"id".cast("int").as("value")).
       as[MyTestClass].
       collect().
       map(mtc => s"MyTestClass(${mtc.value})").
       mkString("[", ", ", "]") 
   res53: String = "[MyTestClass(1), MyTestClass(3)]"
   
   @            
   
   @ semaphore.release() 
   
   
   @  
   
   @ case class MyTestClass(value: Int) 
   defined class MyTestClass
   
   @ spark.range(2).map(i => MyTestClass(i.toInt)).collect() 
   res56: Array[MyTestClass] = Array(MyTestClass(0), MyTestClass(1))
   
   @        
   
   @ semaphore.release() 
   
   
   @  
   
   @ val add1 = udf((i: Long) => i + 1) 
   add1: org.apache.spark.sql.expressions.UserDefinedFunction = 
ScalarUserDefinedFunction(
     Array(
   ...
   
   @ val query = {
       spark.readStream
           .format("rate")
           .option("rowsPerSecond", "10")
           .option("numPartitions", "1")
           .load()
           .withColumn("value", add1($"value"))
           .writeStream
           .format("memory")
           .queryName("my_sink")
           .start()
     } 
   query: org.apache.spark.sql.streaming.StreamingQuery = 
org.apache.spark.sql.streaming.RemoteStreamingQuery@79cdf37e
   
   @ var progress = query.lastProgress 
   progress: org.apache.spark.sql.streaming.StreamingQueryProgress = null
   
   @ while (query.isActive && (progress == null || progress.numInputRows == 0)) 
{
       query.awaitTermination(100)
       progress = query.lastProgress
     } 
   
   
   @ val noException = query.exception.isEmpty 
   noException: Boolean = true
   
   @ query.stop() 
   
   
   @  
   
   @ semaphore.release() 
   ```
   
   **After**: we can only see the content that is related to the test case 
`streaming works with REPL generated code`
   
   ```
   @  
   
   @ val add1 = udf((i: Long) => i + 1) 
   add1: org.apache.spark.sql.expressions.UserDefinedFunction = 
ScalarUserDefinedFunction(
     Array(
   ...
   
   @ val query = {
       spark.readStream
           .format("rate")
           .option("rowsPerSecond", "10")
           .option("numPartitions", "1")
           .load()
           .withColumn("value", add1($"value"))
           .writeStream
           .format("memory")
           .queryName("my_sink")
           .start()
     } 
   query: org.apache.spark.sql.streaming.StreamingQuery = 
org.apache.spark.sql.streaming.RemoteStreamingQuery@5429e19b
   
   @ var progress = query.lastProgress 
   progress: org.apache.spark.sql.streaming.StreamingQueryProgress = null
   
   @ while (query.isActive && (progress == null || progress.numInputRows == 0)) 
{
       query.awaitTermination(100)
       progress = query.lastProgress
     } 
   
   
   @ val noException = query.exception.isEmpty 
   noException: Boolean = true
   
   @ query.stop() 
   
   
   @  
   
   @ semaphore.release() 
   
   ```


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to