This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new c32aee117b60 [SPARK-55243][CONNECT] Allow setting binary headers via 
the -bin suffix in the Scala Connect client
c32aee117b60 is described below

commit c32aee117b60370e69ce5271c4efbe64d1982d3a
Author: Robert Dillitz <[email protected]>
AuthorDate: Thu Jan 29 15:56:40 2026 +0100

    [SPARK-55243][CONNECT] Allow setting binary headers via the -bin suffix in 
the Scala Connect client
    
    ### What changes were proposed in this pull request?
    Automatically use the `Metadata.BINARY_BYTE_MARSHALLER` for `-bin` suffixed 
header keys, assuming base64-encoded header value strings set through the Scala 
Spark Connect client builder.
    
    ### Why are the changes needed?
    The Scala Spark Connect client currently only allows setting 
`Metadata.ASCII_STRING_MARSHALLER` headers and fails if one tries to put a 
(binary) header with `-bin` key suffix:
    ```
    [info]   java.lang.IllegalArgumentException: ASCII header is named 
test-bin.  Only binary headers may end with -bin
    [info]   at 
com.google.common.base.Preconditions.checkArgument(Preconditions.java:445)
    [info]   at io.grpc.Metadata$AsciiKey.<init>(Metadata.java:972)
    [info]   at io.grpc.Metadata$AsciiKey.<init>(Metadata.java:966)
    [info]   at io.grpc.Metadata$Key.of(Metadata.java:708)
    [info]   at io.grpc.Metadata$Key.of(Metadata.java:704)
    [info]   at 
org.apache.spark.sql.connect.client.SparkConnectClient$MetadataHeaderClientInterceptor$$anon$2.$anonfun$start$1(SparkConnectClient.scala:1112)
    [info]   at 
org.apache.spark.sql.connect.client.SparkConnectClient$MetadataHeaderClientInterceptor$$anon$2.$anonfun$start$1$adapted(SparkConnectClient.scala:1106)
    [info]   at scala.collection.immutable.Map$Map1.foreach(Map.scala:278)
    [info]   at 
org.apache.spark.sql.connect.client.SparkConnectClient$MetadataHeaderClientInterceptor$$anon$2.start(SparkConnectClient.scala:1106)
    [info]   at io.grpc.stub.ClientCalls.startCall(ClientCalls.java:435)
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    Current behaviour: Fails for all header key-value pairs if the key has the 
`-bin` suffix with an `IllegalArgumentException`.
    
    New behaviour: Adds a `Metadata.BINARY_BYTE_MARSHALLER` header if the key 
has a `-bin` suffix and the value string is base64-encoded.
    
    ### How was this patch tested?
    Added a test to `SparkConnectClientSuite`.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #54016 from dillitz/fix-bin-header.
    
    Authored-by: Robert Dillitz <[email protected]>
    Signed-off-by: Herman van Hövell <[email protected]>
---
 .../connect/client/SparkConnectClientSuite.scala   | 78 ++++++++++++++++++++--
 .../sql/connect/client/SparkConnectClient.scala    | 13 +++-
 2 files changed, 84 insertions(+), 7 deletions(-)

diff --git 
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
 
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
index 98fc5dd78ee4..fc4a590716b4 100644
--- 
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
+++ 
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
@@ -16,13 +16,14 @@
  */
 package org.apache.spark.sql.connect.client
 
-import java.util.UUID
+import java.nio.charset.StandardCharsets.UTF_8
+import java.util.{Base64, UUID}
 import java.util.concurrent.TimeUnit
 
 import scala.collection.mutable
 import scala.jdk.CollectionConverters._
 
-import io.grpc.{CallOptions, Channel, ClientCall, ClientInterceptor, 
MethodDescriptor, Server, Status, StatusRuntimeException}
+import io.grpc.{CallOptions, Channel, ClientCall, ClientInterceptor, Metadata, 
MethodDescriptor, Server, ServerCall, ServerCallHandler, ServerInterceptor, 
Status, StatusRuntimeException}
 import io.grpc.netty.NettyServerBuilder
 import io.grpc.stub.StreamObserver
 import org.scalatest.concurrent.Eventually
@@ -42,12 +43,13 @@ class SparkConnectClientSuite extends ConnectFunSuite {
   private var service: DummySparkConnectService = _
   private var server: Server = _
 
-  private def startDummyServer(port: Int): Unit = {
+  private def startDummyServer(port: Int, interceptors: Seq[ServerInterceptor] 
= Seq()): Unit = {
     service = new DummySparkConnectService
-    server = NettyServerBuilder
+    val serverBuilder = NettyServerBuilder
       .forPort(port)
       .addService(service)
-      .build()
+    interceptors.foreach(serverBuilder.intercept)
+    server = serverBuilder.build()
     server.start()
   }
 
@@ -622,6 +624,72 @@ class SparkConnectClientSuite extends ConnectFunSuite {
     // The client should try to fetch the config only once.
     assert(service.getAndClearLatestConfigRequests().size == 1)
   }
+
+  test("SPARK-55243: Binary headers use the correct marshaller") {
+    class HeadersInterceptor extends ServerInterceptor {
+      var headers: Option[Metadata] = None
+
+      override def interceptCall[ReqT, RespT](
+          call: ServerCall[ReqT, RespT],
+          headers: Metadata,
+          next: ServerCallHandler[ReqT, RespT]): ServerCall.Listener[ReqT] = {
+        this.headers = Some(headers)
+        next.startCall(call, headers)
+      }
+    }
+
+    def buildClientWithHeader(key: String, value: String): SparkConnectClient 
= {
+      SparkConnectClient
+        .builder()
+        .connectionString(s"sc://localhost:${server.getPort}")
+        .option(key, value)
+        .build()
+    }
+
+    val headerInterceptor = new HeadersInterceptor()
+    startDummyServer(0, Seq(headerInterceptor))
+
+    val keyName = "test-bin"
+    val key = Metadata.Key.of(keyName, Metadata.BINARY_BYTE_MARSHALLER)
+    val binaryData = "test-binary-data"
+    val base64EncodedValue = 
Base64.getEncoder.encodeToString(binaryData.getBytes(UTF_8))
+
+    val plan = buildPlan("select * from range(10)")
+
+    // Successfully set and use base64-encoded -bin key.
+    client = buildClientWithHeader(keyName, base64EncodedValue)
+    client.execute(plan)
+
+    Eventually.eventually(timeout(5.seconds)) {
+      assert(headerInterceptor.headers.exists(_.containsKey(key)))
+      val bytes = headerInterceptor.headers.get.get(key)
+      assert(new String(bytes, UTF_8) == binaryData)
+    }
+
+    // Non base64-encoded -bin header throws IllegalArgumentException.
+    client = buildClientWithHeader(keyName, binaryData)
+
+    assertThrows[IllegalArgumentException] {
+      client.execute(plan)
+    }
+
+    // Non -bin headers keep using the ASCII marshaller.
+    val asciiKeyName = "test"
+    val asciiKey = Metadata.Key.of(asciiKeyName, 
Metadata.ASCII_STRING_MARSHALLER)
+
+    headerInterceptor.headers = None // Reset captured headers.
+
+    client = buildClientWithHeader(asciiKeyName, base64EncodedValue)
+    client.execute(plan)
+
+    Eventually.eventually(timeout(5.seconds)) {
+      assert(headerInterceptor.headers.exists(_.containsKey(asciiKey)))
+      val value = headerInterceptor.headers.get.get(asciiKey)
+      assert(value == base64EncodedValue)
+      // No BINARY_BYTE_MARSHALLER header.
+      assert(!headerInterceptor.headers.exists(_.containsKey(key)))
+    }
+  }
 }
 
 class DummySparkConnectService() extends 
SparkConnectServiceGrpc.SparkConnectServiceImplBase {
diff --git 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
index 5d36fc45f948..cac43c2cb67c 100644
--- 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
+++ 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
@@ -18,7 +18,8 @@
 package org.apache.spark.sql.connect.client
 
 import java.net.URI
-import java.util.{Locale, UUID}
+import java.nio.charset.StandardCharsets.UTF_8
+import java.util.{Base64, Locale, UUID}
 import java.util.concurrent.Executor
 
 import scala.collection.mutable
@@ -1093,6 +1094,8 @@ object SparkConnectClient {
    */
   private[client] class MetadataHeaderClientInterceptor(metadata: Map[String, 
String])
       extends ClientInterceptor {
+    metadata.foreach { case (key, value) => assert(key != null && value != 
null) }
+
     override def interceptCall[ReqT, RespT](
         method: MethodDescriptor[ReqT, RespT],
         callOptions: CallOptions,
@@ -1103,7 +1106,13 @@ object SparkConnectClient {
             responseListener: ClientCall.Listener[RespT],
             headers: Metadata): Unit = {
           metadata.foreach { case (key, value) =>
-            headers.put(Metadata.Key.of(key, 
Metadata.ASCII_STRING_MARSHALLER), value)
+            if (key.endsWith(Metadata.BINARY_HEADER_SUFFIX)) {
+              // Expects a base64-encoded value string.
+              val valueByteArray = 
Base64.getDecoder.decode(value.getBytes(UTF_8))
+              headers.put(Metadata.Key.of(key, 
Metadata.BINARY_BYTE_MARSHALLER), valueByteArray)
+            } else {
+              headers.put(Metadata.Key.of(key, 
Metadata.ASCII_STRING_MARSHALLER), value)
+            }
           }
           super.start(responseListener, headers)
         }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to