This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.0 by this push:
     new 33ecfab405a5 [SPARK-55015][SS][SQL] Fix decodeRemainingKey numFields 
calculation in PrefixKeyScanStateEncoder
33ecfab405a5 is described below

commit 33ecfab405a59b915219872b81d051c974db7b43
Author: ericm-db <[email protected]>
AuthorDate: Fri Jan 16 13:49:01 2026 +0900

    [SPARK-55015][SS][SQL] Fix decodeRemainingKey numFields calculation in 
PrefixKeyScanStateEncoder
    
    ### What changes were proposed in this pull request?
    
    Fix bug in RocksDBStateEncoder.decodeRemainingKey where it incorrectly used 
numColsPrefixKey instead of (keySchema.length - numColsPrefixKey) for the 
PrefixKeyScanStateEncoderSpec case. The remaining key should contain the 
non-prefix columns, not the prefix columns.
    
    ### Why are the changes needed?
    
    On the decode path, we would hit an exception
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Added test to verify decodeRemainingKey correctly decodes with the proper 
number of fields.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #53775 from ericm-db/fix-decode-remaining-key.
    
    Authored-by: ericm-db <[email protected]>
    Signed-off-by: Jungtaek Lim <[email protected]>
    (cherry picked from commit af058d59f3008163fd8c72ce80fa5ca24706165a)
    Signed-off-by: Jungtaek Lim <[email protected]>
---
 .../streaming/state/RocksDBStateEncoder.scala      |   2 +-
 .../execution/streaming/state/RocksDBSuite.scala   | 124 +++++++++++++++++++++
 2 files changed, 125 insertions(+), 1 deletion(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
index cf5f8ba5f2eb..fb5e623bdfec 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
@@ -627,7 +627,7 @@ class UnsafeRowDataEncoder(
   override def decodeRemainingKey(bytes: Array[Byte]): UnsafeRow = {
     keyStateEncoderSpec match {
       case PrefixKeyScanStateEncoderSpec(_, numColsPrefixKey) =>
-        decodeToUnsafeRow(bytes, numFields = numColsPrefixKey)
+        decodeToUnsafeRow(bytes, numFields = keySchema.length - 
numColsPrefixKey)
       case RangeKeyScanStateEncoderSpec(_, orderingOrdinals) =>
         decodeToUnsafeRow(bytes, keySchema.length - orderingOrdinals.length)
       case _ => throw 
unsupportedOperationForKeyStateEncoder("decodeRemainingKey")
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala
index 5d1ed9b8622a..10f14d5655f7 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala
@@ -579,6 +579,130 @@ class RocksDBStateEncoderSuite extends SparkFunSuite {
       assert(decodedValue.getBoolean(2) === true)
     }
   }
+
+  test("verify PrefixKeyScanStateEncoder full encode/decode cycle with 
multi-key session window") {
+    // Simulate session window state with multiple grouping keys
+    // Key schema: [userId, deviceId, sessionStartTime] - mimics session 
window with 2 grouping keys
+    val keySchema = StructType(Seq(
+      StructField("userId", IntegerType),
+      StructField("deviceId", StringType),
+      StructField("sessionStartTime", LongType)
+    ))
+    val valueSchema = StructType(Seq(
+      StructField("count", LongType)
+    ))
+
+    // Session window uses first N columns as prefix (the grouping keys)
+    val numColsPrefixKey = 2
+    val prefixKeySpec = PrefixKeyScanStateEncoderSpec(keySchema, 
numColsPrefixKey)
+    val dataEncoder = new UnsafeRowDataEncoder(prefixKeySpec, valueSchema)
+    val keyEncoder = new PrefixKeyScanStateEncoder(
+      dataEncoder, keySchema, numColsPrefixKey, useColumnFamilies = false)
+
+    // Create a full key row
+    val keyProj = UnsafeProjection.create(keySchema)
+    val fullKey = keyProj.apply(InternalRow(123, 
UTF8String.fromString("device1"), 1000000L))
+
+    // Encode the full key (this is what happens when putting to state store)
+    val encodedKey = keyEncoder.encodeKey(fullKey)
+
+    // Decode the key (this is what happens during prefix scan)
+    val decodedKey = keyEncoder.decodeKey(encodedKey)
+
+    // Verify the decoded key matches the original
+    assert(decodedKey.numFields === 3,
+      s"Expected 3 fields in decoded key, but got ${decodedKey.numFields}")
+    assert(decodedKey.getInt(0) === 123, "userId not preserved")
+    assert(decodedKey.getString(1) === "device1", "deviceId not preserved")
+    assert(decodedKey.getLong(2) === 1000000L, "sessionStartTime not 
preserved")
+  }
+
+  test("verify decodeRemainingKey correctly decodes with fix") {
+    // This test verifies the fix prevents garbage data reads
+    val keySchema = StructType(Seq(
+      StructField("k1", IntegerType),
+      StructField("k2", StringType),
+      StructField("k3", LongType)
+    ))
+    val valueSchema = StructType(Seq(
+      StructField("v1", IntegerType)
+    ))
+
+    val prefixKeySpec = PrefixKeyScanStateEncoderSpec(keySchema, 
numColsPrefixKey = 2)
+    val encoder = new UnsafeRowDataEncoder(prefixKeySpec, valueSchema)
+
+    // Create and encode a remaining key with just the last column (k3)
+    val remainingKeySchema = StructType(Seq(StructField("k3", LongType)))
+    val remainingKeyProj = UnsafeProjection.create(remainingKeySchema)
+    val remainingKeyRow = remainingKeyProj.apply(InternalRow(999999L))
+    val encodedRemainingKey = encoder.encodeRemainingKey(remainingKeyRow)
+
+    // Decode the remaining key
+    val decodedRemainingKey = encoder.decodeRemainingKey(encodedRemainingKey)
+
+    // With the FIX: numFields should be keySchema.length - numColsPrefixKey = 
3 - 2 = 1
+    assert(decodedRemainingKey.numFields === 1,
+      s"Expected 1 field but got ${decodedRemainingKey.numFields}")
+
+    // Field 0 should read correctly
+    assert(decodedRemainingKey.getLong(0) === 999999L,
+      "Field 0 value incorrect")
+
+    // Trying to read field 1 should throw exception (doesn't exist)
+    intercept[AssertionError] {
+      decodedRemainingKey.getLong(1)
+    }
+  }
+
+  test("verify AvroStateEncoder decodeRemainingKey with 
PrefixKeyScanStateEncoder") {
+    // This test verifies that AvroStateEncoder correctly decodes remaining 
keys
+    // AvroStateEncoder uses remainingKeySchema = 
keySchema.drop(numColsPrefixKey)
+    // which is the correct calculation (unlike the bug in 
UnsafeRowDataEncoder)
+    val keySchema = StructType(Seq(
+      StructField("k1", IntegerType),
+      StructField("k2", StringType),
+      StructField("k3", LongType)
+    ))
+    val valueSchema = StructType(Seq(
+      StructField("v1", IntegerType)
+    ))
+
+    // Create test state schema provider
+    val testProvider = new TestStateSchemaProvider()
+    testProvider.captureSchema(
+      StateStore.DEFAULT_COL_FAMILY_NAME,
+      keySchema,
+      valueSchema,
+      keySchemaId = 0,
+      valueSchemaId = 0
+    )
+
+    val prefixKeySpec = PrefixKeyScanStateEncoderSpec(keySchema, 
numColsPrefixKey = 2)
+    val encoder = new AvroStateEncoder(prefixKeySpec, valueSchema, 
Some(testProvider),
+      StateStore.DEFAULT_COL_FAMILY_NAME)
+
+    // Create and encode a remaining key with just the last column (k3)
+    val remainingKeySchema = StructType(Seq(StructField("k3", LongType)))
+    val remainingKeyProj = UnsafeProjection.create(remainingKeySchema)
+    val remainingKeyRow = remainingKeyProj.apply(InternalRow(999999L))
+    val encodedRemainingKey = encoder.encodeRemainingKey(remainingKeyRow)
+
+    // Decode the remaining key
+    val decodedRemainingKey = encoder.decodeRemainingKey(encodedRemainingKey)
+
+    // Should have 1 field (keySchema.length - numColsPrefixKey = 3 - 2 = 1)
+    assert(decodedRemainingKey.numFields === 1,
+      s"Expected 1 field but got ${decodedRemainingKey.numFields}")
+
+    // Field 0 should read correctly
+    assert(decodedRemainingKey.getLong(0) === 999999L,
+      "Field 0 value incorrect")
+
+    // Trying to read field 1 should throw exception (doesn't exist)
+    intercept[AssertionError] {
+      decodedRemainingKey.getLong(1)
+    }
+  }
 }
 
 @SlowSQLTest


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to