This is an automated email from the ASF dual-hosted git repository.
yangjie01 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 7d96f12784d [SPARK-44824][CONNECT][TESTS] Reset `ammoniteOut` in the
`afterEach` method of `ReplE2ESuite`
7d96f12784d is described below
commit 7d96f12784de7f49eaf2bb0a0d8b5cb17d2ecf63
Author: yangjie01 <[email protected]>
AuthorDate: Wed Aug 16 17:48:36 2023 +0800
[SPARK-44824][CONNECT][TESTS] Reset `ammoniteOut` in the `afterEach` method
of `ReplE2ESuite`
### What changes were proposed in this pull request?
This PR add `ammoniteOut.reset()` in the `afterEach` method of
`ReplE2ESuite` to ensure that the 'output' used for assertions in each test
case is only related to the current case and not all content.
### Why are the changes needed?
The current `ammoniteOut` records the output content of all executed tests,
without isolating between cases. This can lead to unexpected assertion results.
For example, adding 'assertContains("""String = "[MyTestClass(1),
MyTestClass(3)]"""", output)' in the following test case would still pass the
test because it is a result content printed to `ammoniteOut` in the previous
test case.
https://github.com/apache/spark/blob/2be20e54a2222f6cdf64e8486d1910133b43665f/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala#L283-L290
Hence, we need to clear the content in `ammoniteOut` after each test to
achieve isolation between test cases.
### Does this PR introduce _any_ user-facing change?
No, just for test
### How was this patch tested?
- Pass Github Actions
- Manual check
Prints the `output` after `val output = runCommandsInShell(input)` in the
the case `streaming works with REPL generated code`
https://github.com/apache/spark/blob/2be20e54a2222f6cdf64e8486d1910133b43665f/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala#L313-L318
run
```
build/sbt "connect-client-jvm/testOnly
org.apache.spark.sql.application.ReplE2ESuite" -Phive
```
**Before**: we can see the content of all test cases that have been
executed in the `ReplE2ESuite`
```
Spark session available as 'spark'.
_____ __ ______ __
/ ___/____ ____ ______/ /__ / ____/___ ____ ____ ___ _____/ /_
\__ \/ __ \/ __ `/ ___/ //_/ / / / __ \/ __ \/ __ \/ _ \/ ___/ __/
___/ / /_/ / /_/ / / / ,< / /___/ /_/ / / / / / / / __/ /__/ /_
/____/ .___/\__,_/_/ /_/|_| \____/\____/_/ /_/_/ /_/\___/\___/\__/
/_/
spark.sql("select 1").collect()
res0: Array[org.apache.spark.sql.Row] = Array([1])
semaphore.release()
class A(x: Int) { def get = x * 5 + 19 }
defined class A
def dummyUdf(x: Int): Int = new A(x).get
defined function dummyUdf
val myUdf = udf(dummyUdf _)
myUdf: org.apache.spark.sql.expressions.UserDefinedFunction =
ScalarUserDefinedFunction(
Array(
...
spark.range(5).select(myUdf(col("id"))).as[Int].collect()
res5: Array[Int] = Array(19, 24, 29, 34, 39)
semaphore.release()
class A(x: Int) { def get = x * 42 + 5 }
defined class A
val myUdf = udf((x: Int) => new A(x).get)
myUdf: org.apache.spark.sql.expressions.UserDefinedFunction =
ScalarUserDefinedFunction(
Array(
...
spark.range(5).select(myUdf(col("id"))).as[Int].collect()
res9: Array[Int] = Array(5, 47, 89, 131, 173)
semaphore.release()
class A(x: Int) { def get = x * 7 }
defined class A
val myUdf = udf((x: Int) => new A(x).get)
myUdf: org.apache.spark.sql.expressions.UserDefinedFunction =
ScalarUserDefinedFunction(
Array(
...
val modifiedUdf = myUdf.withName("myUdf").asNondeterministic()
modifiedUdf: org.apache.spark.sql.expressions.UserDefinedFunction =
ScalarUserDefinedFunction(
Array(
...
spark.range(5).select(modifiedUdf(col("id"))).as[Int].collect()
res14: Array[Int] = Array(0, 7, 14, 21, 28)
semaphore.release()
spark.range(10).filter(n => n % 2 == 0).collect()
res16: Array[java.lang.Long] = Array(0L, 2L, 4L, 6L, 8L)
semaphore.release()
import java.nio.file.Paths
import java.nio.file.Paths
def classLoadingTest(x: Int): Int = {
val classloader =
Option(Thread.currentThread().getContextClassLoader).getOrElse(getClass.getClassLoader)
val cls = Class.forName("com.example.Hello$", true, classloader)
val module = cls.getField("MODULE$").get(null)
cls.getMethod("test").invoke(module).asInstanceOf[Int]
}
defined function classLoadingTest
val classLoaderUdf = udf(classLoadingTest _)
classLoaderUdf: org.apache.spark.sql.expressions.UserDefinedFunction =
ScalarUserDefinedFunction(
Array(
...
val jarPath =
Paths.get("/Users/yangjie01/SourceCode/git/spark-mine-sbt/connector/connect/client/jvm/src/test/resources/TestHelloV2_2.12.jar").toUri
jarPath: java.net.URI =
file:///Users/yangjie01/SourceCode/git/spark-mine-sbt/connector/connect/client/jvm/src/test/resources/TestHelloV2_2.12.jar
spark.addArtifact(jarPath)
spark.range(5).select(classLoaderUdf(col("id"))).as[Int].collect()
res23: Array[Int] = Array(2, 2, 2, 2, 2)
semaphore.release()
import org.apache.spark.sql.api.java._
import org.apache.spark.sql.api.java._
import org.apache.spark.sql.types.LongType
import org.apache.spark.sql.types.LongType
val javaUdf = udf(new UDF1[Long, Long] {
override def call(num: Long): Long = num * num + 25L
}, LongType).asNondeterministic()
javaUdf: org.apache.spark.sql.expressions.UserDefinedFunction =
ScalarUserDefinedFunction(
Array(
...
spark.range(5).select(javaUdf(col("id"))).as[Long].collect()
res28: Array[Long] = Array(25L, 26L, 29L, 34L, 41L)
semaphore.release()
import org.apache.spark.sql.api.java._
import org.apache.spark.sql.api.java._
import org.apache.spark.sql.types.LongType
import org.apache.spark.sql.types.LongType
spark.udf.register("javaUdf", new UDF1[Long, Long] {
override def call(num: Long): Long = num * num * num + 250L
}, LongType)
spark.sql("select javaUdf(id) from range(5)").as[Long].collect()
res33: Array[Long] = Array(250L, 251L, 258L, 277L, 314L)
semaphore.release()
class A(x: Int) { def get = x * 100 }
defined class A
val myUdf = udf((x: Int) => new A(x).get)
myUdf: org.apache.spark.sql.expressions.UserDefinedFunction =
ScalarUserDefinedFunction(
Array(
...
spark.udf.register("dummyUdf", myUdf)
res37: org.apache.spark.sql.expressions.UserDefinedFunction =
ScalarUserDefinedFunction(
Array(
...
spark.sql("select dummyUdf(id) from range(5)").as[Long].collect()
res38: Array[Long] = Array(0L, 100L, 200L, 300L, 400L)
semaphore.release()
class A(x: Int) { def get = x * 15 }
defined class A
spark.udf.register("directUdf", (x: Int) => new A(x).get)
res41: org.apache.spark.sql.expressions.UserDefinedFunction =
ScalarUserDefinedFunction(
Array(
...
spark.sql("select directUdf(id) from range(5)").as[Long].collect()
res42: Array[Long] = Array(0L, 15L, 30L, 45L, 60L)
semaphore.release()
val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value")
df: org.apache.spark.sql.package.DataFrame = [id: string, value: int]
spark.udf.register("simpleUDF", (v: Int) => v * v)
res45: org.apache.spark.sql.expressions.UserDefinedFunction =
ScalarUserDefinedFunction(
Array(
...
df.select($"id", call_udf("simpleUDF", $"value")).collect()
res46: Array[org.apache.spark.sql.Row] = Array([id1,1], [id2,16], [id3,25])
semaphore.release()
val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value")
df: org.apache.spark.sql.package.DataFrame = [id: string, value: int]
spark.udf.register("simpleUDF", (v: Int) => v * v)
res49: org.apache.spark.sql.expressions.UserDefinedFunction =
ScalarUserDefinedFunction(
Array(
...
df.select($"id", call_function("simpleUDF", $"value")).collect()
res50: Array[org.apache.spark.sql.Row] = Array([id1,1], [id2,16], [id3,25])
semaphore.release()
case class MyTestClass(value: Int)
defined class MyTestClass
spark.range(4).
filter($"id" % 2 === 1).
select($"id".cast("int").as("value")).
as[MyTestClass].
collect().
map(mtc => s"MyTestClass(${mtc.value})").
mkString("[", ", ", "]")
res53: String = "[MyTestClass(1), MyTestClass(3)]"
semaphore.release()
case class MyTestClass(value: Int)
defined class MyTestClass
spark.range(2).map(i => MyTestClass(i.toInt)).collect()
res56: Array[MyTestClass] = Array(MyTestClass(0), MyTestClass(1))
semaphore.release()
val add1 = udf((i: Long) => i + 1)
add1: org.apache.spark.sql.expressions.UserDefinedFunction =
ScalarUserDefinedFunction(
Array(
...
val query = {
spark.readStream
.format("rate")
.option("rowsPerSecond", "10")
.option("numPartitions", "1")
.load()
.withColumn("value", add1($"value"))
.writeStream
.format("memory")
.queryName("my_sink")
.start()
}
query: org.apache.spark.sql.streaming.StreamingQuery =
org.apache.spark.sql.streaming.RemoteStreamingQuery79cdf37e
var progress = query.lastProgress
progress: org.apache.spark.sql.streaming.StreamingQueryProgress = null
while (query.isActive && (progress == null || progress.numInputRows == 0))
{
query.awaitTermination(100)
progress = query.lastProgress
}
val noException = query.exception.isEmpty
noException: Boolean = true
query.stop()
semaphore.release()
```
**After**: we can only see the content that is related to the test case
`streaming works with REPL generated code`
```
val add1 = udf((i: Long) => i + 1)
add1: org.apache.spark.sql.expressions.UserDefinedFunction =
ScalarUserDefinedFunction(
Array(
...
val query = {
spark.readStream
.format("rate")
.option("rowsPerSecond", "10")
.option("numPartitions", "1")
.load()
.withColumn("value", add1($"value"))
.writeStream
.format("memory")
.queryName("my_sink")
.start()
}
query: org.apache.spark.sql.streaming.StreamingQuery =
org.apache.spark.sql.streaming.RemoteStreamingQuery5429e19b
var progress = query.lastProgress
progress: org.apache.spark.sql.streaming.StreamingQueryProgress = null
while (query.isActive && (progress == null || progress.numInputRows == 0))
{
query.awaitTermination(100)
progress = query.lastProgress
}
val noException = query.exception.isEmpty
noException: Boolean = true
query.stop()
semaphore.release()
```
Closes #42509 from LuciferYang/SPARK-44824.
Authored-by: yangjie01 <[email protected]>
Signed-off-by: yangjie01 <[email protected]>
---
.../src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala | 3 +++
1 file changed, 3 insertions(+)
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
index 13ca5caf0af..b26483edd09 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
@@ -83,6 +83,9 @@ class ReplE2ESuite extends RemoteSparkSession with
BeforeAndAfterEach {
override def afterEach(): Unit = {
semaphore.drainPermits()
+ if (ammoniteOut != null) {
+ ammoniteOut.reset()
+ }
}
def runCommandsInShell(input: String): String = {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]