This is an automated email from the ASF dual-hosted git repository.
hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new c32aee117b60 [SPARK-55243][CONNECT] Allow setting binary headers via
the -bin suffix in the Scala Connect client
c32aee117b60 is described below
commit c32aee117b60370e69ce5271c4efbe64d1982d3a
Author: Robert Dillitz <[email protected]>
AuthorDate: Thu Jan 29 15:56:40 2026 +0100
[SPARK-55243][CONNECT] Allow setting binary headers via the -bin suffix in
the Scala Connect client
### What changes were proposed in this pull request?
Automatically use the `Metadata.BINARY_BYTE_MARSHALLER` for `-bin` suffixed
header keys, assuming base64-encoded header value strings set through the Scala
Spark Connect client builder.
### Why are the changes needed?
The Scala Spark Connect client currently only allows setting
`Metadata.ASCII_STRING_MARSHALLER` headers and fails if one tries to put a
(binary) header with `-bin` key suffix:
```
[info] java.lang.IllegalArgumentException: ASCII header is named
test-bin. Only binary headers may end with -bin
[info] at
com.google.common.base.Preconditions.checkArgument(Preconditions.java:445)
[info] at io.grpc.Metadata$AsciiKey.<init>(Metadata.java:972)
[info] at io.grpc.Metadata$AsciiKey.<init>(Metadata.java:966)
[info] at io.grpc.Metadata$Key.of(Metadata.java:708)
[info] at io.grpc.Metadata$Key.of(Metadata.java:704)
[info] at
org.apache.spark.sql.connect.client.SparkConnectClient$MetadataHeaderClientInterceptor$$anon$2.$anonfun$start$1(SparkConnectClient.scala:1112)
[info] at
org.apache.spark.sql.connect.client.SparkConnectClient$MetadataHeaderClientInterceptor$$anon$2.$anonfun$start$1$adapted(SparkConnectClient.scala:1106)
[info] at scala.collection.immutable.Map$Map1.foreach(Map.scala:278)
[info] at
org.apache.spark.sql.connect.client.SparkConnectClient$MetadataHeaderClientInterceptor$$anon$2.start(SparkConnectClient.scala:1106)
[info] at io.grpc.stub.ClientCalls.startCall(ClientCalls.java:435)
```
### Does this PR introduce _any_ user-facing change?
Current behaviour: Fails for all header key-value pairs if the key has the
`-bin` suffix with an `IllegalArgumentException`.
New behaviour: Adds a `Metadata.BINARY_BYTE_MARSHALLER` header if the key
has a `-bin` suffix and the value string is base64-encoded.
### How was this patch tested?
Added a test to `SparkConnectClientSuite`.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #54016 from dillitz/fix-bin-header.
Authored-by: Robert Dillitz <[email protected]>
Signed-off-by: Herman van Hövell <[email protected]>
---
.../connect/client/SparkConnectClientSuite.scala | 78 ++++++++++++++++++++--
.../sql/connect/client/SparkConnectClient.scala | 13 +++-
2 files changed, 84 insertions(+), 7 deletions(-)
diff --git
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
index 98fc5dd78ee4..fc4a590716b4 100644
---
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
+++
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
@@ -16,13 +16,14 @@
*/
package org.apache.spark.sql.connect.client
-import java.util.UUID
+import java.nio.charset.StandardCharsets.UTF_8
+import java.util.{Base64, UUID}
import java.util.concurrent.TimeUnit
import scala.collection.mutable
import scala.jdk.CollectionConverters._
-import io.grpc.{CallOptions, Channel, ClientCall, ClientInterceptor,
MethodDescriptor, Server, Status, StatusRuntimeException}
+import io.grpc.{CallOptions, Channel, ClientCall, ClientInterceptor, Metadata,
MethodDescriptor, Server, ServerCall, ServerCallHandler, ServerInterceptor,
Status, StatusRuntimeException}
import io.grpc.netty.NettyServerBuilder
import io.grpc.stub.StreamObserver
import org.scalatest.concurrent.Eventually
@@ -42,12 +43,13 @@ class SparkConnectClientSuite extends ConnectFunSuite {
private var service: DummySparkConnectService = _
private var server: Server = _
- private def startDummyServer(port: Int): Unit = {
+ private def startDummyServer(port: Int, interceptors: Seq[ServerInterceptor]
= Seq()): Unit = {
service = new DummySparkConnectService
- server = NettyServerBuilder
+ val serverBuilder = NettyServerBuilder
.forPort(port)
.addService(service)
- .build()
+ interceptors.foreach(serverBuilder.intercept)
+ server = serverBuilder.build()
server.start()
}
@@ -622,6 +624,72 @@ class SparkConnectClientSuite extends ConnectFunSuite {
// The client should try to fetch the config only once.
assert(service.getAndClearLatestConfigRequests().size == 1)
}
+
+ test("SPARK-55243: Binary headers use the correct marshaller") {
+ class HeadersInterceptor extends ServerInterceptor {
+ var headers: Option[Metadata] = None
+
+ override def interceptCall[ReqT, RespT](
+ call: ServerCall[ReqT, RespT],
+ headers: Metadata,
+ next: ServerCallHandler[ReqT, RespT]): ServerCall.Listener[ReqT] = {
+ this.headers = Some(headers)
+ next.startCall(call, headers)
+ }
+ }
+
+ def buildClientWithHeader(key: String, value: String): SparkConnectClient
= {
+ SparkConnectClient
+ .builder()
+ .connectionString(s"sc://localhost:${server.getPort}")
+ .option(key, value)
+ .build()
+ }
+
+ val headerInterceptor = new HeadersInterceptor()
+ startDummyServer(0, Seq(headerInterceptor))
+
+ val keyName = "test-bin"
+ val key = Metadata.Key.of(keyName, Metadata.BINARY_BYTE_MARSHALLER)
+ val binaryData = "test-binary-data"
+ val base64EncodedValue =
Base64.getEncoder.encodeToString(binaryData.getBytes(UTF_8))
+
+ val plan = buildPlan("select * from range(10)")
+
+ // Successfully set and use base64-encoded -bin key.
+ client = buildClientWithHeader(keyName, base64EncodedValue)
+ client.execute(plan)
+
+ Eventually.eventually(timeout(5.seconds)) {
+ assert(headerInterceptor.headers.exists(_.containsKey(key)))
+ val bytes = headerInterceptor.headers.get.get(key)
+ assert(new String(bytes, UTF_8) == binaryData)
+ }
+
+ // Non base64-encoded -bin header throws IllegalArgumentException.
+ client = buildClientWithHeader(keyName, binaryData)
+
+ assertThrows[IllegalArgumentException] {
+ client.execute(plan)
+ }
+
+ // Non -bin headers keep using the ASCII marshaller.
+ val asciiKeyName = "test"
+ val asciiKey = Metadata.Key.of(asciiKeyName,
Metadata.ASCII_STRING_MARSHALLER)
+
+ headerInterceptor.headers = None // Reset captured headers.
+
+ client = buildClientWithHeader(asciiKeyName, base64EncodedValue)
+ client.execute(plan)
+
+ Eventually.eventually(timeout(5.seconds)) {
+ assert(headerInterceptor.headers.exists(_.containsKey(asciiKey)))
+ val value = headerInterceptor.headers.get.get(asciiKey)
+ assert(value == base64EncodedValue)
+ // No BINARY_BYTE_MARSHALLER header.
+ assert(!headerInterceptor.headers.exists(_.containsKey(key)))
+ }
+ }
}
class DummySparkConnectService() extends
SparkConnectServiceGrpc.SparkConnectServiceImplBase {
diff --git
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
index 5d36fc45f948..cac43c2cb67c 100644
---
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
+++
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
@@ -18,7 +18,8 @@
package org.apache.spark.sql.connect.client
import java.net.URI
-import java.util.{Locale, UUID}
+import java.nio.charset.StandardCharsets.UTF_8
+import java.util.{Base64, Locale, UUID}
import java.util.concurrent.Executor
import scala.collection.mutable
@@ -1093,6 +1094,8 @@ object SparkConnectClient {
*/
private[client] class MetadataHeaderClientInterceptor(metadata: Map[String,
String])
extends ClientInterceptor {
+ metadata.foreach { case (key, value) => assert(key != null && value !=
null) }
+
override def interceptCall[ReqT, RespT](
method: MethodDescriptor[ReqT, RespT],
callOptions: CallOptions,
@@ -1103,7 +1106,13 @@ object SparkConnectClient {
responseListener: ClientCall.Listener[RespT],
headers: Metadata): Unit = {
metadata.foreach { case (key, value) =>
- headers.put(Metadata.Key.of(key,
Metadata.ASCII_STRING_MARSHALLER), value)
+ if (key.endsWith(Metadata.BINARY_HEADER_SUFFIX)) {
+ // Expects a base64-encoded value string.
+ val valueByteArray =
Base64.getDecoder.decode(value.getBytes(UTF_8))
+ headers.put(Metadata.Key.of(key,
Metadata.BINARY_BYTE_MARSHALLER), valueByteArray)
+ } else {
+ headers.put(Metadata.Key.of(key,
Metadata.ASCII_STRING_MARSHALLER), value)
+ }
}
super.start(responseListener, headers)
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]