This is an automated email from the ASF dual-hosted git repository. hvanhovell pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 83fe9b16ab5a [SPARK-47694][CONNECT] Make max message size configurable on the client side 83fe9b16ab5a is described below commit 83fe9b16ab5a2eec5f844d1e30488fe48223e29b Author: Robert Dillitz <robert.dill...@databricks.com> AuthorDate: Mon Apr 15 14:52:52 2024 -0400 [SPARK-47694][CONNECT] Make max message size configurable on the client side ### What changes were proposed in this pull request? Follow up to https://github.com/apache/spark/pull/40447. Allows to configure the currently hardcoded max message of 128MB on the client side for both the Scala and Python clients. Adds the option to the Scala client and improves the way we handle `channelOptions` in Python's `ChannelBuiler`. ### Why are the changes needed? Usability - I am aware of two different cases where these limits are hit: 1. The user is trying to create a large dataframe from local data. We either hit the` grpc.max_send_message_length` in the Python client ([currently hardcoded](https://github.com/apache/spark/pull/40447/files)) or the `maxInboundMessageSize` on the cluster side ([now configurable](https://github.com/apache/spark/pull/40447/files)). 2. The result from the cluster has a single row that is larger than 128MB, causing an `ExecutePlanResponse` that is larger than the client's `grpc.max_receive_message_length` (Python) or `channel.maxInboundMessageSize` (Scala) ([both hardcoded](https://github.com/apache/spark/pull/40447/files)). This gives the option to increase these limits on the client side. ### Does this PR introduce _any_ user-facing change? Scala: Adds option to set `grpcMaxMessageSize` to `SparkConnectClient.Builder` Python: No. ### How was this patch tested? Tests added. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #45842 from dillitz/SPARK-47694. Authored-by: Robert Dillitz <robert.dill...@databricks.com> Signed-off-by: Herman van Hovell <her...@databricks.com> --- .../connect/client/SparkConnectClientSuite.scala | 9 +++++++- .../sql/connect/client/SparkConnectClient.scala | 14 +++++++++++-- .../connect/client/SparkConnectClientParser.scala | 24 +++++++++++++--------- python/pyspark/sql/connect/client/core.py | 20 +++++++++++++----- .../sql/tests/connect/test_connect_session.py | 20 ++++++++++++++++++ 5 files changed, 69 insertions(+), 18 deletions(-) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala index 5a43cf014bdc..55f962b2a52c 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala @@ -310,7 +310,14 @@ class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach { assert(client.userAgent.contains("scala/")) assert(client.userAgent.contains("jvm/")) assert(client.userAgent.contains("os/")) - })) + }), + TestPackURI( + "sc://SPARK-47694:123/;grpc_max_message_size=1860", + isCorrect = true, + client => { + assert(client.configuration.grpcMaxMessageSize == 1860) + }), + TestPackURI("sc://SPARK-47694:123/;grpc_max_message_size=abc", isCorrect = false)) private def checkTestPack(testPack: TestPackURI): Unit = { val client = SparkConnectClient.builder().connectionString(testPack.connectionString).build() diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala index 746aaca6f559..d9d51c15a880 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala @@ -510,6 +510,7 @@ object SparkConnectClient { val PARAM_TOKEN = "token" val PARAM_USER_AGENT = "user_agent" val PARAM_SESSION_ID = "session_id" + val PARAM_GRPC_MAX_MESSAGE_SIZE = "grpc_max_message_size" } private def verifyURI(uri: URI): Unit = { @@ -558,6 +559,13 @@ object SparkConnectClient { def userAgent: String = _configuration.userAgent + def grpcMaxMessageSize(messageSize: Int): Builder = { + _configuration = _configuration.copy(grpcMaxMessageSize = messageSize) + this + } + + def grpcMaxMessageSize: Int = _configuration.grpcMaxMessageSize + def option(key: String, value: String): Builder = { _configuration = _configuration.copy(metadata = _configuration.metadata + ((key, value))) this @@ -584,6 +592,7 @@ object SparkConnectClient { case URIParams.PARAM_USE_SSL => if (java.lang.Boolean.valueOf(value)) enableSsl() else disableSsl() case URIParams.PARAM_SESSION_ID => sessionId(value) + case URIParams.PARAM_GRPC_MAX_MESSAGE_SIZE => grpcMaxMessageSize(value.toInt) case _ => option(key, value) } } @@ -693,7 +702,8 @@ object SparkConnectClient { retryPolicies: Seq[RetryPolicy] = RetryPolicy.defaultPolicies(), useReattachableExecute: Boolean = true, interceptors: List[ClientInterceptor] = List.empty, - sessionId: Option[String] = None) { + sessionId: Option[String] = None, + grpcMaxMessageSize: Int = ConnectCommon.CONNECT_GRPC_MAX_MESSAGE_SIZE) { def userContext: proto.UserContext = { val builder = proto.UserContext.newBuilder() @@ -731,7 +741,7 @@ object SparkConnectClient { interceptors.foreach(channelBuilder.intercept(_)) - channelBuilder.maxInboundMessageSize(ConnectCommon.CONNECT_GRPC_MAX_MESSAGE_SIZE) + channelBuilder.maxInboundMessageSize(grpcMaxMessageSize) channelBuilder.build() } diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClientParser.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClientParser.scala index cfb5823ee43e..7e137a6a3e05 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClientParser.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClientParser.scala @@ -32,16 +32,17 @@ private[sql] object SparkConnectClientParser { def usage(): String = s""" |Options: - | --remote REMOTE URI of the Spark Connect Server to connect to. - | --host HOST Host where the Spark Connect Server is running. - | --port PORT Port where the Spark Connect Server is running. - | --use_ssl Connect to the server using SSL. - | --token TOKEN Token to use for authentication. - | --user_id USER_ID Id of the user connecting. - | --user_name USER_NAME Name of the user connecting. - | --user_agent USER_AGENT The User-Agent Client information (only intended for logging purposes by the server). - | --session_id SESSION_ID Session Id of the user connecting. - | --option KEY=VALUE Key-value pair that is used to further configure the session. + | --remote REMOTE URI of the Spark Connect Server to connect to. + | --host HOST Host where the Spark Connect Server is running. + | --port PORT Port where the Spark Connect Server is running. + | --use_ssl Connect to the server using SSL. + | --token TOKEN Token to use for authentication. + | --user_id USER_ID Id of the user connecting. + | --user_name USER_NAME Name of the user connecting. + | --user_agent USER_AGENT The User-Agent Client information (only intended for logging purposes by the server). + | --session_id SESSION_ID Session Id of the user connecting. + | --grpc_max_message_size SIZE Maximum message size allowed for gRPC messages in bytes. + | --option KEY=VALUE Key-value pair that is used to further configure the session. """.stripMargin // scalastyle:on line.size.limit @@ -88,6 +89,9 @@ private[sql] object SparkConnectClientParser { s"--option should contain key=value, found ${tail.head} instead") } parse(tail.tail, builder.option(key, value)) + case "--grpc_max_message_size" :: tail => + val (value, remainder) = extract("--grpc_max_message_size", tail) + parse(remainder, builder.grpcMaxMessageSize(value.toInt)) case unsupported :: _ => throw new IllegalArgumentException(s"$unsupported is an unsupported argument.") } diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index 532d490d925e..667b93596c5f 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -115,11 +115,12 @@ class ChannelBuilder: PARAM_USER_ID = "user_id" PARAM_USER_AGENT = "user_agent" PARAM_SESSION_ID = "session_id" - MAX_MESSAGE_LENGTH = 128 * 1024 * 1024 + + GRPC_MAX_MESSAGE_LENGTH_DEFAULT = 128 * 1024 * 1024 GRPC_DEFAULT_OPTIONS = [ - ("grpc.max_send_message_length", MAX_MESSAGE_LENGTH), - ("grpc.max_receive_message_length", MAX_MESSAGE_LENGTH), + ("grpc.max_send_message_length", GRPC_MAX_MESSAGE_LENGTH_DEFAULT), + ("grpc.max_receive_message_length", GRPC_MAX_MESSAGE_LENGTH_DEFAULT), ] def __init__( @@ -129,10 +130,11 @@ class ChannelBuilder: ): self._interceptors: List[grpc.UnaryStreamClientInterceptor] = [] self._params: Dict[str, str] = params or dict() - self._channel_options: List[Tuple[str, Any]] = ChannelBuilder.GRPC_DEFAULT_OPTIONS + self._channel_options: List[Tuple[str, Any]] = ChannelBuilder.GRPC_DEFAULT_OPTIONS.copy() if channelOptions is not None: - self._channel_options = self._channel_options + channelOptions + for key, value in channelOptions: + self.setChannelOption(key, value) def get(self, key: str) -> Any: """ @@ -152,6 +154,14 @@ class ChannelBuilder: def set(self, key: str, value: Any) -> None: self._params[key] = value + def setChannelOption(self, key: str, value: Any) -> None: + # overwrite option if it exists already else append it + for i, option in enumerate(self._channel_options): + if option[0] == key: + self._channel_options[i] = (key, value) + return + self._channel_options.append((key, value)) + def add_interceptor(self, interceptor: grpc.UnaryStreamClientInterceptor) -> None: self._interceptors.append(interceptor) diff --git a/python/pyspark/sql/tests/connect/test_connect_session.py b/python/pyspark/sql/tests/connect/test_connect_session.py index 4d6127b5be8b..1caf3525cfbb 100644 --- a/python/pyspark/sql/tests/connect/test_connect_session.py +++ b/python/pyspark/sql/tests/connect/test_connect_session.py @@ -498,6 +498,26 @@ class ChannelBuilderTests(unittest.TestCase): chan = DefaultChannelBuilder("sc://host/") self.assertIsNone(chan.session_id) + def test_channel_options(self): + # SPARK-47694 + chan = DefaultChannelBuilder( + "sc://host", [("grpc.max_send_message_length", 1860), ("test", "robert")] + ) + options = chan._channel_options + self.assertEqual( + [k for k, _ in options].count("grpc.max_send_message_length"), + 1, + "only one occurrence for defaults", + ) + self.assertEqual( + next(v for k, v in options if k == "grpc.max_send_message_length"), + 1860, + "overwrites defaults", + ) + self.assertEqual( + next(v for k, v in options if k == "test"), "robert", "new values are picked up" + ) + if __name__ == "__main__": from pyspark.sql.tests.connect.test_connect_session import * # noqa: F401 --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org