This is an automated email from the ASF dual-hosted git repository.
kabhwan pushed a commit to branch branch-4.1
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.1 by this push:
new 86c6548697ed [SPARK-55015][SS][SQL] Fix decodeRemainingKey numFields
calculation in PrefixKeyScanStateEncoder
86c6548697ed is described below
commit 86c6548697ed1a93bef0139b17d2ffe4edcbfdc9
Author: ericm-db <[email protected]>
AuthorDate: Fri Jan 16 13:49:01 2026 +0900
[SPARK-55015][SS][SQL] Fix decodeRemainingKey numFields calculation in
PrefixKeyScanStateEncoder
### What changes were proposed in this pull request?
Fix bug in RocksDBStateEncoder.decodeRemainingKey where it incorrectly used
numColsPrefixKey instead of (keySchema.length - numColsPrefixKey) for the
PrefixKeyScanStateEncoderSpec case. The remaining key should contain the
non-prefix columns, not the prefix columns.
### Why are the changes needed?
On the decode path, we would hit an exception
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Added test to verify decodeRemainingKey correctly decodes with the proper
number of fields.
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #53775 from ericm-db/fix-decode-remaining-key.
Authored-by: ericm-db <[email protected]>
Signed-off-by: Jungtaek Lim <[email protected]>
(cherry picked from commit af058d59f3008163fd8c72ce80fa5ca24706165a)
Signed-off-by: Jungtaek Lim <[email protected]>
---
.../streaming/state/RocksDBStateEncoder.scala | 2 +-
.../execution/streaming/state/RocksDBSuite.scala | 124 +++++++++++++++++++++
2 files changed, 125 insertions(+), 1 deletion(-)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
index f49c79f96b9c..f6fe4dbea576 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
@@ -628,7 +628,7 @@ class UnsafeRowDataEncoder(
override def decodeRemainingKey(bytes: Array[Byte]): UnsafeRow = {
keyStateEncoderSpec match {
case PrefixKeyScanStateEncoderSpec(_, numColsPrefixKey) =>
- decodeToUnsafeRow(bytes, numFields = numColsPrefixKey)
+ decodeToUnsafeRow(bytes, numFields = keySchema.length -
numColsPrefixKey)
case RangeKeyScanStateEncoderSpec(_, orderingOrdinals) =>
decodeToUnsafeRow(bytes, keySchema.length - orderingOrdinals.length)
case _ => throw
unsupportedOperationForKeyStateEncoder("decodeRemainingKey")
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala
index 6c22436c29a0..21ce069f3b3b 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala
@@ -584,6 +584,130 @@ class RocksDBStateEncoderSuite extends SparkFunSuite {
assert(decodedValue.getBoolean(2) === true)
}
}
+
+ test("verify PrefixKeyScanStateEncoder full encode/decode cycle with
multi-key session window") {
+ // Simulate session window state with multiple grouping keys
+ // Key schema: [userId, deviceId, sessionStartTime] - mimics session
window with 2 grouping keys
+ val keySchema = StructType(Seq(
+ StructField("userId", IntegerType),
+ StructField("deviceId", StringType),
+ StructField("sessionStartTime", LongType)
+ ))
+ val valueSchema = StructType(Seq(
+ StructField("count", LongType)
+ ))
+
+ // Session window uses first N columns as prefix (the grouping keys)
+ val numColsPrefixKey = 2
+ val prefixKeySpec = PrefixKeyScanStateEncoderSpec(keySchema,
numColsPrefixKey)
+ val dataEncoder = new UnsafeRowDataEncoder(prefixKeySpec, valueSchema)
+ val keyEncoder = new PrefixKeyScanStateEncoder(
+ dataEncoder, keySchema, numColsPrefixKey, useColumnFamilies = false)
+
+ // Create a full key row
+ val keyProj = UnsafeProjection.create(keySchema)
+ val fullKey = keyProj.apply(InternalRow(123,
UTF8String.fromString("device1"), 1000000L))
+
+ // Encode the full key (this is what happens when putting to state store)
+ val encodedKey = keyEncoder.encodeKey(fullKey)
+
+ // Decode the key (this is what happens during prefix scan)
+ val decodedKey = keyEncoder.decodeKey(encodedKey)
+
+ // Verify the decoded key matches the original
+ assert(decodedKey.numFields === 3,
+ s"Expected 3 fields in decoded key, but got ${decodedKey.numFields}")
+ assert(decodedKey.getInt(0) === 123, "userId not preserved")
+ assert(decodedKey.getString(1) === "device1", "deviceId not preserved")
+ assert(decodedKey.getLong(2) === 1000000L, "sessionStartTime not
preserved")
+ }
+
+ test("verify decodeRemainingKey correctly decodes with fix") {
+ // This test verifies the fix prevents garbage data reads
+ val keySchema = StructType(Seq(
+ StructField("k1", IntegerType),
+ StructField("k2", StringType),
+ StructField("k3", LongType)
+ ))
+ val valueSchema = StructType(Seq(
+ StructField("v1", IntegerType)
+ ))
+
+ val prefixKeySpec = PrefixKeyScanStateEncoderSpec(keySchema,
numColsPrefixKey = 2)
+ val encoder = new UnsafeRowDataEncoder(prefixKeySpec, valueSchema)
+
+ // Create and encode a remaining key with just the last column (k3)
+ val remainingKeySchema = StructType(Seq(StructField("k3", LongType)))
+ val remainingKeyProj = UnsafeProjection.create(remainingKeySchema)
+ val remainingKeyRow = remainingKeyProj.apply(InternalRow(999999L))
+ val encodedRemainingKey = encoder.encodeRemainingKey(remainingKeyRow)
+
+ // Decode the remaining key
+ val decodedRemainingKey = encoder.decodeRemainingKey(encodedRemainingKey)
+
+ // With the FIX: numFields should be keySchema.length - numColsPrefixKey =
3 - 2 = 1
+ assert(decodedRemainingKey.numFields === 1,
+ s"Expected 1 field but got ${decodedRemainingKey.numFields}")
+
+ // Field 0 should read correctly
+ assert(decodedRemainingKey.getLong(0) === 999999L,
+ "Field 0 value incorrect")
+
+ // Trying to read field 1 should throw exception (doesn't exist)
+ intercept[AssertionError] {
+ decodedRemainingKey.getLong(1)
+ }
+ }
+
+ test("verify AvroStateEncoder decodeRemainingKey with
PrefixKeyScanStateEncoder") {
+ // This test verifies that AvroStateEncoder correctly decodes remaining
keys
+ // AvroStateEncoder uses remainingKeySchema =
keySchema.drop(numColsPrefixKey)
+ // which is the correct calculation (unlike the bug in
UnsafeRowDataEncoder)
+ val keySchema = StructType(Seq(
+ StructField("k1", IntegerType),
+ StructField("k2", StringType),
+ StructField("k3", LongType)
+ ))
+ val valueSchema = StructType(Seq(
+ StructField("v1", IntegerType)
+ ))
+
+ // Create test state schema provider
+ val testProvider = new TestStateSchemaProvider()
+ testProvider.captureSchema(
+ StateStore.DEFAULT_COL_FAMILY_NAME,
+ keySchema,
+ valueSchema,
+ keySchemaId = 0,
+ valueSchemaId = 0
+ )
+
+ val prefixKeySpec = PrefixKeyScanStateEncoderSpec(keySchema,
numColsPrefixKey = 2)
+ val encoder = new AvroStateEncoder(prefixKeySpec, valueSchema,
Some(testProvider),
+ StateStore.DEFAULT_COL_FAMILY_NAME)
+
+ // Create and encode a remaining key with just the last column (k3)
+ val remainingKeySchema = StructType(Seq(StructField("k3", LongType)))
+ val remainingKeyProj = UnsafeProjection.create(remainingKeySchema)
+ val remainingKeyRow = remainingKeyProj.apply(InternalRow(999999L))
+ val encodedRemainingKey = encoder.encodeRemainingKey(remainingKeyRow)
+
+ // Decode the remaining key
+ val decodedRemainingKey = encoder.decodeRemainingKey(encodedRemainingKey)
+
+ // Should have 1 field (keySchema.length - numColsPrefixKey = 3 - 2 = 1)
+ assert(decodedRemainingKey.numFields === 1,
+ s"Expected 1 field but got ${decodedRemainingKey.numFields}")
+
+ // Field 0 should read correctly
+ assert(decodedRemainingKey.getLong(0) === 999999L,
+ "Field 0 value incorrect")
+
+ // Trying to read field 1 should throw exception (doesn't exist)
+ intercept[AssertionError] {
+ decodedRemainingKey.getLong(1)
+ }
+ }
}
@SlowSQLTest
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]