LuciferYang opened a new pull request, #42512:
URL: https://github.com/apache/spark/pull/42512
### What changes were proposed in this pull request?
This PR add `ammoniteOut.reset()` in the `afterEach` method of
`ReplE2ESuite` to ensure that the 'output' used for assertions in each test
case is only related to the current case and not all content.
### Why are the changes needed?
The current `ammoniteOut` records the output content of all executed tests,
without isolating between cases. This can lead to unexpected assertion results.
For example, adding 'assertContains("""String = "[MyTestClass(1),
MyTestClass(3)]"""", output)' in the following test case would still pass the
test because it is a result content printed to `ammoniteOut` in the previous
test case.
https://github.com/apache/spark/blob/2be20e54a2222f6cdf64e8486d1910133b43665f/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala#L283-L290
Hence, we need to clear the content in `ammoniteOut` after each test to
achieve isolation between test cases.
### Does this PR introduce _any_ user-facing change?
No, just for test
### How was this patch tested?
- Pass Github Actions
- Manual check
Prints the `output` after `val output = runCommandsInShell(input)` in the
the case `streaming works with REPL generated code`
https://github.com/apache/spark/blob/2be20e54a2222f6cdf64e8486d1910133b43665f/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala#L313-L318
run
```
build/sbt "connect-client-jvm/testOnly
org.apache.spark.sql.application.ReplE2ESuite" -Phive
```
**Before**: we can see the content of all test cases that have been executed
in the `ReplE2ESuite`
```
Spark session available as 'spark'.
_____ __ ______ __
/ ___/____ ____ ______/ /__ / ____/___ ____ ____ ___ _____/ /_
\__ \/ __ \/ __ `/ ___/ //_/ / / / __ \/ __ \/ __ \/ _ \/ ___/ __/
___/ / /_/ / /_/ / / / ,< / /___/ /_/ / / / / / / / __/ /__/ /_
/____/ .___/\__,_/_/ /_/|_| \____/\____/_/ /_/_/ /_/\___/\___/\__/
/_/
@
@ spark.sql("select 1").collect()
res0: Array[org.apache.spark.sql.Row] = Array([1])
@
@ semaphore.release()
@
@ class A(x: Int) { def get = x * 5 + 19 }
defined class A
@ def dummyUdf(x: Int): Int = new A(x).get
defined function dummyUdf
@ val myUdf = udf(dummyUdf _)
myUdf: org.apache.spark.sql.expressions.UserDefinedFunction =
ScalarUserDefinedFunction(
Array(
...
@ spark.range(5).select(myUdf(col("id"))).as[Int].collect()
res5: Array[Int] = Array(19, 24, 29, 34, 39)
@
@ semaphore.release()
@
@ class A(x: Int) { def get = x * 42 + 5 }
defined class A
@ val myUdf = udf((x: Int) => new A(x).get)
myUdf: org.apache.spark.sql.expressions.UserDefinedFunction =
ScalarUserDefinedFunction(
Array(
...
@ spark.range(5).select(myUdf(col("id"))).as[Int].collect()
res9: Array[Int] = Array(5, 47, 89, 131, 173)
@
@ semaphore.release()
@
@ class A(x: Int) { def get = x * 7 }
defined class A
@ val myUdf = udf((x: Int) => new A(x).get)
myUdf: org.apache.spark.sql.expressions.UserDefinedFunction =
ScalarUserDefinedFunction(
Array(
...
@ val modifiedUdf = myUdf.withName("myUdf").asNondeterministic()
modifiedUdf: org.apache.spark.sql.expressions.UserDefinedFunction =
ScalarUserDefinedFunction(
Array(
...
@ spark.range(5).select(modifiedUdf(col("id"))).as[Int].collect()
res14: Array[Int] = Array(0, 7, 14, 21, 28)
@
@ semaphore.release()
@
@ spark.range(10).filter(n => n % 2 == 0).collect()
res16: Array[java.lang.Long] = Array(0L, 2L, 4L, 6L, 8L)
@
@ semaphore.release()
@
@ import java.nio.file.Paths
import java.nio.file.Paths
@ def classLoadingTest(x: Int): Int = {
val classloader =
Option(Thread.currentThread().getContextClassLoader).getOrElse(getClass.getClassLoader)
val cls = Class.forName("com.example.Hello$", true, classloader)
val module = cls.getField("MODULE$").get(null)
cls.getMethod("test").invoke(module).asInstanceOf[Int]
}
defined function classLoadingTest
@ val classLoaderUdf = udf(classLoadingTest _)
classLoaderUdf: org.apache.spark.sql.expressions.UserDefinedFunction =
ScalarUserDefinedFunction(
Array(
...
@
@ val jarPath =
Paths.get("/Users/yangjie01/SourceCode/git/spark-mine-sbt/connector/connect/client/jvm/src/test/resources/TestHelloV2_2.12.jar").toUri
jarPath: java.net.URI =
file:///Users/yangjie01/SourceCode/git/spark-mine-sbt/connector/connect/client/jvm/src/test/resources/TestHelloV2_2.12.jar
@ spark.addArtifact(jarPath)
@
@ spark.range(5).select(classLoaderUdf(col("id"))).as[Int].collect()
res23: Array[Int] = Array(2, 2, 2, 2, 2)
@
@ semaphore.release()
@
@ import org.apache.spark.sql.api.java._
import org.apache.spark.sql.api.java._
@ import org.apache.spark.sql.types.LongType
import org.apache.spark.sql.types.LongType
@
@ val javaUdf = udf(new UDF1[Long, Long] {
override def call(num: Long): Long = num * num + 25L
}, LongType).asNondeterministic()
javaUdf: org.apache.spark.sql.expressions.UserDefinedFunction =
ScalarUserDefinedFunction(
Array(
...
@ spark.range(5).select(javaUdf(col("id"))).as[Long].collect()
res28: Array[Long] = Array(25L, 26L, 29L, 34L, 41L)
@
@ semaphore.release()
@
@ import org.apache.spark.sql.api.java._
import org.apache.spark.sql.api.java._
@ import org.apache.spark.sql.types.LongType
import org.apache.spark.sql.types.LongType
@
@ spark.udf.register("javaUdf", new UDF1[Long, Long] {
override def call(num: Long): Long = num * num * num + 250L
}, LongType)
@ spark.sql("select javaUdf(id) from range(5)").as[Long].collect()
res33: Array[Long] = Array(250L, 251L, 258L, 277L, 314L)
@
@ semaphore.release()
@
@ class A(x: Int) { def get = x * 100 }
defined class A
@ val myUdf = udf((x: Int) => new A(x).get)
myUdf: org.apache.spark.sql.expressions.UserDefinedFunction =
ScalarUserDefinedFunction(
Array(
...
@ spark.udf.register("dummyUdf", myUdf)
res37: org.apache.spark.sql.expressions.UserDefinedFunction =
ScalarUserDefinedFunction(
Array(
...
@ spark.sql("select dummyUdf(id) from range(5)").as[Long].collect()
res38: Array[Long] = Array(0L, 100L, 200L, 300L, 400L)
@
@ semaphore.release()
@
@ class A(x: Int) { def get = x * 15 }
defined class A
@ spark.udf.register("directUdf", (x: Int) => new A(x).get)
res41: org.apache.spark.sql.expressions.UserDefinedFunction =
ScalarUserDefinedFunction(
Array(
...
@ spark.sql("select directUdf(id) from range(5)").as[Long].collect()
res42: Array[Long] = Array(0L, 15L, 30L, 45L, 60L)
@
@ semaphore.release()
@
@ val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value")
df: org.apache.spark.sql.package.DataFrame = [id: string, value: int]
@ spark.udf.register("simpleUDF", (v: Int) => v * v)
res45: org.apache.spark.sql.expressions.UserDefinedFunction =
ScalarUserDefinedFunction(
Array(
...
@ df.select($"id", call_udf("simpleUDF", $"value")).collect()
res46: Array[org.apache.spark.sql.Row] = Array([id1,1], [id2,16], [id3,25])
@
semaphore.release()
@
@ val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value")
df: org.apache.spark.sql.package.DataFrame = [id: string, value: int]
@ spark.udf.register("simpleUDF", (v: Int) => v * v)
res49: org.apache.spark.sql.expressions.UserDefinedFunction =
ScalarUserDefinedFunction(
Array(
...
@ df.select($"id", call_function("simpleUDF", $"value")).collect()
res50: Array[org.apache.spark.sql.Row] = Array([id1,1], [id2,16], [id3,25])
@
@ semaphore.release()
@
@ case class MyTestClass(value: Int)
defined class MyTestClass
@ spark.range(4).
filter($"id" % 2 === 1).
select($"id".cast("int").as("value")).
as[MyTestClass].
collect().
map(mtc => s"MyTestClass(${mtc.value})").
mkString("[", ", ", "]")
res53: String = "[MyTestClass(1), MyTestClass(3)]"
@
@ semaphore.release()
@
@ case class MyTestClass(value: Int)
defined class MyTestClass
@ spark.range(2).map(i => MyTestClass(i.toInt)).collect()
res56: Array[MyTestClass] = Array(MyTestClass(0), MyTestClass(1))
@
@ semaphore.release()
@
@ val add1 = udf((i: Long) => i + 1)
add1: org.apache.spark.sql.expressions.UserDefinedFunction =
ScalarUserDefinedFunction(
Array(
...
@ val query = {
spark.readStream
.format("rate")
.option("rowsPerSecond", "10")
.option("numPartitions", "1")
.load()
.withColumn("value", add1($"value"))
.writeStream
.format("memory")
.queryName("my_sink")
.start()
}
query: org.apache.spark.sql.streaming.StreamingQuery =
org.apache.spark.sql.streaming.RemoteStreamingQuery@79cdf37e
@ var progress = query.lastProgress
progress: org.apache.spark.sql.streaming.StreamingQueryProgress = null
@ while (query.isActive && (progress == null || progress.numInputRows == 0))
{
query.awaitTermination(100)
progress = query.lastProgress
}
@ val noException = query.exception.isEmpty
noException: Boolean = true
@ query.stop()
@
@ semaphore.release()
```
**After**: we can only see the content that is related to the test case
`streaming works with REPL generated code`
```
@
@ val add1 = udf((i: Long) => i + 1)
add1: org.apache.spark.sql.expressions.UserDefinedFunction =
ScalarUserDefinedFunction(
Array(
...
@ val query = {
spark.readStream
.format("rate")
.option("rowsPerSecond", "10")
.option("numPartitions", "1")
.load()
.withColumn("value", add1($"value"))
.writeStream
.format("memory")
.queryName("my_sink")
.start()
}
query: org.apache.spark.sql.streaming.StreamingQuery =
org.apache.spark.sql.streaming.RemoteStreamingQuery@5429e19b
@ var progress = query.lastProgress
progress: org.apache.spark.sql.streaming.StreamingQueryProgress = null
@ while (query.isActive && (progress == null || progress.numInputRows == 0))
{
query.awaitTermination(100)
progress = query.lastProgress
}
@ val noException = query.exception.isEmpty
noException: Boolean = true
@ query.stop()
@
@ semaphore.release()
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]