This is an automated email from the ASF dual-hosted git repository.
hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 5e17b07b88e [SPARK-42953][CONNECT][FOLLOWUP] Fix maven test build for
Scala client UDF tests
5e17b07b88e is described below
commit 5e17b07b88ea32aa9e6e56df4a901032db6c7919
Author: Zhen Li <[email protected]>
AuthorDate: Wed Apr 26 12:07:29 2023 -0400
[SPARK-42953][CONNECT][FOLLOWUP] Fix maven test build for Scala client UDF
tests
### What changes were proposed in this pull request?
Moved UDFUtils to common.
Fixed the client `UserDefinedFunctionE2ETestSuite` and
`KeyValueGroupedDatasetE2ETestSuite` tests to be able to run on maven.
Verified the code works with the following maven commands.
```
build/mvn clean install -pl connector/connect/server -am -DskipTests
build/mvn clean install -pl assembly -am -DskipTests
build/mvn clean install -pl connector/connect/client/jvm
```
### Why are the changes needed?
Fix maven failing tests
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Tests
Fixed tests for https://github.com/apache/spark/pull/40581 and
https://github.com/apache/spark/pull/40729
Closes #40762 from zhenlineo/udf-mvn.
Authored-by: Zhen Li <[email protected]>
Signed-off-by: Herman van Hovell <[email protected]>
---
.../main/scala/org/apache/spark/sql/Dataset.scala | 4 +-
.../apache/spark/sql/KeyValueGroupedDataset.scala | 2 +-
.../sql/UserDefinedFunctionE2ETestSuite.scala | 3 +-
.../connect/client/util/IntegrationTestUtils.scala | 28 ++++++---
.../connect/client/util/RemoteSparkSession.scala | 73 ++++++++++++++++------
.../spark/sql/connect/common}/UdfUtils.scala | 11 ++--
6 files changed, 84 insertions(+), 37 deletions(-)
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
index 26dc0074bc2..3301b483b5e 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -28,8 +28,8 @@ import org.apache.spark.connect.proto
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._
import org.apache.spark.sql.catalyst.expressions.RowOrdering
-import org.apache.spark.sql.connect.client.{SparkResult, UdfUtils}
-import org.apache.spark.sql.connect.common.{DataTypeProtoConverter,
StorageLevelProtoConverter}
+import org.apache.spark.sql.connect.client.SparkResult
+import org.apache.spark.sql.connect.common.{DataTypeProtoConverter,
StorageLevelProtoConverter, UdfUtils}
import org.apache.spark.sql.expressions.ScalarUserDefinedFunction
import org.apache.spark.sql.functions.{struct, to_json}
import org.apache.spark.sql.streaming.DataStreamWriter
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
index 8f9722789fc..2d712bc4c51 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
@@ -25,7 +25,7 @@ import scala.language.existentials
import org.apache.spark.api.java.function._
import org.apache.spark.connect.proto
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
-import org.apache.spark.sql.connect.client.UdfUtils
+import org.apache.spark.sql.connect.common.UdfUtils
import org.apache.spark.sql.expressions.ScalarUserDefinedFunction
import org.apache.spark.sql.functions.col
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
index afe9ce5f751..b07d1459df5 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
@@ -30,8 +30,7 @@ import
org.apache.spark.sql.connect.client.util.RemoteSparkSession
import org.apache.spark.sql.functions.{col, udf}
/**
- * All tests in this class requires client UDF artifacts synced with the
server. TODO: It means
- * these tests only works with SBT for now.
+ * All tests in this class requires client UDF defined in this test class
synced with the server.
*/
class UserDefinedFunctionE2ETestSuite extends RemoteSparkSession {
test("Dataset typed filter") {
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala
index 9196db175d2..7e34726b48e 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala
@@ -46,7 +46,7 @@ object IntegrationTestUtils {
sys.props.getOrElse("spark.test.home", sys.env("SPARK_HOME"))
}
- private[connect] lazy val debugConfig: Seq[String] = {
+ private[connect] def debugConfigs: Seq[String] = {
val log4j2 =
s"$sparkHome/connector/connect/client/jvm/src/test/resources/log4j2.properties"
if (isDebug) {
Seq(
@@ -90,7 +90,19 @@ object IntegrationTestUtils {
sbtName: String,
mvnName: String,
test: Boolean = false): File = {
- val targetDir = new File(new File(sparkHome, path), "target")
+ val jar = tryFindJar(path, sbtName, mvnName, test).getOrElse(
+ throw new RuntimeException(
+ s"Failed to find the jar inside folder: ${getTargetFilePath(path)}"))
+ debug("Using jar: " + jar.getCanonicalPath)
+ jar
+ }
+
+ private[sql] def tryFindJar(
+ path: String,
+ sbtName: String,
+ mvnName: String,
+ test: Boolean = false): Option[File] = {
+ val targetDir = getTargetFilePath(path).toFile
assert(
targetDir.exists(),
s"Fail to locate the target folder: '${targetDir.getCanonicalPath}'. " +
@@ -98,7 +110,9 @@ object IntegrationTestUtils {
"Make sure the spark project jars has been built (e.g. using build/sbt
package)" +
"and the env variable `SPARK_HOME` is set correctly.")
val suffix = if (test) "-tests.jar" else ".jar"
- val jars = recursiveListFiles(targetDir).filter { f =>
+ // It is possible there are more than one: one built by maven, and another
by SBT,
+ // Return the first one found.
+ recursiveListFiles(targetDir).find { f =>
// SBT jar
(f.getParentFile.getName == scalaDir &&
f.getName.startsWith(sbtName) && f.getName.endsWith(suffix)) ||
@@ -107,10 +121,10 @@ object IntegrationTestUtils {
f.getName.startsWith(mvnName) &&
f.getName.endsWith(s"${org.apache.spark.SPARK_VERSION}$suffix"))
}
- // It is possible we found more than one: one built by maven, and another
by SBT
- assert(jars.nonEmpty, s"Failed to find the jar inside folder:
${targetDir.getCanonicalPath}")
- debug("Using jar: " + jars(0).getCanonicalPath)
- jars(0) // return the first jar found
+ }
+
+ private def getTargetFilePath(path: String): java.nio.file.Path = {
+ Paths.get(sparkHome, path, "target").toAbsolutePath
}
private def recursiveListFiles(f: File): Array[File] = {
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala
index 2d8cc6d3298..235605e3121 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala
@@ -61,31 +61,14 @@ object SparkConnectServerUtils {
"connector/connect/server",
"spark-connect-assembly",
"spark-connect").getCanonicalPath
- val driverClassPath = connectJar + ":" +
- findJar("sql/catalyst", "spark-catalyst", "spark-catalyst", test =
true).getCanonicalPath
- val catalogImplementation = if
(IntegrationTestUtils.isSparkHiveJarAvailable) {
- "hive"
- } else {
- // scalastyle:off println
- println(
- "Will start Spark Connect server with
`spark.sql.catalogImplementation=in-memory`, " +
- "some tests that rely on Hive will be ignored. If you don't want to
skip them:\n" +
- "1. Test with maven: run `build/mvn install -DskipTests -Phive`
before testing\n" +
- "2. Test with sbt: run test with `-Phive` profile")
- // scalastyle:on println
- "in-memory"
- }
+
val builder = Process(
Seq(
"bin/spark-submit",
"--driver-class-path",
- driverClassPath,
- "--conf",
- s"spark.connect.grpc.binding.port=$port",
- "--conf",
-
"spark.sql.catalog.testcat=org.apache.spark.sql.connector.catalog.InMemoryTableCatalog",
+ connectJar,
"--conf",
- s"spark.sql.catalogImplementation=$catalogImplementation") ++
debugConfig ++ Seq(
+ s"spark.connect.grpc.binding.port=$port") ++ testConfigs ++
debugConfigs ++ Seq(
"--class",
"org.apache.spark.sql.connect.SimpleSparkConnectService",
connectJar),
@@ -102,6 +85,56 @@ object SparkConnectServerUtils {
process
}
+ /**
+ * As one shared spark will be started for all E2E tests, for tests that
needs some special
+ * configs, we add them here
+ */
+ private def testConfigs: Seq[String] = {
+ // Use InMemoryTableCatalog for V2 writer tests
+ val writerV2Configs = {
+ val catalystTestJar = findJar( // To find InMemoryTableCatalog for V2
writer tests
+ "sql/catalyst",
+ "spark-catalyst",
+ "spark-catalyst",
+ test = true).getCanonicalPath
+ Seq(
+ "--jars",
+ catalystTestJar,
+ "--conf",
+
"spark.sql.catalog.testcat=org.apache.spark.sql.connector.catalog.InMemoryTableCatalog")
+ }
+
+ // Run tests using hive
+ val hiveTestConfigs = {
+ val catalogImplementation = if
(IntegrationTestUtils.isSparkHiveJarAvailable) {
+ "hive"
+ } else {
+ // scalastyle:off println
+ println(
+ "Will start Spark Connect server with
`spark.sql.catalogImplementation=in-memory`, " +
+ "some tests that rely on Hive will be ignored. If you don't want
to skip them:\n" +
+ "1. Test with maven: run `build/mvn install -DskipTests -Phive`
before testing\n" +
+ "2. Test with sbt: run test with `-Phive` profile")
+ // scalastyle:on println
+ "in-memory"
+ }
+ Seq("--conf", s"spark.sql.catalogImplementation=$catalogImplementation")
+ }
+
+ // For UDF maven E2E tests, the server needs the client code to find the
UDFs defined in tests.
+ val udfTestConfigs = tryFindJar(
+ "connector/connect/client/jvm",
+ // SBT passes the client & test jars to the server process automatically.
+ // So we skip building or finding this jar for SBT.
+ "sbt-tests-do-not-need-this-jar",
+ "spark-connect-client-jvm",
+ test = true)
+ .map(clientTestJar => Seq("--jars", clientTestJar.getCanonicalPath))
+ .getOrElse(Seq.empty)
+
+ writerV2Configs ++ hiveTestConfigs ++ udfTestConfigs
+ }
+
def start(): Unit = {
assert(!stopped)
sparkConnect
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/UdfUtils.scala
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/UdfUtils.scala
similarity index 91%
rename from
connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/UdfUtils.scala
rename to
connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/UdfUtils.scala
index db47c3bf681..7cd251b245f 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/UdfUtils.scala
+++
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/UdfUtils.scala
@@ -14,7 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.spark.sql.connect.client
+package org.apache.spark.sql.connect.common
import scala.collection.JavaConverters._
@@ -23,11 +23,12 @@ import org.apache.spark.sql.Row
/**
* Util functions to help convert input functions between typed filter, map,
flatMap,
- * mapPartitions etc. These functions cannot be defined inside the client
Dataset class as it will
- * cause Dataset sync conflicts when used together with UDFs. Thus we define
them outside, in the
- * client package.
+ * mapPartitions etc. This class is shared between the client and the server
so that when the
+ * methods are used in client UDFs, the server will be able to find them when
actually executing
+ * the UDFs.
*/
-private[sql] object UdfUtils {
+@SerialVersionUID(8464839273647598302L)
+private[sql] object UdfUtils extends Serializable {
def mapFuncToMapPartitionsAdaptor[T, U](f: T => U): Iterator[T] =>
Iterator[U] = _.map(f(_))
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]