HeartSaVioR commented on code in PR #53775:
URL: https://github.com/apache/spark/pull/53775#discussion_r2689363806


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala:
##########
@@ -584,6 +584,114 @@ class RocksDBStateEncoderSuite extends SparkFunSuite {
       assert(decodedValue.getBoolean(2) === true)
     }
   }
+
+  test("verify decodeRemainingKey with PrefixKeyScanStateEncoder uses correct 
numFields") {
+    val keySchema = StructType(Seq(
+      StructField("k1", IntegerType),
+      StructField("k2", LongType),
+      StructField("k3", StringType)
+    ))
+    val valueSchema = StructType(Seq(
+      StructField("v1", IntegerType)
+    ))
+
+    // Create encoder with 2 prefix columns, so remaining key should have 1 
column (k3)
+    val prefixKeySpec = PrefixKeyScanStateEncoderSpec(keySchema, 
numColsPrefixKey = 2)
+    val encoder = new UnsafeRowDataEncoder(prefixKeySpec, valueSchema)
+
+    // Create a remaining key row with just the last column
+    val remainingKeySchema = StructType(Seq(StructField("k3", StringType)))
+    val remainingKeyProj = UnsafeProjection.create(remainingKeySchema)
+    val remainingKeyRow = 
remainingKeyProj.apply(InternalRow(UTF8String.fromString("test")))
+
+    // Encode the remaining key
+    val encodedRemainingKey = encoder.encodeRemainingKey(remainingKeyRow)
+
+    // Decode the remaining key - this should create a row with 1 field, not 2
+    val decodedRemainingKey = encoder.decodeRemainingKey(encodedRemainingKey)
+
+    // Verify the decoded row has correct number of fields (should be 1, not 2)
+    assert(decodedRemainingKey.numFields === 1,
+      s"Expected 1 field in decoded remaining key, but got 
${decodedRemainingKey.numFields}")
+
+    // Verify the value is preserved correctly
+    assert(decodedRemainingKey.getString(0) === "test",
+      "Value not preserved in remaining key encoding/decoding")
+  }
+
+  test("verify PrefixKeyScanStateEncoder full encode/decode cycle with 
multi-key session window") {
+    // Simulate session window state with multiple grouping keys
+    // Key schema: [userId, deviceId, sessionStartTime] - mimics session 
window with 2 grouping keys
+    val keySchema = StructType(Seq(
+      StructField("userId", IntegerType),
+      StructField("deviceId", StringType),
+      StructField("sessionStartTime", LongType)
+    ))
+    val valueSchema = StructType(Seq(
+      StructField("count", LongType)
+    ))
+
+    // Session window uses first N columns as prefix (the grouping keys)
+    val numColsPrefixKey = 2
+    val prefixKeySpec = PrefixKeyScanStateEncoderSpec(keySchema, 
numColsPrefixKey)
+    val dataEncoder = new UnsafeRowDataEncoder(prefixKeySpec, valueSchema)
+    val keyEncoder = new PrefixKeyScanStateEncoder(
+      dataEncoder, keySchema, numColsPrefixKey, useColumnFamilies = false)
+
+    // Create a full key row
+    val keyProj = UnsafeProjection.create(keySchema)
+    val fullKey = keyProj.apply(InternalRow(123, 
UTF8String.fromString("device1"), 1000000L))
+
+    // Encode the full key (this is what happens when putting to state store)
+    val encodedKey = keyEncoder.encodeKey(fullKey)
+
+    // Decode the key (this is what happens during prefix scan)
+    val decodedKey = keyEncoder.decodeKey(encodedKey)
+
+    // Verify the decoded key matches the original
+    assert(decodedKey.numFields === 3,
+      s"Expected 3 fields in decoded key, but got ${decodedKey.numFields}")
+    assert(decodedKey.getInt(0) === 123, "userId not preserved")
+    assert(decodedKey.getString(1) === "device1", "deviceId not preserved")
+    assert(decodedKey.getLong(2) === 1000000L, "sessionStartTime not 
preserved")
+  }
+
+  test("verify decodeRemainingKey correctly decodes with fix") {

Review Comment:
   While we are here, do we have the same test coverage in Avro encoder path?



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala:
##########
@@ -584,6 +584,114 @@ class RocksDBStateEncoderSuite extends SparkFunSuite {
       assert(decodedValue.getBoolean(2) === true)
     }
   }
+
+  test("verify decodeRemainingKey with PrefixKeyScanStateEncoder uses correct 
numFields") {
+    val keySchema = StructType(Seq(
+      StructField("k1", IntegerType),
+      StructField("k2", LongType),
+      StructField("k3", StringType)
+    ))
+    val valueSchema = StructType(Seq(
+      StructField("v1", IntegerType)
+    ))
+
+    // Create encoder with 2 prefix columns, so remaining key should have 1 
column (k3)
+    val prefixKeySpec = PrefixKeyScanStateEncoderSpec(keySchema, 
numColsPrefixKey = 2)
+    val encoder = new UnsafeRowDataEncoder(prefixKeySpec, valueSchema)
+
+    // Create a remaining key row with just the last column
+    val remainingKeySchema = StructType(Seq(StructField("k3", StringType)))
+    val remainingKeyProj = UnsafeProjection.create(remainingKeySchema)
+    val remainingKeyRow = 
remainingKeyProj.apply(InternalRow(UTF8String.fromString("test")))
+
+    // Encode the remaining key
+    val encodedRemainingKey = encoder.encodeRemainingKey(remainingKeyRow)
+
+    // Decode the remaining key - this should create a row with 1 field, not 2
+    val decodedRemainingKey = encoder.decodeRemainingKey(encodedRemainingKey)
+
+    // Verify the decoded row has correct number of fields (should be 1, not 2)
+    assert(decodedRemainingKey.numFields === 1,
+      s"Expected 1 field in decoded remaining key, but got 
${decodedRemainingKey.numFields}")
+
+    // Verify the value is preserved correctly
+    assert(decodedRemainingKey.getString(0) === "test",
+      "Value not preserved in remaining key encoding/decoding")
+  }
+
+  test("verify PrefixKeyScanStateEncoder full encode/decode cycle with 
multi-key session window") {
+    // Simulate session window state with multiple grouping keys
+    // Key schema: [userId, deviceId, sessionStartTime] - mimics session 
window with 2 grouping keys
+    val keySchema = StructType(Seq(
+      StructField("userId", IntegerType),
+      StructField("deviceId", StringType),
+      StructField("sessionStartTime", LongType)
+    ))
+    val valueSchema = StructType(Seq(
+      StructField("count", LongType)
+    ))
+
+    // Session window uses first N columns as prefix (the grouping keys)
+    val numColsPrefixKey = 2
+    val prefixKeySpec = PrefixKeyScanStateEncoderSpec(keySchema, 
numColsPrefixKey)
+    val dataEncoder = new UnsafeRowDataEncoder(prefixKeySpec, valueSchema)
+    val keyEncoder = new PrefixKeyScanStateEncoder(
+      dataEncoder, keySchema, numColsPrefixKey, useColumnFamilies = false)
+
+    // Create a full key row
+    val keyProj = UnsafeProjection.create(keySchema)
+    val fullKey = keyProj.apply(InternalRow(123, 
UTF8String.fromString("device1"), 1000000L))
+
+    // Encode the full key (this is what happens when putting to state store)
+    val encodedKey = keyEncoder.encodeKey(fullKey)
+
+    // Decode the key (this is what happens during prefix scan)
+    val decodedKey = keyEncoder.decodeKey(encodedKey)
+
+    // Verify the decoded key matches the original
+    assert(decodedKey.numFields === 3,
+      s"Expected 3 fields in decoded key, but got ${decodedKey.numFields}")
+    assert(decodedKey.getInt(0) === 123, "userId not preserved")
+    assert(decodedKey.getString(1) === "device1", "deviceId not preserved")
+    assert(decodedKey.getLong(2) === 1000000L, "sessionStartTime not 
preserved")
+  }
+
+  test("verify decodeRemainingKey correctly decodes with fix") {

Review Comment:
   Looks like this test supersedes the test "verify decodeRemainingKey with 
PrefixKeyScanStateEncoder uses correct numFields"? Let's retain one with more 
coverage if one supersedes the other one.



##########
sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSessionWindowSuite.scala:
##########
@@ -805,4 +805,29 @@ class StreamingSessionWindowSuite extends StreamTest
         "CAST(session.end AS LONG) - CAST(session.start AS LONG) AS 
durationMs",
         "numEvents")
   }
+
+  testWithAllOptions("complete mode - session window - multiple grouping 
keys") {

Review Comment:
   Does this test fail without the fix? If then could you please help checking 
the existing test "complete mode - session window - multiple col key" and 
reason about why it didn't trigger the issue?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala:
##########
@@ -629,7 +629,7 @@ class UnsafeRowDataEncoder(
   override def decodeRemainingKey(bytes: Array[Byte]): UnsafeRow = {
     keyStateEncoderSpec match {
       case PrefixKeyScanStateEncoderSpec(_, numColsPrefixKey) =>
-        decodeToUnsafeRow(bytes, numFields = numColsPrefixKey)
+        decodeToUnsafeRow(bytes, numFields = keySchema.length - 
numColsPrefixKey)

Review Comment:
   Nice finding!



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to