This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 5e17b07b88e [SPARK-42953][CONNECT][FOLLOWUP] Fix maven test build for 
Scala client UDF tests
5e17b07b88e is described below

commit 5e17b07b88ea32aa9e6e56df4a901032db6c7919
Author: Zhen Li <zhenli...@users.noreply.github.com>
AuthorDate: Wed Apr 26 12:07:29 2023 -0400

    [SPARK-42953][CONNECT][FOLLOWUP] Fix maven test build for Scala client UDF 
tests
    
    ### What changes were proposed in this pull request?
    Moved UDFUtils to common.
    Fixed the client `UserDefinedFunctionE2ETestSuite` and 
`KeyValueGroupedDatasetE2ETestSuite` tests to be able to run on maven.
    
    Verified the code works with the following maven commands.
    ```
    build/mvn clean install -pl connector/connect/server -am -DskipTests
    build/mvn clean install -pl assembly -am -DskipTests
    build/mvn clean install -pl connector/connect/client/jvm
    ```
    
    ### Why are the changes needed?
    Fix maven failing tests
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Tests
    Fixed tests for https://github.com/apache/spark/pull/40581 and 
https://github.com/apache/spark/pull/40729
    
    Closes #40762 from zhenlineo/udf-mvn.
    
    Authored-by: Zhen Li <zhenli...@users.noreply.github.com>
    Signed-off-by: Herman van Hovell <her...@databricks.com>
---
 .../main/scala/org/apache/spark/sql/Dataset.scala  |  4 +-
 .../apache/spark/sql/KeyValueGroupedDataset.scala  |  2 +-
 .../sql/UserDefinedFunctionE2ETestSuite.scala      |  3 +-
 .../connect/client/util/IntegrationTestUtils.scala | 28 ++++++---
 .../connect/client/util/RemoteSparkSession.scala   | 73 ++++++++++++++++------
 .../spark/sql/connect/common}/UdfUtils.scala       | 11 ++--
 6 files changed, 84 insertions(+), 37 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
index 26dc0074bc2..3301b483b5e 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -28,8 +28,8 @@ import org.apache.spark.connect.proto
 import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
 import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._
 import org.apache.spark.sql.catalyst.expressions.RowOrdering
-import org.apache.spark.sql.connect.client.{SparkResult, UdfUtils}
-import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, 
StorageLevelProtoConverter}
+import org.apache.spark.sql.connect.client.SparkResult
+import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, 
StorageLevelProtoConverter, UdfUtils}
 import org.apache.spark.sql.expressions.ScalarUserDefinedFunction
 import org.apache.spark.sql.functions.{struct, to_json}
 import org.apache.spark.sql.streaming.DataStreamWriter
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
index 8f9722789fc..2d712bc4c51 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
@@ -25,7 +25,7 @@ import scala.language.existentials
 import org.apache.spark.api.java.function._
 import org.apache.spark.connect.proto
 import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
-import org.apache.spark.sql.connect.client.UdfUtils
+import org.apache.spark.sql.connect.common.UdfUtils
 import org.apache.spark.sql.expressions.ScalarUserDefinedFunction
 import org.apache.spark.sql.functions.col
 
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
index afe9ce5f751..b07d1459df5 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
@@ -30,8 +30,7 @@ import 
org.apache.spark.sql.connect.client.util.RemoteSparkSession
 import org.apache.spark.sql.functions.{col, udf}
 
 /**
- * All tests in this class requires client UDF artifacts synced with the 
server. TODO: It means
- * these tests only works with SBT for now.
+ * All tests in this class requires client UDF defined in this test class 
synced with the server.
  */
 class UserDefinedFunctionE2ETestSuite extends RemoteSparkSession {
   test("Dataset typed filter") {
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala
index 9196db175d2..7e34726b48e 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala
@@ -46,7 +46,7 @@ object IntegrationTestUtils {
     sys.props.getOrElse("spark.test.home", sys.env("SPARK_HOME"))
   }
 
-  private[connect] lazy val debugConfig: Seq[String] = {
+  private[connect] def debugConfigs: Seq[String] = {
     val log4j2 = 
s"$sparkHome/connector/connect/client/jvm/src/test/resources/log4j2.properties"
     if (isDebug) {
       Seq(
@@ -90,7 +90,19 @@ object IntegrationTestUtils {
       sbtName: String,
       mvnName: String,
       test: Boolean = false): File = {
-    val targetDir = new File(new File(sparkHome, path), "target")
+    val jar = tryFindJar(path, sbtName, mvnName, test).getOrElse(
+      throw new RuntimeException(
+        s"Failed to find the jar inside folder: ${getTargetFilePath(path)}"))
+    debug("Using jar: " + jar.getCanonicalPath)
+    jar
+  }
+
+  private[sql] def tryFindJar(
+      path: String,
+      sbtName: String,
+      mvnName: String,
+      test: Boolean = false): Option[File] = {
+    val targetDir = getTargetFilePath(path).toFile
     assert(
       targetDir.exists(),
       s"Fail to locate the target folder: '${targetDir.getCanonicalPath}'. " +
@@ -98,7 +110,9 @@ object IntegrationTestUtils {
         "Make sure the spark project jars has been built (e.g. using build/sbt 
package)" +
         "and the env variable `SPARK_HOME` is set correctly.")
     val suffix = if (test) "-tests.jar" else ".jar"
-    val jars = recursiveListFiles(targetDir).filter { f =>
+    // It is possible there are more than one: one built by maven, and another 
by SBT,
+    // Return the first one found.
+    recursiveListFiles(targetDir).find { f =>
       // SBT jar
       (f.getParentFile.getName == scalaDir &&
         f.getName.startsWith(sbtName) && f.getName.endsWith(suffix)) ||
@@ -107,10 +121,10 @@ object IntegrationTestUtils {
         f.getName.startsWith(mvnName) &&
         f.getName.endsWith(s"${org.apache.spark.SPARK_VERSION}$suffix"))
     }
-    // It is possible we found more than one: one built by maven, and another 
by SBT
-    assert(jars.nonEmpty, s"Failed to find the jar inside folder: 
${targetDir.getCanonicalPath}")
-    debug("Using jar: " + jars(0).getCanonicalPath)
-    jars(0) // return the first jar found
+  }
+
+  private def getTargetFilePath(path: String): java.nio.file.Path = {
+    Paths.get(sparkHome, path, "target").toAbsolutePath
   }
 
   private def recursiveListFiles(f: File): Array[File] = {
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala
index 2d8cc6d3298..235605e3121 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala
@@ -61,31 +61,14 @@ object SparkConnectServerUtils {
       "connector/connect/server",
       "spark-connect-assembly",
       "spark-connect").getCanonicalPath
-    val driverClassPath = connectJar + ":" +
-      findJar("sql/catalyst", "spark-catalyst", "spark-catalyst", test = 
true).getCanonicalPath
-    val catalogImplementation = if 
(IntegrationTestUtils.isSparkHiveJarAvailable) {
-      "hive"
-    } else {
-      // scalastyle:off println
-      println(
-        "Will start Spark Connect server with 
`spark.sql.catalogImplementation=in-memory`, " +
-          "some tests that rely on Hive will be ignored. If you don't want to 
skip them:\n" +
-          "1. Test with maven: run `build/mvn install -DskipTests -Phive` 
before testing\n" +
-          "2. Test with sbt: run test with `-Phive` profile")
-      // scalastyle:on println
-      "in-memory"
-    }
+
     val builder = Process(
       Seq(
         "bin/spark-submit",
         "--driver-class-path",
-        driverClassPath,
-        "--conf",
-        s"spark.connect.grpc.binding.port=$port",
-        "--conf",
-        
"spark.sql.catalog.testcat=org.apache.spark.sql.connector.catalog.InMemoryTableCatalog",
+        connectJar,
         "--conf",
-        s"spark.sql.catalogImplementation=$catalogImplementation") ++ 
debugConfig ++ Seq(
+        s"spark.connect.grpc.binding.port=$port") ++ testConfigs ++ 
debugConfigs ++ Seq(
         "--class",
         "org.apache.spark.sql.connect.SimpleSparkConnectService",
         connectJar),
@@ -102,6 +85,56 @@ object SparkConnectServerUtils {
     process
   }
 
+  /**
+   * As one shared spark will be started for all E2E tests, for tests that 
needs some special
+   * configs, we add them here
+   */
+  private def testConfigs: Seq[String] = {
+    // Use InMemoryTableCatalog for V2 writer tests
+    val writerV2Configs = {
+      val catalystTestJar = findJar( // To find InMemoryTableCatalog for V2 
writer tests
+        "sql/catalyst",
+        "spark-catalyst",
+        "spark-catalyst",
+        test = true).getCanonicalPath
+      Seq(
+        "--jars",
+        catalystTestJar,
+        "--conf",
+        
"spark.sql.catalog.testcat=org.apache.spark.sql.connector.catalog.InMemoryTableCatalog")
+    }
+
+    // Run tests using hive
+    val hiveTestConfigs = {
+      val catalogImplementation = if 
(IntegrationTestUtils.isSparkHiveJarAvailable) {
+        "hive"
+      } else {
+        // scalastyle:off println
+        println(
+          "Will start Spark Connect server with 
`spark.sql.catalogImplementation=in-memory`, " +
+            "some tests that rely on Hive will be ignored. If you don't want 
to skip them:\n" +
+            "1. Test with maven: run `build/mvn install -DskipTests -Phive` 
before testing\n" +
+            "2. Test with sbt: run test with `-Phive` profile")
+        // scalastyle:on println
+        "in-memory"
+      }
+      Seq("--conf", s"spark.sql.catalogImplementation=$catalogImplementation")
+    }
+
+    // For UDF maven E2E tests, the server needs the client code to find the 
UDFs defined in tests.
+    val udfTestConfigs = tryFindJar(
+      "connector/connect/client/jvm",
+      // SBT passes the client & test jars to the server process automatically.
+      // So we skip building or finding this jar for SBT.
+      "sbt-tests-do-not-need-this-jar",
+      "spark-connect-client-jvm",
+      test = true)
+      .map(clientTestJar => Seq("--jars", clientTestJar.getCanonicalPath))
+      .getOrElse(Seq.empty)
+
+    writerV2Configs ++ hiveTestConfigs ++ udfTestConfigs
+  }
+
   def start(): Unit = {
     assert(!stopped)
     sparkConnect
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/UdfUtils.scala
 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/UdfUtils.scala
similarity index 91%
rename from 
connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/UdfUtils.scala
rename to 
connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/UdfUtils.scala
index db47c3bf681..7cd251b245f 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/UdfUtils.scala
+++ 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/UdfUtils.scala
@@ -14,7 +14,7 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.spark.sql.connect.client
+package org.apache.spark.sql.connect.common
 
 import scala.collection.JavaConverters._
 
@@ -23,11 +23,12 @@ import org.apache.spark.sql.Row
 
 /**
  * Util functions to help convert input functions between typed filter, map, 
flatMap,
- * mapPartitions etc. These functions cannot be defined inside the client 
Dataset class as it will
- * cause Dataset sync conflicts when used together with UDFs. Thus we define 
them outside, in the
- * client package.
+ * mapPartitions etc. This class is shared between the client and the server 
so that when the
+ * methods are used in client UDFs, the server will be able to find them when 
actually executing
+ * the UDFs.
  */
-private[sql] object UdfUtils {
+@SerialVersionUID(8464839273647598302L)
+private[sql] object UdfUtils extends Serializable {
 
   def mapFuncToMapPartitionsAdaptor[T, U](f: T => U): Iterator[T] => 
Iterator[U] = _.map(f(_))
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to