This is an automated email from the ASF dual-hosted git repository.
sarutak pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new b5ed5220267d [SPARK-57066][SECURITY] Use constant-time comparison for
authentication secrets
b5ed5220267d is described below
commit b5ed5220267d639f0fae73d1fb1b4de7e84adecc
Author: Kousuke Saruta <[email protected]>
AuthorDate: Wed May 27 15:10:56 2026 +0900
[SPARK-57066][SECURITY] Use constant-time comparison for authentication
secrets
### What changes were proposed in this pull request?
Use constant-time comparison for authentication secret/token validation in
all occurrences:
- `SocketAuthHelper.scala`: socket authentication between Spark processes
(`MessageDigest.isEqual`)
- `PreSharedKeyAuthenticationInterceptor.scala`: Spark Connect pre-shared
key authentication (`MessageDigest.isEqual`)
- `RBackendAuthHandler.scala`: R backend authentication
(`MessageDigest.isEqual`)
- `python/pyspark/accumulators.py`: accumulator server token authentication
(`hmac.compare_digest`)
All previously used standard string equality (`==` / `!=`), which is
vulnerable to timing attacks. An attacker can infer the correct secret one
character at a time by measuring response time differences.
### Why are the changes needed?
Standard string comparison short-circuits on the first mismatched
character, leaking information about how many leading characters are correct.
This reduces the brute-force complexity from O(C^N) to O(C*N) where C is the
character set size and N is the secret length.
`java.security.MessageDigest.isEqual()` (Scala/Java) and
`hmac.compare_digest()` (Python) always compare all bytes regardless of
content, eliminating the timing side channel.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
GA.
### Was this patch authored or co-authored using generative AI tooling?
Generated-by: Claude (via Kiro CLI, auto model selection)
Closes #56108 from sarutak/fix-timing-attack-auth.
Authored-by: Kousuke Saruta <[email protected]>
Signed-off-by: Kousuke Saruta <[email protected]>
---
.../main/scala/org/apache/spark/api/r/RBackendAuthHandler.scala | 4 +++-
.../main/scala/org/apache/spark/security/SocketAuthHelper.scala | 3 ++-
python/pyspark/accumulators.py | 3 ++-
.../connect/service/PreSharedKeyAuthenticationInterceptor.scala | 7 ++++++-
4 files changed, 13 insertions(+), 4 deletions(-)
diff --git
a/core/src/main/scala/org/apache/spark/api/r/RBackendAuthHandler.scala
b/core/src/main/scala/org/apache/spark/api/r/RBackendAuthHandler.scala
index 8cd95ee653eb..7c704c3d2b37 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RBackendAuthHandler.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RBackendAuthHandler.scala
@@ -19,6 +19,7 @@ package org.apache.spark.api.r
import java.io.{ByteArrayOutputStream, DataOutputStream}
import java.nio.charset.StandardCharsets.UTF_8
+import java.security.MessageDigest
import io.netty.channel.{Channel, ChannelHandlerContext,
SimpleChannelInboundHandler}
@@ -34,7 +35,8 @@ private class RBackendAuthHandler(secret: String)
// The R code adds a null terminator to serialized strings, so ignore it
here.
val clientSecret = new String(msg, 0, msg.length - 1, UTF_8)
try {
- require(secret == clientSecret, "Auth secret mismatch.")
+ require(MessageDigest.isEqual(secret.getBytes(UTF_8),
clientSecret.getBytes(UTF_8)),
+ "Auth secret mismatch.")
ctx.pipeline().remove(this)
writeReply("ok", ctx.channel())
} catch {
diff --git
a/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala
b/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala
index ecebb97ecfc1..d2a81e56265c 100644
--- a/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala
+++ b/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala
@@ -21,6 +21,7 @@ import java.io.{DataInputStream, DataOutputStream}
import java.net.Socket
import java.nio.channels.SocketChannel
import java.nio.charset.StandardCharsets.UTF_8
+import java.security.MessageDigest
import org.apache.spark.SparkConf
import org.apache.spark.internal.config.Python.{PYTHON_UNIX_DOMAIN_SOCKET_DIR,
PYTHON_UNIX_DOMAIN_SOCKET_ENABLED}
@@ -65,7 +66,7 @@ private[spark] class SocketAuthHelper(val conf: SparkConf) {
try {
s.setSoTimeout(10000)
val clientSecret = readUtf8(s)
- if (secret == clientSecret) {
+ if (MessageDigest.isEqual(secret.getBytes(UTF_8),
clientSecret.getBytes(UTF_8))) {
writeUtf8("ok", s)
shouldClose = false
} else {
diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py
index 0bce2934cb81..fcfa347092ee 100644
--- a/python/pyspark/accumulators.py
+++ b/python/pyspark/accumulators.py
@@ -17,6 +17,7 @@
import os
import sys
+import hmac
import select
import struct
import socketserver
@@ -310,7 +311,7 @@ class
UpdateRequestHandler(socketserver.StreamRequestHandler):
received_token: Union[bytes, str] =
self.rfile.read(len(auth_token))
if isinstance(received_token, bytes):
received_token = received_token.decode("utf-8")
- if received_token == auth_token:
+ if hmac.compare_digest(received_token, auth_token):
accum_updates()
# we've authenticated, we can break out of the first loop now
return True
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/PreSharedKeyAuthenticationInterceptor.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/PreSharedKeyAuthenticationInterceptor.scala
index 5d7cc65358eb..b997f9d0d910 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/PreSharedKeyAuthenticationInterceptor.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/PreSharedKeyAuthenticationInterceptor.scala
@@ -17,6 +17,9 @@
package org.apache.spark.sql.connect.service
+import java.nio.charset.StandardCharsets.UTF_8
+import java.security.MessageDigest
+
import io.grpc.{Metadata, ServerCall, ServerCallHandler, ServerInterceptor,
Status}
class PreSharedKeyAuthenticationInterceptor(token: String) extends
ServerInterceptor {
@@ -36,7 +39,9 @@ class PreSharedKeyAuthenticationInterceptor(token: String)
extends ServerInterce
val status = Status.UNAUTHENTICATED.withDescription("No authentication
token provided")
call.close(status, new Metadata())
new ServerCall.Listener[ReqT]() {}
- } else if (authHeaderValue != expectedValue) {
+ } else if (!MessageDigest.isEqual(
+ authHeaderValue.getBytes(UTF_8),
+ expectedValue.getBytes(UTF_8))) {
val status = Status.UNAUTHENTICATED.withDescription("Invalid
authentication token")
call.close(status, new Metadata())
new ServerCall.Listener[ReqT]() {}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]