This is an automated email from the ASF dual-hosted git repository. hvanhovell pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 5e17b07b88e [SPARK-42953][CONNECT][FOLLOWUP] Fix maven test build for Scala client UDF tests 5e17b07b88e is described below commit 5e17b07b88ea32aa9e6e56df4a901032db6c7919 Author: Zhen Li <zhenli...@users.noreply.github.com> AuthorDate: Wed Apr 26 12:07:29 2023 -0400 [SPARK-42953][CONNECT][FOLLOWUP] Fix maven test build for Scala client UDF tests ### What changes were proposed in this pull request? Moved UDFUtils to common. Fixed the client `UserDefinedFunctionE2ETestSuite` and `KeyValueGroupedDatasetE2ETestSuite` tests to be able to run on maven. Verified the code works with the following maven commands. ``` build/mvn clean install -pl connector/connect/server -am -DskipTests build/mvn clean install -pl assembly -am -DskipTests build/mvn clean install -pl connector/connect/client/jvm ``` ### Why are the changes needed? Fix maven failing tests ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Tests Fixed tests for https://github.com/apache/spark/pull/40581 and https://github.com/apache/spark/pull/40729 Closes #40762 from zhenlineo/udf-mvn. Authored-by: Zhen Li <zhenli...@users.noreply.github.com> Signed-off-by: Herman van Hovell <her...@databricks.com> --- .../main/scala/org/apache/spark/sql/Dataset.scala | 4 +- .../apache/spark/sql/KeyValueGroupedDataset.scala | 2 +- .../sql/UserDefinedFunctionE2ETestSuite.scala | 3 +- .../connect/client/util/IntegrationTestUtils.scala | 28 ++++++--- .../connect/client/util/RemoteSparkSession.scala | 73 ++++++++++++++++------ .../spark/sql/connect/common}/UdfUtils.scala | 11 ++-- 6 files changed, 84 insertions(+), 37 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala index 26dc0074bc2..3301b483b5e 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -28,8 +28,8 @@ import org.apache.spark.connect.proto import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._ import org.apache.spark.sql.catalyst.expressions.RowOrdering -import org.apache.spark.sql.connect.client.{SparkResult, UdfUtils} -import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, StorageLevelProtoConverter} +import org.apache.spark.sql.connect.client.SparkResult +import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, StorageLevelProtoConverter, UdfUtils} import org.apache.spark.sql.expressions.ScalarUserDefinedFunction import org.apache.spark.sql.functions.{struct, to_json} import org.apache.spark.sql.streaming.DataStreamWriter diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 8f9722789fc..2d712bc4c51 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -25,7 +25,7 @@ import scala.language.existentials import org.apache.spark.api.java.function._ import org.apache.spark.connect.proto import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder -import org.apache.spark.sql.connect.client.UdfUtils +import org.apache.spark.sql.connect.common.UdfUtils import org.apache.spark.sql.expressions.ScalarUserDefinedFunction import org.apache.spark.sql.functions.col diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala index afe9ce5f751..b07d1459df5 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala @@ -30,8 +30,7 @@ import org.apache.spark.sql.connect.client.util.RemoteSparkSession import org.apache.spark.sql.functions.{col, udf} /** - * All tests in this class requires client UDF artifacts synced with the server. TODO: It means - * these tests only works with SBT for now. + * All tests in this class requires client UDF defined in this test class synced with the server. */ class UserDefinedFunctionE2ETestSuite extends RemoteSparkSession { test("Dataset typed filter") { diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala index 9196db175d2..7e34726b48e 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala @@ -46,7 +46,7 @@ object IntegrationTestUtils { sys.props.getOrElse("spark.test.home", sys.env("SPARK_HOME")) } - private[connect] lazy val debugConfig: Seq[String] = { + private[connect] def debugConfigs: Seq[String] = { val log4j2 = s"$sparkHome/connector/connect/client/jvm/src/test/resources/log4j2.properties" if (isDebug) { Seq( @@ -90,7 +90,19 @@ object IntegrationTestUtils { sbtName: String, mvnName: String, test: Boolean = false): File = { - val targetDir = new File(new File(sparkHome, path), "target") + val jar = tryFindJar(path, sbtName, mvnName, test).getOrElse( + throw new RuntimeException( + s"Failed to find the jar inside folder: ${getTargetFilePath(path)}")) + debug("Using jar: " + jar.getCanonicalPath) + jar + } + + private[sql] def tryFindJar( + path: String, + sbtName: String, + mvnName: String, + test: Boolean = false): Option[File] = { + val targetDir = getTargetFilePath(path).toFile assert( targetDir.exists(), s"Fail to locate the target folder: '${targetDir.getCanonicalPath}'. " + @@ -98,7 +110,9 @@ object IntegrationTestUtils { "Make sure the spark project jars has been built (e.g. using build/sbt package)" + "and the env variable `SPARK_HOME` is set correctly.") val suffix = if (test) "-tests.jar" else ".jar" - val jars = recursiveListFiles(targetDir).filter { f => + // It is possible there are more than one: one built by maven, and another by SBT, + // Return the first one found. + recursiveListFiles(targetDir).find { f => // SBT jar (f.getParentFile.getName == scalaDir && f.getName.startsWith(sbtName) && f.getName.endsWith(suffix)) || @@ -107,10 +121,10 @@ object IntegrationTestUtils { f.getName.startsWith(mvnName) && f.getName.endsWith(s"${org.apache.spark.SPARK_VERSION}$suffix")) } - // It is possible we found more than one: one built by maven, and another by SBT - assert(jars.nonEmpty, s"Failed to find the jar inside folder: ${targetDir.getCanonicalPath}") - debug("Using jar: " + jars(0).getCanonicalPath) - jars(0) // return the first jar found + } + + private def getTargetFilePath(path: String): java.nio.file.Path = { + Paths.get(sparkHome, path, "target").toAbsolutePath } private def recursiveListFiles(f: File): Array[File] = { diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala index 2d8cc6d3298..235605e3121 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala @@ -61,31 +61,14 @@ object SparkConnectServerUtils { "connector/connect/server", "spark-connect-assembly", "spark-connect").getCanonicalPath - val driverClassPath = connectJar + ":" + - findJar("sql/catalyst", "spark-catalyst", "spark-catalyst", test = true).getCanonicalPath - val catalogImplementation = if (IntegrationTestUtils.isSparkHiveJarAvailable) { - "hive" - } else { - // scalastyle:off println - println( - "Will start Spark Connect server with `spark.sql.catalogImplementation=in-memory`, " + - "some tests that rely on Hive will be ignored. If you don't want to skip them:\n" + - "1. Test with maven: run `build/mvn install -DskipTests -Phive` before testing\n" + - "2. Test with sbt: run test with `-Phive` profile") - // scalastyle:on println - "in-memory" - } + val builder = Process( Seq( "bin/spark-submit", "--driver-class-path", - driverClassPath, - "--conf", - s"spark.connect.grpc.binding.port=$port", - "--conf", - "spark.sql.catalog.testcat=org.apache.spark.sql.connector.catalog.InMemoryTableCatalog", + connectJar, "--conf", - s"spark.sql.catalogImplementation=$catalogImplementation") ++ debugConfig ++ Seq( + s"spark.connect.grpc.binding.port=$port") ++ testConfigs ++ debugConfigs ++ Seq( "--class", "org.apache.spark.sql.connect.SimpleSparkConnectService", connectJar), @@ -102,6 +85,56 @@ object SparkConnectServerUtils { process } + /** + * As one shared spark will be started for all E2E tests, for tests that needs some special + * configs, we add them here + */ + private def testConfigs: Seq[String] = { + // Use InMemoryTableCatalog for V2 writer tests + val writerV2Configs = { + val catalystTestJar = findJar( // To find InMemoryTableCatalog for V2 writer tests + "sql/catalyst", + "spark-catalyst", + "spark-catalyst", + test = true).getCanonicalPath + Seq( + "--jars", + catalystTestJar, + "--conf", + "spark.sql.catalog.testcat=org.apache.spark.sql.connector.catalog.InMemoryTableCatalog") + } + + // Run tests using hive + val hiveTestConfigs = { + val catalogImplementation = if (IntegrationTestUtils.isSparkHiveJarAvailable) { + "hive" + } else { + // scalastyle:off println + println( + "Will start Spark Connect server with `spark.sql.catalogImplementation=in-memory`, " + + "some tests that rely on Hive will be ignored. If you don't want to skip them:\n" + + "1. Test with maven: run `build/mvn install -DskipTests -Phive` before testing\n" + + "2. Test with sbt: run test with `-Phive` profile") + // scalastyle:on println + "in-memory" + } + Seq("--conf", s"spark.sql.catalogImplementation=$catalogImplementation") + } + + // For UDF maven E2E tests, the server needs the client code to find the UDFs defined in tests. + val udfTestConfigs = tryFindJar( + "connector/connect/client/jvm", + // SBT passes the client & test jars to the server process automatically. + // So we skip building or finding this jar for SBT. + "sbt-tests-do-not-need-this-jar", + "spark-connect-client-jvm", + test = true) + .map(clientTestJar => Seq("--jars", clientTestJar.getCanonicalPath)) + .getOrElse(Seq.empty) + + writerV2Configs ++ hiveTestConfigs ++ udfTestConfigs + } + def start(): Unit = { assert(!stopped) sparkConnect diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/UdfUtils.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/UdfUtils.scala similarity index 91% rename from connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/UdfUtils.scala rename to connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/UdfUtils.scala index db47c3bf681..7cd251b245f 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/UdfUtils.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/UdfUtils.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.connect.client +package org.apache.spark.sql.connect.common import scala.collection.JavaConverters._ @@ -23,11 +23,12 @@ import org.apache.spark.sql.Row /** * Util functions to help convert input functions between typed filter, map, flatMap, - * mapPartitions etc. These functions cannot be defined inside the client Dataset class as it will - * cause Dataset sync conflicts when used together with UDFs. Thus we define them outside, in the - * client package. + * mapPartitions etc. This class is shared between the client and the server so that when the + * methods are used in client UDFs, the server will be able to find them when actually executing + * the UDFs. */ -private[sql] object UdfUtils { +@SerialVersionUID(8464839273647598302L) +private[sql] object UdfUtils extends Serializable { def mapFuncToMapPartitionsAdaptor[T, U](f: T => U): Iterator[T] => Iterator[U] = _.map(f(_)) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org