This is an automated email from the ASF dual-hosted git repository.

ashrigondekar pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new a8e35c407bc5 [SPARK-54106][SS] Recheckin State store row checksum 
implementation
a8e35c407bc5 is described below

commit a8e35c407bc5340f83b35e5a2f0b0767c6baadb0
Author: micheal-o <[email protected]>
AuthorDate: Tue Nov 4 18:06:59 2025 -0800

    [SPARK-54106][SS] Recheckin State store row checksum implementation
    
    ### What changes were proposed in this pull request?
    
    Original [PR](https://github.com/apache/spark/pull/52809) was reverted 
because the `AlsoTestWithStateStoreRowChecksum` test dimension added to 
existing tests made them slow and timeout. Now creating a separate suite for 
testing with row checksum.
    
    This introduces row checksum creation and verification for state store, 
both HDFS and RocksDB state store. This will help detect corruption at the row 
level and help prevent us from corrupting the remote checkpoint. We also verify 
the checksum when loading from checkpoint to detect if a row in the checkpoint 
is corrupt.
    
    Since this adds overhead, it is disabled by default and can only be enabled 
for a new checkpoint since the conf is also written to the offset log to ensure 
that.
    
    This also introduces a readVerificationRatio conf, that allows checksum 
verification when a row is read by client. This makes corruption detection more 
immediate and will fail the query task, rather than when it is too late. But 
can be expensive to do this frequently for reads, hence why the frequency is 
configurable.
    
    ### Why are the changes needed?
    
    Integrity verification
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Added new tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #52829 from micheal-o/row_checksum_2.
    
    Authored-by: micheal-o <[email protected]>
    Signed-off-by: Anish Shrigondekar <[email protected]>
---
 .../src/main/resources/error/error-conditions.json |   7 +
 .../org/apache/spark/sql/internal/SQLConf.scala    |  30 ++
 .../streaming/checkpointing/OffsetSeq.scala        |   6 +-
 .../streaming/state/HDFSBackedStateStoreMap.scala  |  89 +++-
 .../state/HDFSBackedStateStoreProvider.scala       | 173 +++++--
 .../sql/execution/streaming/state/RocksDB.scala    | 212 ++++++--
 .../streaming/state/RocksDBStateEncoder.scala      |  40 ++
 .../state/RocksDBStateStoreProvider.scala          |  54 ++-
 .../sql/execution/streaming/state/StateStore.scala |   1 -
 .../streaming/state/StateStoreChangelog.scala      |   8 +
 .../execution/streaming/state/StateStoreConf.scala |   6 +
 .../streaming/state/StateStoreErrors.scala         |  19 +
 .../execution/streaming/state/StateStoreRow.scala  |  56 +++
 .../streaming/state/StateStoreRowChecksum.scala    | 536 +++++++++++++++++++++
 .../StateDataSourceTransformWithStateSuite.scala   |   9 +-
 .../execution/streaming/OffsetSeqLogSuite.scala    |  22 +
 .../execution/streaming/state/ListStateSuite.scala |   5 +
 .../execution/streaming/state/MapStateSuite.scala  |   5 +
 .../RocksDBStateStoreCheckpointFormatV2Suite.scala |   7 +
 .../state/RocksDBStateStoreIntegrationSuite.scala  |   6 +
 .../RocksDBStateStoreLockHardeningSuite.scala      |   6 +
 .../streaming/state/RocksDBStateStoreSuite.scala   |   7 +
 .../execution/streaming/state/RocksDBSuite.scala   |   7 +
 .../state/StateStoreInstanceMetricSuite.scala      |   6 +
 .../state/StateStoreRowChecksumSuite.scala         | 445 +++++++++++++++++
 .../streaming/state/StateStoreSuite.scala          |  24 +
 .../state/StatefulProcessorHandleSuite.scala       |   6 +
 .../sql/execution/streaming/state/TimerSuite.scala |   5 +
 .../streaming/state/ValueStateSuite.scala          |   4 +
 .../streaming/TransformWithListStateSuite.scala    |   9 +-
 .../sql/streaming/TransformWithMapStateSuite.scala |   9 +-
 .../TransformWithStateChainingSuite.scala          |   9 +-
 .../TransformWithStateInitialStateSuite.scala      |   9 +-
 33 files changed, 1725 insertions(+), 112 deletions(-)

diff --git a/common/utils/src/main/resources/error/error-conditions.json 
b/common/utils/src/main/resources/error/error-conditions.json
index a990221221a4..fd7002e86f67 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -5433,6 +5433,13 @@
     ],
     "sqlState" : "42K06"
   },
+  "STATE_STORE_ROW_CHECKSUM_VERIFICATION_FAILED" : {
+    "message" : [
+      "Row checksum verification failed for stateStore=<stateStoreId>. The row 
may be corrupted.",
+      "Expected checksum: <expectedChecksum>, Computed checksum: 
<computedChecksum>."
+    ],
+    "sqlState" : "XXKST"
+  },
   "STATE_STORE_STATE_SCHEMA_FILES_THRESHOLD_EXCEEDED" : {
     "message" : [
       "The number of state schema files <numStateSchemaFiles> exceeds the 
maximum number of state schema files for this query: <maxStateSchemaFiles>.",
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index d2d7edc65121..38e823a96cbf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -2670,6 +2670,31 @@ object SQLConf {
       .checkValue(k => k >= 0, "Must be greater than or equal to 0")
       .createWithDefault(5)
 
+  val STATE_STORE_ROW_CHECKSUM_ENABLED =
+    buildConf("spark.sql.streaming.stateStore.rowChecksum.enabled")
+      .internal()
+      .doc("When true, checksum would be generated and verified for each state 
store row. " +
+        "This is used to detect row level corruption. " +
+        "Note: This configuration cannot be changed between query restarts " +
+        "from the same checkpoint location.")
+      .version("4.1.0")
+      .booleanConf
+      .createWithDefault(false)
+
+  val STATE_STORE_ROW_CHECKSUM_READ_VERIFICATION_RATIO =
+    
buildConf("spark.sql.streaming.stateStore.rowChecksum.readVerificationRatio")
+      .internal()
+      .doc("When specified, Spark will do row checksum verification for every 
specified " +
+        "number of rows read from state store. The check is to ensure the row 
read from " +
+        "state store is not corrupt. Default is 0, which means no verification 
during read " +
+        "but we will still do verification when loading from checkpoint 
location." +
+        "Example, if you set to 1, it will do the check for every row read 
from the state store." +
+        "If set to 10, it will do the check for every 10th row read from the 
state store.")
+      .version("4.1.0")
+      .longConf
+      .checkValue(k => k >= 0, "Must be greater than or equal to 0")
+      .createWithDefault(if (Utils.isTesting) 1 else 0)
+
   val STATEFUL_SHUFFLE_PARTITIONS_INTERNAL =
     buildConf("spark.sql.streaming.internal.stateStore.partitions")
       .doc("WARN: This config is used internally and is not intended to be 
user-facing. This " +
@@ -6795,6 +6820,11 @@ class SQLConf extends Serializable with Logging with 
SqlApiConf {
   def stateStoreCoordinatorMaxLaggingStoresToReport: Int =
     getConf(STATE_STORE_COORDINATOR_MAX_LAGGING_STORES_TO_REPORT)
 
+  def stateStoreRowChecksumEnabled: Boolean = 
getConf(STATE_STORE_ROW_CHECKSUM_ENABLED)
+
+  def stateStoreRowChecksumReadVerificationRatio: Long =
+    getConf(STATE_STORE_ROW_CHECKSUM_READ_VERIFICATION_RATIO)
+
   def checkpointLocation: Option[String] = getConf(CHECKPOINT_LOCATION)
 
   def checkpointFileChecksumEnabled: Boolean = 
getConf(STREAMING_CHECKPOINT_FILE_CHECKSUM_ENABLED)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/OffsetSeq.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/OffsetSeq.scala
index 62c903cb689a..888dc0cdb912 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/OffsetSeq.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/OffsetSeq.scala
@@ -113,7 +113,8 @@ object OffsetSeqMetadata extends Logging {
     FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION, 
STREAMING_AGGREGATION_STATE_FORMAT_VERSION,
     STREAMING_JOIN_STATE_FORMAT_VERSION, STATE_STORE_COMPRESSION_CODEC,
     STATE_STORE_ROCKSDB_FORMAT_VERSION, 
STATEFUL_OPERATOR_USE_STRICT_DISTRIBUTION,
-    PRUNE_FILTERS_CAN_PRUNE_STREAMING_SUBPLAN, 
STREAMING_STATE_STORE_ENCODING_FORMAT
+    PRUNE_FILTERS_CAN_PRUNE_STREAMING_SUBPLAN, 
STREAMING_STATE_STORE_ENCODING_FORMAT,
+    STATE_STORE_ROW_CHECKSUM_ENABLED
   )
 
   /**
@@ -159,7 +160,8 @@ object OffsetSeqMetadata extends Logging {
     STATE_STORE_COMPRESSION_CODEC.key -> CompressionCodec.LZ4,
     STATEFUL_OPERATOR_USE_STRICT_DISTRIBUTION.key -> "false",
     PRUNE_FILTERS_CAN_PRUNE_STREAMING_SUBPLAN.key -> "true",
-    STREAMING_STATE_STORE_ENCODING_FORMAT.key -> "unsaferow"
+    STREAMING_STATE_STORE_ENCODING_FORMAT.key -> "unsaferow",
+    STATE_STORE_ROW_CHECKSUM_ENABLED.key -> "false"
   )
 
   def readValue[T](metadataLog: OffsetSeqMetadata, confKey: ConfigEntry[T]): 
String = {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreMap.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreMap.scala
index fe59703a1f45..f2290f8569b8 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreMap.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreMap.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.execution.streaming.state
 
+import java.util.Map.Entry
+
 import scala.collection.mutable
 import scala.jdk.CollectionConverters._
 
@@ -27,10 +29,13 @@ import org.apache.spark.sql.types.{StructField, StructType}
 trait HDFSBackedStateStoreMap {
   def size(): Int
   def get(key: UnsafeRow): UnsafeRow
-  def put(key: UnsafeRow, value: UnsafeRow): UnsafeRow
+  def put(key: UnsafeRow, value: UnsafeRowWrapper): UnsafeRowWrapper
   def putAll(map: HDFSBackedStateStoreMap): Unit
-  def remove(key: UnsafeRow): UnsafeRow
+  def remove(key: UnsafeRow): UnsafeRowWrapper
   def iterator(): Iterator[UnsafeRowPair]
+  /** Returns entries in the underlying map and skips additional checks done 
by [[iterator]].
+   * [[iterator]] should be preferred over this. */
+  def entryIterator(): Iterator[Entry[UnsafeRow, UnsafeRowWrapper]]
   def prefixScan(prefixKey: UnsafeRow): Iterator[UnsafeRowPair]
 }
 
@@ -40,42 +45,69 @@ object HDFSBackedStateStoreMap {
   //   the map when the iterator was created
   // - Any updates to the map while iterating through the filtered iterator 
does not throw
   //   java.util.ConcurrentModificationException
-  type MapType = java.util.concurrent.ConcurrentHashMap[UnsafeRow, UnsafeRow]
+  type MapType = java.util.concurrent.ConcurrentHashMap[UnsafeRow, 
UnsafeRowWrapper]
 
-  def create(keySchema: StructType, numColsPrefixKey: Int): 
HDFSBackedStateStoreMap = {
+  def create(
+      keySchema: StructType,
+      numColsPrefixKey: Int,
+      readVerifier: Option[KeyValueIntegrityVerifier]): 
HDFSBackedStateStoreMap = {
     if (numColsPrefixKey > 0) {
-      new PrefixScannableHDFSBackedStateStoreMap(keySchema, numColsPrefixKey)
+      new PrefixScannableHDFSBackedStateStoreMap(keySchema, numColsPrefixKey, 
readVerifier)
     } else {
-      new NoPrefixHDFSBackedStateStoreMap()
+      new NoPrefixHDFSBackedStateStoreMap(readVerifier)
+    }
+  }
+
+  /** Get the value row from the value wrapper and verify it */
+  def getAndVerifyValueRow(
+      key: UnsafeRow,
+      valueWrapper: UnsafeRowWrapper,
+      readVerifier: Option[KeyValueIntegrityVerifier]): UnsafeRow = {
+    Option(valueWrapper) match {
+      case Some(value) =>
+        readVerifier.foreach(_.verify(key, value))
+        value.unsafeRow()
+      case None => null
     }
   }
 }
 
-class NoPrefixHDFSBackedStateStoreMap extends HDFSBackedStateStoreMap {
+class NoPrefixHDFSBackedStateStoreMap(private val readVerifier: 
Option[KeyValueIntegrityVerifier])
+    extends HDFSBackedStateStoreMap {
   private val map = new HDFSBackedStateStoreMap.MapType()
 
   override def size(): Int = map.size()
 
-  override def get(key: UnsafeRow): UnsafeRow = map.get(key)
+  override def get(key: UnsafeRow): UnsafeRow = {
+    HDFSBackedStateStoreMap.getAndVerifyValueRow(key, map.get(key), 
readVerifier)
+  }
 
-  override def put(key: UnsafeRow, value: UnsafeRow): UnsafeRow = map.put(key, 
value)
+  override def put(key: UnsafeRow, value: UnsafeRowWrapper): UnsafeRowWrapper 
= map.put(key, value)
 
   def putAll(other: HDFSBackedStateStoreMap): Unit = {
     other match {
       case o: NoPrefixHDFSBackedStateStoreMap => map.putAll(o.map)
-      case _ => other.iterator().foreach { pair => put(pair.key, pair.value) }
+      case _ => other.entryIterator().foreach { pair => put(pair.getKey, 
pair.getValue) }
     }
   }
 
-  override def remove(key: UnsafeRow): UnsafeRow = map.remove(key)
+  override def remove(key: UnsafeRow): UnsafeRowWrapper = map.remove(key)
 
   override def iterator(): Iterator[UnsafeRowPair] = {
     val unsafeRowPair = new UnsafeRowPair()
-    map.entrySet.asScala.iterator.map { entry =>
-      unsafeRowPair.withRows(entry.getKey, entry.getValue)
+    entryIterator().map { entry =>
+      val valueRow = HDFSBackedStateStoreMap
+        .getAndVerifyValueRow(entry.getKey, entry.getValue, readVerifier)
+      unsafeRowPair.withRows(entry.getKey, valueRow)
     }
   }
 
+  /** Returns entries in the underlying map and skips additional checks done 
by [[iterator]].
+   * [[iterator]] should be preferred over this. */
+  override def entryIterator(): Iterator[Entry[UnsafeRow, UnsafeRowWrapper]] = 
{
+    map.entrySet.asScala.iterator
+  }
+
   override def prefixScan(prefixKey: UnsafeRow): Iterator[UnsafeRowPair] = {
     throw SparkUnsupportedOperationException()
   }
@@ -83,7 +115,8 @@ class NoPrefixHDFSBackedStateStoreMap extends 
HDFSBackedStateStoreMap {
 
 class PrefixScannableHDFSBackedStateStoreMap(
     keySchema: StructType,
-    numColsPrefixKey: Int) extends HDFSBackedStateStoreMap {
+    numColsPrefixKey: Int,
+    private val readVerifier: Option[KeyValueIntegrityVerifier]) extends 
HDFSBackedStateStoreMap {
 
   private val map = new HDFSBackedStateStoreMap.MapType()
 
@@ -103,9 +136,11 @@ class PrefixScannableHDFSBackedStateStoreMap(
 
   override def size(): Int = map.size()
 
-  override def get(key: UnsafeRow): UnsafeRow = map.get(key)
+  override def get(key: UnsafeRow): UnsafeRow = {
+    HDFSBackedStateStoreMap.getAndVerifyValueRow(key, map.get(key), 
readVerifier)
+  }
 
-  override def put(key: UnsafeRow, value: UnsafeRow): UnsafeRow = {
+  override def put(key: UnsafeRow, value: UnsafeRowWrapper): UnsafeRowWrapper 
= {
     val ret = map.put(key, value)
 
     val prefixKey = prefixKeyProjection(key).copy()
@@ -136,11 +171,11 @@ class PrefixScannableHDFSBackedStateStoreMap(
           prefixKeyToKeysMap.put(prefixKey, newSet)
         }
 
-      case _ => other.iterator().foreach { pair => put(pair.key, pair.value) }
+      case _ => other.entryIterator().foreach { pair => put(pair.getKey, 
pair.getValue) }
     }
   }
 
-  override def remove(key: UnsafeRow): UnsafeRow = {
+  override def remove(key: UnsafeRow): UnsafeRowWrapper = {
     val ret = map.remove(key)
 
     if (ret != null) {
@@ -156,15 +191,27 @@ class PrefixScannableHDFSBackedStateStoreMap(
 
   override def iterator(): Iterator[UnsafeRowPair] = {
     val unsafeRowPair = new UnsafeRowPair()
-    map.entrySet.asScala.iterator.map { entry =>
-      unsafeRowPair.withRows(entry.getKey, entry.getValue)
+    entryIterator().map { entry =>
+      val valueRow = HDFSBackedStateStoreMap
+        .getAndVerifyValueRow(entry.getKey, entry.getValue, readVerifier)
+      unsafeRowPair.withRows(entry.getKey, valueRow)
     }
   }
 
+  /** Returns entries in the underlying map and skips additional checks done 
by [[iterator]].
+   * [[iterator]] should be preferred over this. */
+  override def entryIterator(): Iterator[Entry[UnsafeRow, UnsafeRowWrapper]] = 
{
+    map.entrySet.asScala.iterator
+  }
+
   override def prefixScan(prefixKey: UnsafeRow): Iterator[UnsafeRowPair] = {
     val unsafeRowPair = new UnsafeRowPair()
     prefixKeyToKeysMap.getOrDefault(prefixKey, mutable.Set.empty[UnsafeRow])
       .iterator
-      .map { key => unsafeRowPair.withRows(key, map.get(key)) }
+      .map { keyRow =>
+        val valueRow = HDFSBackedStateStoreMap
+          .getAndVerifyValueRow(keyRow, map.get(keyRow), readVerifier)
+        unsafeRowPair.withRows(keyRow, valueRow)
+      }
   }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
index f1c9c94e7bf8..a0ace7976edd 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
@@ -107,7 +107,9 @@ private[sql] class HDFSBackedStateStoreProvider extends 
StateStoreProvider with
   }
 
   /** Implementation of [[StateStore]] API which is backed by an 
HDFS-compatible file system */
-  class HDFSBackedStateStore(val version: Long, mapToUpdate: 
HDFSBackedStateStoreMap)
+  class HDFSBackedStateStore(
+      val version: Long,
+      private val mapToUpdate: HDFSBackedStateStoreMap)
     extends StateStore {
 
     /** Trait and classes representing the internal state of the store */
@@ -163,8 +165,16 @@ private[sql] class HDFSBackedStateStoreProvider extends 
StateStoreProvider with
       verify(state == UPDATING, "Cannot put after already committed or 
aborted")
       val keyCopy = key.copy()
       val valueCopy = value.copy()
-      mapToUpdate.put(keyCopy, valueCopy)
-      writeUpdateToDeltaFile(compressedStream, keyCopy, valueCopy)
+
+      val valueWrapper = if (storeConf.rowChecksumEnabled) {
+        // Add the key-value checksum to the value row
+        StateStoreRowWithChecksum(valueCopy, KeyValueChecksum.create(keyCopy, 
Some(valueCopy)))
+      } else {
+        StateStoreRow(valueCopy)
+      }
+
+      mapToUpdate.put(keyCopy, valueWrapper)
+      writeUpdateToDeltaFile(compressedStream, keyCopy, valueWrapper)
     }
 
     override def remove(key: UnsafeRow, colFamilyName: String): Unit = {
@@ -172,7 +182,13 @@ private[sql] class HDFSBackedStateStoreProvider extends 
StateStoreProvider with
       verify(state == UPDATING, "Cannot remove after already committed or 
aborted")
       val prevValue = mapToUpdate.remove(key)
       if (prevValue != null) {
-        writeRemoveToDeltaFile(compressedStream, key)
+        val keyWrapper = if (storeConf.rowChecksumEnabled) {
+          // Add checksum for only the removed key
+          StateStoreRowWithChecksum(key, KeyValueChecksum.create(key, None))
+        } else {
+          StateStoreRow(key)
+        }
+        writeRemoveToDeltaFile(compressedStream, keyWrapper)
       }
     }
 
@@ -328,7 +344,7 @@ private[sql] class HDFSBackedStateStoreProvider extends 
StateStoreProvider with
       }
 
       performedSnapshotAutoRepair.set(false)
-      val newMap = HDFSBackedStateStoreMap.create(keySchema, numColsPrefixKey)
+      val newMap = createHDFSBackedStateStoreMap()
       if (version > 0) {
         newMap.putAll(loadMap(version))
       }
@@ -613,7 +629,7 @@ private[sql] class HDFSBackedStateStoreProvider extends 
StateStoreProvider with
       } else {
         // Load all the deltas from the version after the loadedVersion up to 
the target version.
         // The loadedVersion is the one with a full snapshot, so it doesn't 
need deltas.
-        val resultMap = HDFSBackedStateStoreMap.create(keySchema, 
numColsPrefixKey)
+        val resultMap = createHDFSBackedStateStoreMap()
         resultMap.putAll(loadedMap)
         for (deltaVersion <- loadedVersion + 1 to version) {
           updateFromDeltaFile(deltaVersion, resultMap)
@@ -652,7 +668,7 @@ private[sql] class HDFSBackedStateStoreProvider extends 
StateStoreProvider with
       override protected def loadSnapshotFromCheckpoint(snapshotVersion: 
Long): Unit = {
         loadedMap = if (snapshotVersion <= 0) {
           // Use an empty map for versions 0 or less.
-          Some(HDFSBackedStateStoreMap.create(keySchema, numColsPrefixKey))
+          Some(createHDFSBackedStateStoreMap())
         } else {
           // first try to get the map from the cache
           synchronized { Option(loadedMaps.get(snapshotVersion)) }
@@ -684,17 +700,30 @@ private[sql] class HDFSBackedStateStoreProvider extends 
StateStoreProvider with
   private def writeUpdateToDeltaFile(
       output: DataOutputStream,
       key: UnsafeRow,
-      value: UnsafeRow): Unit = {
+      value: UnsafeRowWrapper): Unit = {
     val keyBytes = key.getBytes()
-    val valueBytes = value.getBytes()
+    val valueBytes = value match {
+      case v: StateStoreRowWithChecksum =>
+        // If it has checksum, encode the value bytes with the checksum.
+        KeyValueChecksumEncoder.encodeSingleValueRowWithChecksum(
+          v.unsafeRow().getBytes(), v.checksum)
+      case _ => value.unsafeRow().getBytes()
+    }
+
     output.writeInt(keyBytes.size)
     output.write(keyBytes)
     output.writeInt(valueBytes.size)
     output.write(valueBytes)
   }
 
-  private def writeRemoveToDeltaFile(output: DataOutputStream, key: 
UnsafeRow): Unit = {
-    val keyBytes = key.getBytes()
+  private def writeRemoveToDeltaFile(output: DataOutputStream, key: 
UnsafeRowWrapper): Unit = {
+    val keyBytes = key match {
+      case k: StateStoreRowWithChecksum =>
+        // If it has checksum, encode the key bytes with the checksum.
+        KeyValueChecksumEncoder.encodeKeyRowWithChecksum(
+          k.unsafeRow().getBytes(), k.checksum)
+      case _ => key.unsafeRow().getBytes()
+    }
     output.writeInt(keyBytes.size)
     output.write(keyBytes)
     output.writeInt(-1)
@@ -718,6 +747,10 @@ private[sql] class HDFSBackedStateStoreProvider extends 
StateStoreProvider with
       input = decompressStream(sourceStream)
       var eof = false
 
+      // If row checksum is enabled, verify every record in the file to detect 
corrupt rows.
+      val verifier = KeyValueIntegrityVerifier
+        .create(stateStoreId_.toString, storeConf.rowChecksumEnabled, 
verificationRatio = 1)
+
       while (!eof) {
         val keySize = input.readInt()
         if (keySize == -1) {
@@ -730,26 +763,46 @@ private[sql] class HDFSBackedStateStoreProvider extends 
StateStoreProvider with
           Utils.readFully(input, keyRowBuffer, 0, keySize)
 
           val keyRow = new UnsafeRow(keySchema.fields.length)
-          keyRow.pointTo(keyRowBuffer, keySize)
 
           val valueSize = input.readInt()
           if (valueSize < 0) {
+            val originalKeyBytes = if (storeConf.rowChecksumEnabled) {
+              // For deleted row, we added checksum to the key side.
+              // Decode the original key and remove the checksum part
+              
KeyValueChecksumEncoder.decodeAndVerifyKeyRowWithChecksum(verifier, 
keyRowBuffer)
+            } else {
+              keyRowBuffer
+            }
+            keyRow.pointTo(originalKeyBytes, originalKeyBytes.length)
             map.remove(keyRow)
           } else {
+            keyRow.pointTo(keyRowBuffer, keySize)
+
             val valueRowBuffer = new Array[Byte](valueSize)
             Utils.readFully(input, valueRowBuffer, 0, valueSize)
             val valueRow = new UnsafeRow(valueSchema.fields.length)
+
+            val (originalValueBytes, valueWrapper) = if 
(storeConf.rowChecksumEnabled) {
+              // checksum is on the value side
+              val (valueBytes, checksum) = KeyValueChecksumEncoder
+                .decodeSingleValueRowWithChecksum(valueRowBuffer)
+              verifier.foreach(_.verify(keyRowBuffer, Some(valueBytes), 
checksum))
+              (valueBytes, StateStoreRowWithChecksum(valueRow, checksum))
+            } else {
+              (valueRowBuffer, StateStoreRow(valueRow))
+            }
+
             // If valueSize in existing file is not multiple of 8, floor it to 
multiple of 8.
             // This is a workaround for the following:
             // Prior to Spark 2.3 mistakenly append 4 bytes to the value row in
             // `RowBasedKeyValueBatch`, which gets persisted into the 
checkpoint data
-            valueRow.pointTo(valueRowBuffer, (valueSize / 8) * 8)
+            valueRow.pointTo(originalValueBytes, (originalValueBytes.length / 
8) * 8)
             if (!isValidated) {
               StateStoreProvider.validateStateRowFormat(
                 keyRow, keySchema, valueRow, valueSchema, stateStoreId, 
storeConf)
               isValidated = true
             }
-            map.put(keyRow, valueRow)
+            map.put(keyRow, valueWrapper)
           }
         }
       }
@@ -770,11 +823,28 @@ private[sql] class HDFSBackedStateStoreProvider extends 
StateStoreProvider with
     try {
       rawOutput = fm.createAtomic(targetFile, overwriteIfPossible = true)
       output = compressStream(rawOutput)
-      val iter = map.iterator()
+      // Entry iterator doesn't do verification, we will do it ourselves
+      // Using this instead of iterator() since we want the UnsafeRowWrapper
+      val iter = map.entryIterator()
+
+      // If row checksum is enabled, we will verify every entry in the map 
before writing snapshot,
+      // to prevent writing corrupt rows since the map might have been created 
a while ago.
+      // This is fine since write snapshot is typically done in the background.
+      val verifier = KeyValueIntegrityVerifier
+        .create(stateStoreId_.toString, storeConf.rowChecksumEnabled, 
verificationRatio = 1)
+
       while (iter.hasNext) {
         val entry = iter.next()
-        val keyBytes = entry.key.getBytes()
-        val valueBytes = entry.value.getBytes()
+        verifier.foreach(_.verify(entry.getKey, entry.getValue))
+
+        val keyBytes = entry.getKey.getBytes()
+        val valueBytes = entry.getValue match {
+          case v: StateStoreRowWithChecksum =>
+            // If it has checksum, encode it with the checksum.
+            KeyValueChecksumEncoder.encodeSingleValueRowWithChecksum(
+              v.unsafeRow().getBytes(), v.checksum)
+          case o => o.unsafeRow().getBytes()
+        }
         output.writeInt(keyBytes.size)
         output.write(keyBytes)
         output.writeInt(valueBytes.size)
@@ -831,13 +901,17 @@ private[sql] class HDFSBackedStateStoreProvider extends 
StateStoreProvider with
   */
   private def readSnapshotFile(version: Long): Option[HDFSBackedStateStoreMap] 
= {
     val fileToRead = snapshotFile(version)
-    val map = HDFSBackedStateStoreMap.create(keySchema, numColsPrefixKey)
+    val map = createHDFSBackedStateStoreMap()
     var input: DataInputStream = null
 
     try {
       input = decompressStream(fm.open(fileToRead))
       var eof = false
 
+      // If row checksum is enabled, verify every record in the file to detect 
corrupt rows.
+      val verifier = KeyValueIntegrityVerifier
+        .create(stateStoreId_.toString, storeConf.rowChecksumEnabled, 
verificationRatio = 1)
+
       while (!eof) {
         val keySize = input.readInt()
         if (keySize == -1) {
@@ -860,17 +934,28 @@ private[sql] class HDFSBackedStateStoreProvider extends 
StateStoreProvider with
             val valueRowBuffer = new Array[Byte](valueSize)
             Utils.readFully(input, valueRowBuffer, 0, valueSize)
             val valueRow = new UnsafeRow(valueSchema.fields.length)
+
+            val (originalValueBytes, valueWrapper) = if 
(storeConf.rowChecksumEnabled) {
+              // Checksum is on the value side
+              val (valueBytes, checksum) = KeyValueChecksumEncoder
+                .decodeSingleValueRowWithChecksum(valueRowBuffer)
+              verifier.foreach(_.verify(keyRowBuffer, Some(valueBytes), 
checksum))
+              (valueBytes, StateStoreRowWithChecksum(valueRow, checksum))
+            } else {
+              (valueRowBuffer, StateStoreRow(valueRow))
+            }
+
             // If valueSize in existing file is not multiple of 8, floor it to 
multiple of 8.
             // This is a workaround for the following:
             // Prior to Spark 2.3 mistakenly append 4 bytes to the value row in
             // `RowBasedKeyValueBatch`, which gets persisted into the 
checkpoint data
-            valueRow.pointTo(valueRowBuffer, (valueSize / 8) * 8)
+            valueRow.pointTo(originalValueBytes, (originalValueBytes.length / 
8) * 8)
             if (!isValidated) {
               StateStoreProvider.validateStateRowFormat(
                 keyRow, keySchema, valueRow, valueSchema, stateStoreId, 
storeConf)
               isValidated = true
             }
-            map.put(keyRow, valueRow)
+            map.put(keyRow, valueWrapper)
           }
         }
       }
@@ -887,7 +972,7 @@ private[sql] class HDFSBackedStateStoreProvider extends 
StateStoreProvider with
 
 
   /** Perform a snapshot of the store to allow delta files to be consolidated 
*/
-  private def doSnapshot(opType: String): Unit = {
+  private def doSnapshot(opType: String, throwEx: Boolean = false): Unit = {
     try {
       val ((files, _), e1) = Utils.timeTakenMs(fetchFiles())
       logDebug(s"fetchFiles() took $e1 ms.")
@@ -909,6 +994,9 @@ private[sql] class HDFSBackedStateStoreProvider extends 
StateStoreProvider with
     } catch {
       case NonFatal(e) =>
         logWarning(log"Error doing snapshots", e)
+        if (throwEx) {
+          throw e
+        }
     }
   }
 
@@ -1128,7 +1216,7 @@ private[sql] class HDFSBackedStateStoreProvider extends 
StateStoreProvider with
         throw QueryExecutionErrors.unexpectedStateStoreVersion(endVersion)
       }
 
-      val newMap = HDFSBackedStateStoreMap.create(keySchema, numColsPrefixKey)
+      val newMap = createHDFSBackedStateStoreMap()
       newMap.putAll(constructMapFromSnapshot(snapshotVersion, endVersion))
 
       newMap
@@ -1156,7 +1244,7 @@ private[sql] class HDFSBackedStateStoreProvider extends 
StateStoreProvider with
       }
 
       // Load all the deltas from the version after the start version up to 
the end version.
-      val resultMap = HDFSBackedStateStoreMap.create(keySchema, 
numColsPrefixKey)
+      val resultMap = createHDFSBackedStateStoreMap()
       resultMap.putAll(startVersionMap.get)
       for (deltaVersion <- snapshotVersion + 1 to endVersion) {
         updateFromDeltaFile(deltaVersion, resultMap)
@@ -1171,6 +1259,15 @@ private[sql] class HDFSBackedStateStoreProvider extends 
StateStoreProvider with
     result
   }
 
+  private def createHDFSBackedStateStoreMap(): HDFSBackedStateStoreMap = {
+    val readVerifier = KeyValueIntegrityVerifier.create(
+      stateStoreId_.toString,
+      storeConf.rowChecksumEnabled,
+      storeConf.rowChecksumReadVerificationRatio)
+
+    HDFSBackedStateStoreMap.create(keySchema, numColsPrefixKey, readVerifier)
+  }
+
   override def getStateStoreChangeDataReader(
       startVersion: Long,
       endVersion: Long,
@@ -1189,9 +1286,9 @@ private[sql] class HDFSBackedStateStoreProvider extends 
StateStoreProvider with
       throw StateStoreErrors.multipleColumnFamiliesNotSupported(providerName)
     }
 
-    new HDFSBackedStateStoreChangeDataReader(fm, baseDir, startVersion, 
endVersion,
+    new HDFSBackedStateStoreChangeDataReader(stateStoreId_, fm, baseDir, 
startVersion, endVersion,
       CompressionCodec.createCodec(sparkConf, storeConf.compressionCodec),
-      keySchema, valueSchema)
+      keySchema, valueSchema, storeConf)
   }
 
   /** Reports to the coordinator the store's latest snapshot version */
@@ -1207,15 +1304,17 @@ private[sql] class HDFSBackedStateStoreProvider extends 
StateStoreProvider with
 
 /** [[StateStoreChangeDataReader]] implementation for 
[[HDFSBackedStateStoreProvider]] */
 class HDFSBackedStateStoreChangeDataReader(
+    storeId: StateStoreId,
     fm: CheckpointFileManager,
     stateLocation: Path,
     startVersion: Long,
     endVersion: Long,
     compressionCodec: CompressionCodec,
     keySchema: StructType,
-    valueSchema: StructType)
+    valueSchema: StructType,
+    storeConf: StateStoreConf)
   extends StateStoreChangeDataReader(
-    fm, stateLocation, startVersion, endVersion, compressionCodec) {
+    storeId, fm, stateLocation, startVersion, endVersion, compressionCodec, 
storeConf) {
 
   override protected val changelogSuffix: String = "delta"
 
@@ -1226,16 +1325,32 @@ class HDFSBackedStateStoreChangeDataReader(
     }
     val (recordType, keyArray, valueArray) = reader.next()
     val keyRow = new UnsafeRow(keySchema.fields.length)
-    keyRow.pointTo(keyArray, keyArray.length)
     if (valueArray == null) {
+      val originalKeyBytes = if (storeConf.rowChecksumEnabled) {
+        // Decode the original key and remove the checksum part
+        
KeyValueChecksumEncoder.decodeAndVerifyKeyRowWithChecksum(readVerifier, 
keyArray)
+      } else {
+        keyArray
+      }
+      keyRow.pointTo(originalKeyBytes, originalKeyBytes.length)
       (recordType, keyRow, null, currentChangelogVersion - 1)
     } else {
+      keyRow.pointTo(keyArray, keyArray.length)
+
       val valueRow = new UnsafeRow(valueSchema.fields.length)
+      val originalValueBytes = if (storeConf.rowChecksumEnabled) {
+        // Checksum is on the value side
+        KeyValueChecksumEncoder.decodeAndVerifySingleValueRowWithChecksum(
+          readVerifier, keyArray, valueArray)
+      } else {
+        valueArray
+      }
+
       // If valueSize in existing file is not multiple of 8, floor it to 
multiple of 8.
       // This is a workaround for the following:
       // Prior to Spark 2.3 mistakenly append 4 bytes to the value row in
       // `RowBasedKeyValueBatch`, which gets persisted into the checkpoint data
-      valueRow.pointTo(valueArray, (valueArray.length / 8) * 8)
+      valueRow.pointTo(originalValueBytes, (originalValueBytes.length / 8) * 8)
       (recordType, keyRow, valueRow, currentChangelogVersion - 1)
     }
   }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
index f8570c583387..b1c9dee5a459 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
@@ -236,6 +236,12 @@ class RocksDB(
 
   private val shouldForceSnapshot: AtomicBoolean = new AtomicBoolean(false)
 
+  // Integrity verifier that is only used when clients read from the db e.g. 
db.get()
+  private val readVerifier: Option[KeyValueIntegrityVerifier] = 
KeyValueIntegrityVerifier.create(
+    loggingId,
+    conf.rowChecksumEnabled,
+    conf.rowChecksumReadVerificationRatio)
+
   private def getColumnFamilyInfo(cfName: String): ColumnFamilyInfo = {
     colFamilyNameToInfoMap.get(cfName)
   }
@@ -869,31 +875,26 @@ class RocksDB(
       try {
         changelogReader = fileManager.getChangelogReader(v, uniqueId)
 
-        if (useColumnFamilies) {
-          changelogReader.foreach { case (recordType, key, value) =>
-            recordType match {
-              case RecordType.PUT_RECORD =>
-                put(key, value, includesPrefix = true, deriveCfName = true)
-
-              case RecordType.DELETE_RECORD =>
-                remove(key, includesPrefix = true, deriveCfName = true)
-
-              case RecordType.MERGE_RECORD =>
-                merge(key, value, includesPrefix = true, deriveCfName = true)
-            }
-          }
-        } else {
-          changelogReader.foreach { case (recordType, key, value) =>
-            recordType match {
-              case RecordType.PUT_RECORD =>
-                put(key, value)
-
-              case RecordType.DELETE_RECORD =>
-                remove(key)
-
-              case RecordType.MERGE_RECORD =>
-                merge(key, value)
-            }
+        // If row checksum is enabled, verify every record in the changelog 
file
+        val kvVerifier = KeyValueIntegrityVerifier
+          .create(loggingId, conf.rowChecksumEnabled, verificationRatio = 1)
+
+        changelogReader.foreach { case (recordType, key, value) =>
+          recordType match {
+            case RecordType.PUT_RECORD =>
+              verifyChangelogRecord(kvVerifier, key, Some(value))
+              put(key, value, includesPrefix = useColumnFamilies,
+                deriveCfName = useColumnFamilies, includesChecksum = 
conf.rowChecksumEnabled)
+
+            case RecordType.DELETE_RECORD =>
+              verifyChangelogRecord(kvVerifier, key, None)
+              remove(key, includesPrefix = useColumnFamilies,
+                deriveCfName = useColumnFamilies, includesChecksum = 
conf.rowChecksumEnabled)
+
+            case RecordType.MERGE_RECORD =>
+              verifyChangelogRecord(kvVerifier, key, Some(value))
+              merge(key, value, includesPrefix = useColumnFamilies,
+                deriveCfName = useColumnFamilies, includesChecksum = 
conf.rowChecksumEnabled)
           }
         }
       } finally {
@@ -908,6 +909,28 @@ class RocksDB(
     )
   }
 
+  private def verifyChangelogRecord(
+      verifier: Option[KeyValueIntegrityVerifier],
+      keyBytes: Array[Byte],
+      valueBytes: Option[Array[Byte]]): Unit = {
+    verifier match {
+      case Some(v) if v.isInstanceOf[KeyValueChecksumVerifier] =>
+        // Do checksum verification inline using array index without copying 
bytes
+        valueBytes.map { value =>
+          // Checksum is on the value side for PUT/MERGE record
+          val (valueIndex, checksum) = KeyValueChecksumEncoder
+            .decodeOneValueRowIndexWithChecksum(value)
+          v.verify(ArrayIndexRange(keyBytes, 0, keyBytes.length), 
Some(valueIndex), checksum)
+        }.getOrElse {
+          // For DELETE valueBytes is None, we only check the key
+          val (keyIndex, checksum) = KeyValueChecksumEncoder
+            .decodeKeyRowIndexWithChecksum(keyBytes)
+          v.verify(keyIndex, None, checksum)
+        }
+      case _ =>
+    }
+  }
+
   /**
    * Function to encode state row with virtual col family id prefix
    * @param data - passed byte array to be stored in state store
@@ -942,13 +965,41 @@ class RocksDB(
       key: Array[Byte],
       cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Array[Byte] = {
     updateMemoryUsageIfNeeded()
+    val (finalKey, value) = getValue(key, cfName)
+    if (conf.rowChecksumEnabled && value != null) {
+      KeyValueChecksumEncoder.decodeAndVerifyValueRowWithChecksum(
+        readVerifier, finalKey, value)
+    } else {
+      value
+    }
+  }
+
+  /**
+   * Get the values for a given key if present, that were merged (via merge).
+   * This returns the values as an iterator of index range, to allow inline 
access
+   * of each value bytes without copying, for better performance.
+   * Note: This method is currently only supported when row checksum is 
enabled.
+   * */
+  def multiGet(
+      key: Array[Byte],
+      cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME): 
Iterator[ArrayIndexRange[Byte]] = {
+    assert(conf.rowChecksumEnabled, "multiGet is only allowed when row 
checksum is enabled")
+    updateMemoryUsageIfNeeded()
+
+    val (finalKey, value) = getValue(key, cfName)
+    KeyValueChecksumEncoder.decodeAndVerifyMultiValueRowWithChecksum(
+      readVerifier, finalKey, value)
+  }
+
+  /** Returns a tuple of the final key used to store the value in the db and 
the value. */
+  private def getValue(key: Array[Byte], cfName: String): (Array[Byte], 
Array[Byte]) = {
     val keyWithPrefix = if (useColumnFamilies) {
       encodeStateRowWithPrefix(key, cfName)
     } else {
       key
     }
 
-    db.get(readOptions, keyWithPrefix)
+    (keyWithPrefix, db.get(readOptions, keyWithPrefix))
   }
 
   /**
@@ -1009,7 +1060,8 @@ class RocksDB(
       value: Array[Byte],
       cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME,
       includesPrefix: Boolean = false,
-      deriveCfName: Boolean = false): Unit = {
+      deriveCfName: Boolean = false,
+      includesChecksum: Boolean = false): Unit = {
     updateMemoryUsageIfNeeded()
     val keyWithPrefix = if (useColumnFamilies && !includesPrefix) {
       encodeStateRowWithPrefix(key, cfName)
@@ -1024,27 +1076,45 @@ class RocksDB(
       cfName
     }
 
+    val valueWithChecksum = if (conf.rowChecksumEnabled && !includesChecksum) {
+      KeyValueChecksumEncoder.encodeValueRowWithChecksum(value,
+        KeyValueChecksum.create(keyWithPrefix, Some(value)))
+    } else {
+      value
+    }
+
     handleMetricsUpdate(keyWithPrefix, columnFamilyName, isPutOrMerge = true)
-    db.put(writeOptions, keyWithPrefix, value)
-    changelogWriter.foreach(_.put(keyWithPrefix, value))
+    db.put(writeOptions, keyWithPrefix, valueWithChecksum)
+    changelogWriter.foreach(_.put(keyWithPrefix, valueWithChecksum))
   }
 
   /**
    * Convert the given list of value row bytes into a single byte array. The 
returned array
    * bytes supports additional values to be later merged to it.
    */
-  private def getListValuesInArrayByte(values: List[Array[Byte]]): Array[Byte] 
= {
+  private def getListValuesInArrayByte(
+      keyWithPrefix: Array[Byte],
+      values: List[Array[Byte]],
+      includesChecksum: Boolean): Array[Byte] = {
+    val valueWithChecksum = if (conf.rowChecksumEnabled && !includesChecksum) {
+      values.map { value =>
+        KeyValueChecksumEncoder.encodeValueRowWithChecksum(value,
+          KeyValueChecksum.create(keyWithPrefix, Some(value)))
+      }
+    } else {
+      values
+    }
     // Delimit each value row bytes with a single byte delimiter, the last
     // value row won't have a delimiter at the end.
-    val delimiterNum = values.length - 1
-    // The bytes in values already include the bytes length prefix
-    val totalSize = values.map(_.length).sum +
+    val delimiterNum = valueWithChecksum.length - 1
+    // The bytes in valueWithChecksum already include the bytes length prefix
+    val totalSize = valueWithChecksum.map(_.length).sum +
       delimiterNum // for each delimiter
 
     val result = new Array[Byte](totalSize)
     var pos = Platform.BYTE_ARRAY_OFFSET
 
-    values.zipWithIndex.foreach { case (rowBytes, idx) =>
+    valueWithChecksum.zipWithIndex.foreach { case (rowBytes, idx) =>
       // Write the data
       Platform.copyMemory(rowBytes, Platform.BYTE_ARRAY_OFFSET, result, pos, 
rowBytes.length)
       pos += rowBytes.length
@@ -1069,6 +1139,7 @@ class RocksDB(
       values: List[Array[Byte]],
       cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME,
       includesPrefix: Boolean = false,
+      includesChecksum: Boolean = false,
       deriveCfName: Boolean = false): Unit = {
     updateMemoryUsageIfNeeded()
     val keyWithPrefix = if (useColumnFamilies && !includesPrefix) {
@@ -1077,7 +1148,7 @@ class RocksDB(
       key
     }
 
-    val valuesInArrayByte = getListValuesInArrayByte(values)
+    val valuesInArrayByte = getListValuesInArrayByte(keyWithPrefix, values, 
includesChecksum)
 
     val columnFamilyName = if (deriveCfName && useColumnFamilies) {
       val (_, cfName) = decodeStateRowWithPrefix(keyWithPrefix)
@@ -1108,7 +1179,8 @@ class RocksDB(
       value: Array[Byte],
       cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME,
       includesPrefix: Boolean = false,
-      deriveCfName: Boolean = false): Unit = {
+      deriveCfName: Boolean = false,
+      includesChecksum: Boolean = false): Unit = {
     updateMemoryUsageIfNeeded()
     val keyWithPrefix = if (useColumnFamilies && !includesPrefix) {
       encodeStateRowWithPrefix(key, cfName)
@@ -1123,9 +1195,16 @@ class RocksDB(
       cfName
     }
 
+    val valueWithChecksum = if (conf.rowChecksumEnabled && !includesChecksum) {
+      KeyValueChecksumEncoder.encodeValueRowWithChecksum(value,
+        KeyValueChecksum.create(keyWithPrefix, Some(value)))
+    } else {
+      value
+    }
+
     handleMetricsUpdate(keyWithPrefix, columnFamilyName, isPutOrMerge = true)
-    db.merge(writeOptions, keyWithPrefix, value)
-    changelogWriter.foreach(_.merge(keyWithPrefix, value))
+    db.merge(writeOptions, keyWithPrefix, valueWithChecksum)
+    changelogWriter.foreach(_.merge(keyWithPrefix, valueWithChecksum))
   }
 
   /**
@@ -1139,6 +1218,7 @@ class RocksDB(
       values: List[Array[Byte]],
       cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME,
       includesPrefix: Boolean = false,
+      includesChecksum: Boolean = false,
       deriveCfName: Boolean = false): Unit = {
     updateMemoryUsageIfNeeded()
     val keyWithPrefix = if (useColumnFamilies && !includesPrefix) {
@@ -1154,7 +1234,7 @@ class RocksDB(
       cfName
     }
 
-    val valueInArrayByte = getListValuesInArrayByte(values)
+    val valueInArrayByte = getListValuesInArrayByte(keyWithPrefix, values, 
includesChecksum)
 
     handleMetricsUpdate(keyWithPrefix, columnFamilyName, isPutOrMerge = true)
     db.merge(writeOptions, keyWithPrefix, valueInArrayByte)
@@ -1169,14 +1249,23 @@ class RocksDB(
       key: Array[Byte],
       cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME,
       includesPrefix: Boolean = false,
-      deriveCfName: Boolean = false): Unit = {
+      deriveCfName: Boolean = false,
+      includesChecksum: Boolean = false): Unit = {
     updateMemoryUsageIfNeeded()
-    val keyWithPrefix = if (useColumnFamilies && !includesPrefix) {
-      encodeStateRowWithPrefix(key, cfName)
+    val originalKey = if (conf.rowChecksumEnabled && includesChecksum) {
+      // When we are replaying changelog record, the delete key in the file 
includes checksum.
+      // Remove the checksum, so we use the original key for db.delete.
+      KeyValueChecksumEncoder.decodeKeyRowWithChecksum(key)._1
     } else {
       key
     }
 
+    val keyWithPrefix = if (useColumnFamilies && !includesPrefix) {
+      encodeStateRowWithPrefix(originalKey, cfName)
+    } else {
+      originalKey
+    }
+
     val columnFamilyName = if (deriveCfName && useColumnFamilies) {
       val (_, cfName) = decodeStateRowWithPrefix(keyWithPrefix)
       cfName
@@ -1186,7 +1275,18 @@ class RocksDB(
 
     handleMetricsUpdate(keyWithPrefix, columnFamilyName, isPutOrMerge = false)
     db.delete(writeOptions, keyWithPrefix)
-    changelogWriter.foreach(_.delete(keyWithPrefix))
+    changelogWriter match {
+      case Some(writer) =>
+        val keyWithChecksum = if (conf.rowChecksumEnabled) {
+          // For delete, we will write a checksum with the key row only to the 
changelog file.
+          KeyValueChecksumEncoder.encodeKeyRowWithChecksum(keyWithPrefix,
+            KeyValueChecksum.create(keyWithPrefix, None))
+        } else {
+          keyWithPrefix
+        }
+        writer.delete(keyWithChecksum)
+      case None => // During changelog replay, there is no changelog writer.
+    }
   }
 
   /**
@@ -1213,7 +1313,14 @@ class RocksDB(
             iter.key
           }
 
-          byteArrayPair.set(key, iter.value)
+          val value = if (conf.rowChecksumEnabled) {
+            KeyValueChecksumEncoder.decodeAndVerifyValueRowWithChecksum(
+              readVerifier, iter.key, iter.value)
+          } else {
+            iter.value
+          }
+
+          byteArrayPair.set(key, value)
           iter.next()
           byteArrayPair
         } else {
@@ -1304,7 +1411,14 @@ class RocksDB(
             iter.key
           }
 
-          byteArrayPair.set(key, iter.value)
+          val value = if (conf.rowChecksumEnabled) {
+            KeyValueChecksumEncoder.decodeAndVerifyValueRowWithChecksum(
+              readVerifier, iter.key, iter.value)
+          } else {
+            iter.value
+          }
+
+          byteArrayPair.set(key, value)
           iter.next()
           byteArrayPair
         } else {
@@ -2104,8 +2218,10 @@ case class RocksDBConf(
     allowFAllocate: Boolean,
     compression: String,
     reportSnapshotUploadLag: Boolean,
-    fileChecksumEnabled: Boolean,
     maxVersionsToDeletePerMaintenance: Int,
+    fileChecksumEnabled: Boolean,
+    rowChecksumEnabled: Boolean,
+    rowChecksumReadVerificationRatio: Long,
     stateStoreConf: StateStoreConf)
 
 object RocksDBConf {
@@ -2305,8 +2421,10 @@ object RocksDBConf {
       getBooleanConf(ALLOW_FALLOCATE_CONF),
       getStringConf(COMPRESSION_CONF),
       storeConf.reportSnapshotUploadLag,
-      storeConf.checkpointFileChecksumEnabled,
       storeConf.maxVersionsToDeletePerMaintenance,
+      storeConf.checkpointFileChecksumEnabled,
+      storeConf.rowChecksumEnabled,
+      storeConf.rowChecksumReadVerificationRatio,
       storeConf)
   }
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
index f49c79f96b9c..b4410362c4d3 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
@@ -54,6 +54,7 @@ sealed trait RocksDBValueStateEncoder {
   def encodeValue(row: UnsafeRow): Array[Byte]
   def decodeValue(valueBytes: Array[Byte]): UnsafeRow
   def decodeValues(valueBytes: Array[Byte]): Iterator[UnsafeRow]
+  def decodeValues(valueBytesIterator: Iterator[ArrayIndexRange[Byte]]): 
Iterator[UnsafeRow]
 }
 
 trait StateSchemaProvider extends Serializable {
@@ -1785,6 +1786,40 @@ class MultiValuedStateEncoder(
     }
   }
 
+  /** Takes in an iterator of [[ArrayIndexRange]], each index range presents 
the range of bytes
+   * for the current value in the underlying array. */
+  override def decodeValues(
+      valueBytesIterator: Iterator[ArrayIndexRange[Byte]]): 
Iterator[UnsafeRow] = {
+    if (valueBytesIterator == null) {
+      Seq().iterator
+    } else {
+      new Iterator[UnsafeRow] {
+        override def hasNext: Boolean = valueBytesIterator.hasNext
+
+        override def next(): UnsafeRow = {
+          // Get the index range of the next value
+          val valueBytesIndex = valueBytesIterator.next()
+          val allValuesBytes = valueBytesIndex.array
+          // convert array index to memory offset
+          var pos = valueBytesIndex.fromIndex + Platform.BYTE_ARRAY_OFFSET
+          // Get value length
+          val numBytes = Platform.getInt(allValuesBytes, pos)
+          pos += java.lang.Integer.BYTES
+
+          // Extract the bytes for this value
+          val encodedValue = new Array[Byte](numBytes)
+          Platform.copyMemory(
+            allValuesBytes, pos,
+            encodedValue, Platform.BYTE_ARRAY_OFFSET,
+            numBytes
+          )
+
+          dataEncoder.decodeValue(encodedValue)
+        }
+      }
+    }
+  }
+
   override def supportsMultipleValuesPerKey: Boolean = true
 }
 
@@ -1824,4 +1859,9 @@ class SingleValueStateEncoder(
   override def decodeValues(valueBytes: Array[Byte]): Iterator[UnsafeRow] = {
     throw new IllegalStateException("This encoder doesn't support multiple 
values!")
   }
+
+  override def decodeValues(
+      valueBytesIterator: Iterator[ArrayIndexRange[Byte]]): 
Iterator[UnsafeRow] = {
+    throw new IllegalStateException("This encoder doesn't support multiple 
values!")
+  }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
index e01e1e0f86ca..88dad93b5d15 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
@@ -273,8 +273,14 @@ private[sql] class RocksDBStateStoreProvider
       verify(valueEncoder.supportsMultipleValuesPerKey, "valuesIterator 
requires a encoder " +
       "that supports multiple values for a single key.")
 
-      val encodedValues = rocksDB.get(keyEncoder.encodeKey(key), colFamilyName)
-      valueEncoder.decodeValues(encodedValues)
+      if (storeConf.rowChecksumEnabled) {
+        // multiGet provides better perf for row checksum, since it avoids 
copying values
+        val encodedValuesIterator = 
rocksDB.multiGet(keyEncoder.encodeKey(key), colFamilyName)
+        valueEncoder.decodeValues(encodedValuesIterator)
+      } else {
+        val encodedValues = rocksDB.get(keyEncoder.encodeKey(key), 
colFamilyName)
+        valueEncoder.decodeValues(encodedValues)
+      }
     }
 
     override def merge(key: UnsafeRow, value: UnsafeRow,
@@ -943,6 +949,7 @@ private[sql] class RocksDBStateStoreProvider
     val statePath = stateStoreId.storeCheckpointLocation()
     val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf)
     new RocksDBStateStoreChangeDataReader(
+      stateStoreId,
       CheckpointFileManager.create(statePath, hadoopConf),
       rocksDB,
       statePath,
@@ -951,6 +958,7 @@ private[sql] class RocksDBStateStoreProvider
       endVersionStateStoreCkptId,
       CompressionCodec.createCodec(sparkConf, storeConf.compressionCodec),
       keyValueEncoderMap,
+      storeConf,
       colFamilyNameOpt)
   }
 
@@ -1289,6 +1297,7 @@ object RocksDBStateStoreProvider {
 
 /** [[StateStoreChangeDataReader]] implementation for 
[[RocksDBStateStoreProvider]] */
 class RocksDBStateStoreChangeDataReader(
+    storeId: StateStoreId,
     fm: CheckpointFileManager,
     rocksDB: RocksDB,
     stateLocation: Path,
@@ -1298,9 +1307,11 @@ class RocksDBStateStoreChangeDataReader(
     compressionCodec: CompressionCodec,
     keyValueEncoderMap:
       ConcurrentHashMap[String, (RocksDBKeyStateEncoder, 
RocksDBValueStateEncoder, Short)],
+    storeConf: StateStoreConf,
     colFamilyNameOpt: Option[String] = None)
   extends StateStoreChangeDataReader(
-    fm, stateLocation, startVersion, endVersion, compressionCodec, 
colFamilyNameOpt) {
+    storeId, fm, stateLocation, startVersion, endVersion, compressionCodec,
+    storeConf, colFamilyNameOpt) {
 
   override protected val versionsAndUniqueIds: Array[(Long, Option[String])] =
     if (endVersionStateStoreCkptId.isDefined) {
@@ -1336,15 +1347,30 @@ class RocksDBStateStoreChangeDataReader(
         }
 
         val nextRecord = reader.next()
+        val keyBytes = if (storeConf.rowChecksumEnabled
+          && nextRecord._1 == RecordType.DELETE_RECORD) {
+          // remove checksum and decode to the original key
+          KeyValueChecksumEncoder
+            .decodeAndVerifyKeyRowWithChecksum(readVerifier, nextRecord._2)
+        } else {
+          nextRecord._2
+        }
         val colFamilyIdBytes: Array[Byte] =
           RocksDBStateStoreProvider.getColumnFamilyIdAsBytes(currEncoder._3)
         val endIndex = colFamilyIdBytes.size
         // Function checks for byte arrays being equal
         // from index 0 to endIndex - 1 (both inclusive)
-        if (java.util.Arrays.equals(nextRecord._2, 0, endIndex,
+        if (java.util.Arrays.equals(keyBytes, 0, endIndex,
           colFamilyIdBytes, 0, endIndex)) {
-          val extractedKey = 
RocksDBStateStoreProvider.decodeStateRowWithPrefix(nextRecord._2)
-          val result = (nextRecord._1, extractedKey, nextRecord._3)
+          val valueBytes = if (storeConf.rowChecksumEnabled &&
+            nextRecord._1 != RecordType.DELETE_RECORD) {
+            KeyValueChecksumEncoder.decodeAndVerifyValueRowWithChecksum(
+              readVerifier, keyBytes, nextRecord._3)
+          } else {
+            nextRecord._3
+          }
+          val extractedKey = 
RocksDBStateStoreProvider.decodeStateRowWithPrefix(keyBytes)
+          val result = (nextRecord._1, extractedKey, valueBytes)
           currRecord = result
         }
       }
@@ -1353,7 +1379,21 @@ class RocksDBStateStoreChangeDataReader(
       if (reader == null) {
         return null
       }
-      currRecord = reader.next()
+      val nextRecord = reader.next()
+      currRecord = if (storeConf.rowChecksumEnabled) {
+        nextRecord._1 match {
+          case RecordType.DELETE_RECORD =>
+            val key = KeyValueChecksumEncoder
+              .decodeAndVerifyKeyRowWithChecksum(readVerifier, nextRecord._2)
+            (nextRecord._1, key, nextRecord._3)
+          case _ =>
+            val value = 
KeyValueChecksumEncoder.decodeAndVerifyValueRowWithChecksum(
+              readVerifier, nextRecord._2, nextRecord._3)
+            (nextRecord._1, nextRecord._2, value)
+        }
+      } else {
+        nextRecord
+      }
     }
 
     val keyRow = currEncoder._1.decodeKey(currRecord._2)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
index 1891dd3befa4..25727b73c3d3 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
@@ -1636,4 +1636,3 @@ object StateStore extends Logging {
     }
   }
 }
-
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala
index bd56f72489f9..2029c0988756 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala
@@ -619,6 +619,7 @@ class StateStoreChangelogReaderV4(
  * store. In each iteration, it will return a tuple of (changeType: 
[[RecordType]],
  * nested key: [[UnsafeRow]], nested value: [[UnsafeRow]], batchId: [[Long]])
  *
+ * @param storeId id of the state store
  * @param fm checkpoint file manager used to manage streaming query checkpoint
  * @param stateLocation location of the state store
  * @param startVersion start version of the changelog file to read
@@ -627,17 +628,24 @@ class StateStoreChangelogReaderV4(
  * @param colFamilyNameOpt optional column family name to read from
  */
 abstract class StateStoreChangeDataReader(
+    storeId: StateStoreId,
     fm: CheckpointFileManager,
     stateLocation: Path,
     startVersion: Long,
     endVersion: Long,
     compressionCodec: CompressionCodec,
+    storeConf: StateStoreConf,
     colFamilyNameOpt: Option[String] = None)
   extends NextIterator[(RecordType.Value, UnsafeRow, UnsafeRow, Long)] with 
Logging {
 
   assert(startVersion >= 1)
   assert(endVersion >= startVersion)
 
+  protected val readVerifier: Option[KeyValueIntegrityVerifier] = 
KeyValueIntegrityVerifier.create(
+    storeId.toString,
+    storeConf.rowChecksumEnabled,
+    storeConf.rowChecksumReadVerificationRatio)
+
   /**
    * Iterator that iterates over the changelog files in the state store.
    */
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
index a765f52a2272..ebb212512ccb 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
@@ -134,6 +134,12 @@ class StateStoreConf(
   /** Whether to unload the store on task completion. */
   val unloadOnCommit = sqlConf.stateStoreUnloadOnCommit
 
+  /** whether to enable checksum for state store rows. */
+  val rowChecksumEnabled = sqlConf.stateStoreRowChecksumEnabled
+
+  /** How often should we do row checksum verification when rows are read from 
the state store. */
+  val rowChecksumReadVerificationRatio: Long = 
sqlConf.stateStoreRowChecksumReadVerificationRatio
+
   /** The version of the state store checkpoint format. */
   val stateStoreCheckpointFormatVersion: Int = 
sqlConf.stateStoreCheckpointFormatVersion
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala
index 23bb54d86348..6d211fb6fc0a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala
@@ -104,6 +104,13 @@ object StateStoreErrors {
     new StateStoreInvalidStamp(providedStamp, currentStamp)
   }
 
+  def rowChecksumVerificationFailed(
+      storeId: String,
+      expectedChecksum: Int,
+      computedChecksum: Int): StateStoreRowChecksumVerificationFailed = {
+    new StateStoreRowChecksumVerificationFailed(storeId, expectedChecksum, 
computedChecksum)
+  }
+
   def incorrectNumOrderingColsForRangeScan(numOrderingCols: String):
     StateStoreIncorrectNumOrderingColsForRangeScan = {
     new StateStoreIncorrectNumOrderingColsForRangeScan(numOrderingCols)
@@ -586,6 +593,18 @@ class StateStoreCommitValidationFailed(
     )
   )
 
+class StateStoreRowChecksumVerificationFailed(
+    storeId: String,
+    expectedChecksum: Int,
+    computedChecksum: Int)
+  extends SparkException(
+    errorClass = "STATE_STORE_ROW_CHECKSUM_VERIFICATION_FAILED",
+    messageParameters = Map(
+      "stateStoreId" -> storeId,
+      "expectedChecksum" -> expectedChecksum.toString,
+      "computedChecksum" -> computedChecksum.toString),
+    cause = null)
+
 class StateStoreUnexpectedEmptyFileInRocksDBZip(fileName: String, zipFileName: 
String)
   extends SparkException(
     errorClass = "STATE_STORE_UNEXPECTED_EMPTY_FILE_IN_ROCKSDB_ZIP",
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRow.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRow.scala
new file mode 100644
index 000000000000..3268beff1167
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRow.scala
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state
+
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+
+/**
+ * [[UnsafeRow]] is immutable and this allows wrapping it with additional info
+ * without having to copy/modify it.
+ * */
+trait UnsafeRowWrapper {
+  def unsafeRow(): UnsafeRow
+}
+
+/**
+ * A simple wrapper over the state store [[UnsafeRow]]
+ * */
+class StateStoreRow(row: UnsafeRow) extends UnsafeRowWrapper {
+  override def unsafeRow(): UnsafeRow = row
+}
+
+object StateStoreRow {
+  def apply(row: UnsafeRow): StateStoreRow = new StateStoreRow(row)
+}
+
+/**
+ * This is used to represent a range of indices in an array. Useful when we 
want to operate on a
+ * subset of an array without copying it.
+ *
+ * @param array The underlying array.
+ * @param fromIndex The starting index.
+ * @param untilIndex The end index (exclusive).
+ * */
+case class ArrayIndexRange[T](array: Array[T], fromIndex: Int, untilIndex: 
Int) {
+  // When fromIndex == untilIndex, it is an empty array
+  assert(fromIndex >= 0 && fromIndex <= untilIndex,
+    s"Invalid range: fromIndex ($fromIndex) should be >= 0 and <= untilIndex 
($untilIndex)")
+
+  /** The number of elements in the range. */
+  def length: Int = untilIndex - fromIndex
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRowChecksum.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRowChecksum.scala
new file mode 100644
index 000000000000..f02a598600c3
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRowChecksum.scala
@@ -0,0 +1,536 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state
+
+import java.util.zip.CRC32C
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.LogKeys.{CHECKSUM, STATE_STORE_ID}
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.unsafe.Platform
+
+/**
+ * A State store row and wrapper for [[UnsafeRow]] that includes the row 
checksum.
+ *
+ * @param row The UnsafeRow to be wrapped.
+ * @param rowChecksum The checksum of the row.
+ */
+class StateStoreRowWithChecksum(row: UnsafeRow, rowChecksum: Int) extends 
StateStoreRow(row) {
+  def checksum: Int = rowChecksum
+}
+
+object StateStoreRowWithChecksum {
+  def apply(row: UnsafeRow, rowChecksum: Int): StateStoreRowWithChecksum = {
+    new StateStoreRowWithChecksum(row, rowChecksum)
+  }
+}
+
+/**
+ * Integrity verifier for a State store key and value pair.
+ * To ensure the key and value in the state store isn't corrupt.
+ * */
+trait KeyValueIntegrityVerifier {
+  def verify(key: UnsafeRow, value: UnsafeRowWrapper): Unit
+  def verify(keyBytes: Array[Byte], valueBytes: Option[Array[Byte]], 
expectedChecksum: Int): Unit
+  def verify(
+      keyBytes: ArrayIndexRange[Byte],
+      valueBytes: Option[ArrayIndexRange[Byte]],
+      expectedChecksum: Int): Unit
+}
+
+object KeyValueIntegrityVerifier {
+  def create(
+      storeId: String,
+      rowChecksumEnabled: Boolean,
+      verificationRatio: Long): Option[KeyValueIntegrityVerifier] = {
+    assert(verificationRatio >= 0, "Verification ratio must be non-negative")
+    if (rowChecksumEnabled && verificationRatio != 0) {
+      Some(KeyValueChecksumVerifier(storeId, verificationRatio))
+    } else {
+      None
+    }
+  }
+}
+
+/**
+ * Checksum based key value verifier. Computes the checksum of the key and 
value bytes
+ * and compares it to the expected checksum. This also supports rate limiting 
the number of
+ * verification performed via the [[verificationRatio]] parameter. Given that 
the verify method
+ * can be called many times for different key and value pairs, this can be 
used to reduce the number
+ * of key-value verified.
+ *
+ * NOTE: Not thread safe and expected to be accessed one thread at a time.
+ *
+ * @param storeId The id of the state store, used for logging purpose.
+ * @param verificationRatio How often should verification occur.
+ *                          Setting a value N means for every N verify call.
+ * */
+class KeyValueChecksumVerifier(
+    storeId: String,
+    verificationRatio: Long) extends KeyValueIntegrityVerifier with Logging {
+  assert(verificationRatio > 0, "Verification ratio must be greater than 0")
+
+  // Number of verification requests received via verify method call
+  @volatile private var numRequests = 0L
+  // Number of checksum verification performed
+  @volatile private var numVerified = 0L
+  def getNumRequests: Long = numRequests
+  def getNumVerified: Long = numVerified
+
+  /** [[verify]] methods call this to see if they should proceed */
+  private def shouldVerify(): Boolean = {
+    numRequests += 1
+    if (numRequests % verificationRatio == 0) {
+      numVerified += 1
+      true
+    } else {
+      false
+    }
+  }
+
+  /**
+   * Computes the checksum of the key and value row, and compares it with the
+   * expected checksum included in the [[UnsafeRowWrapper]]. Might not do any
+   * checksum verification, depending on the [[verificationRatio]].
+   *
+   * @param key The key row.
+   * @param value The value row.
+   * */
+  override def verify(key: UnsafeRow, value: UnsafeRowWrapper): Unit = {
+    assert(value.isInstanceOf[StateStoreRowWithChecksum],
+      s"Expected StateStoreRowWithChecksum, but got ${value.getClass.getName}")
+
+    if (shouldVerify()) {
+      val valWithChecksum = value.asInstanceOf[StateStoreRowWithChecksum]
+      val expected = valWithChecksum.checksum
+      val computed = KeyValueChecksum.create(key, 
Some(valWithChecksum.unsafeRow()))
+      verifyChecksum(expected, computed)
+    }
+  }
+
+  /**
+   * Computes the checksum of the key and value bytes (if present), and 
compares it with the
+   * expected checksum specified. Might not do any checksum verification,
+   * depending on the [[verificationRatio]].
+   *
+   * @param keyBytes The key bytes.
+   * @param valueBytes Optional value bytes.
+   * @param expectedChecksum The expected checksum value.
+   * */
+  override def verify(
+      keyBytes: Array[Byte],
+      valueBytes: Option[Array[Byte]],
+      expectedChecksum: Int): Unit = {
+    if (shouldVerify()) {
+      val computed = KeyValueChecksum.create(keyBytes, valueBytes)
+      verifyChecksum(expectedChecksum, computed)
+    }
+  }
+
+  /**
+   * Computes the checksum of the key and value bytes (if present), and 
compares it with the
+   * expected checksum specified. Might not do any checksum verification,
+   * depending on the [[verificationRatio]].
+   *
+   * @param keyBytes Specifies the index range of the key bytes in the 
underlying array.
+   * @param valueBytes Optional, specifies the index range of the value bytes
+   *                   in the underlying array.
+   * @param expectedChecksum The expected checksum value.
+   * */
+  override def verify(
+      keyBytes: ArrayIndexRange[Byte],
+      valueBytes: Option[ArrayIndexRange[Byte]],
+      expectedChecksum: Int): Unit = {
+    if (shouldVerify()) {
+      val computed = KeyValueChecksum.create(keyBytes, valueBytes)
+      verifyChecksum(expectedChecksum, computed)
+    }
+  }
+
+  private def verifyChecksum(expected: Int, computed: Int): Unit = {
+    logDebug(s"Verifying row checksum, expected: $expected, computed: 
$computed")
+    if (expected != computed) {
+      logError(log"Row checksum verification failed for store 
${MDC(STATE_STORE_ID, storeId)}, " +
+        log"Expected checksum: ${MDC(CHECKSUM, expected)}, " +
+        log"Computed checksum: ${MDC(CHECKSUM, computed)}")
+      throw StateStoreErrors.rowChecksumVerificationFailed(storeId, expected, 
computed)
+    }
+  }
+}
+
+object KeyValueChecksumVerifier {
+  def apply(storeId: String, verificationRatio: Long): 
KeyValueChecksumVerifier = {
+    new KeyValueChecksumVerifier(storeId, verificationRatio)
+  }
+}
+
+/** For Key Value checksum creation */
+object KeyValueChecksum {
+  /**
+   * Creates a checksum value using the bytes of the key and value row.
+   * If value row isn't specified, will create using key row bytes only.
+   * */
+  def create(keyRow: UnsafeRow, valueRow: Option[UnsafeRow]): Int = {
+    create(keyRow.getBytes, valueRow.map(_.getBytes))
+  }
+
+  /**
+   * Creates a checksum value using the key and value bytes.
+   * If value bytes isn't specified, will create using key bytes only.
+   * */
+  def create(keyBytes: Array[Byte], valueBytes: Option[Array[Byte]]): Int = {
+    create(
+      ArrayIndexRange(keyBytes, 0, keyBytes.length),
+      valueBytes.map(v => ArrayIndexRange(v, 0, v.length)))
+  }
+
+  /**
+   * Creates a checksum value using key bytes array index range and value 
bytes array index range.
+   * If value bytes index range isn't specified, will create using key bytes 
only.
+   *
+   * @param keyBytes Specifies the index range of bytes to use in the 
underlying array.
+   * @param valueBytes Optional, specifies the index range of bytes to use in 
the underlying array.
+   * */
+  def create(keyBytes: ArrayIndexRange[Byte], valueBytes: 
Option[ArrayIndexRange[Byte]]): Int = {
+    // We can later make the checksum algorithm configurable
+    val crc32c = new CRC32C()
+
+    crc32c.update(keyBytes.array, keyBytes.fromIndex, keyBytes.length)
+    valueBytes.foreach { value =>
+      crc32c.update(value.array, value.fromIndex, value.length)
+    }
+
+    crc32c.getValue.toInt
+  }
+}
+
+/**
+ * Used to encode and decode checksum value with/from the row bytes.
+ * */
+object KeyValueChecksumEncoder {
+  /**
+   * Encodes the value row bytes with a checksum value. This encodes the bytes 
in a way that
+   * supports additional values to be later merged to it. If the value would 
only ever have a
+   * single value (no merge), then you should use 
[[encodeSingleValueRowWithChecksum]] instead.
+   *
+   * It is encoded as: checksum (4 bytes) + rowBytes.length (4 bytes) + 
rowBytes
+   *
+   * @param rowBytes Value row bytes.
+   * @param checksum Checksum value to encode with the value row bytes.
+   * @return The encoded value row bytes that includes the checksum.
+   * */
+  def encodeValueRowWithChecksum(rowBytes: Array[Byte], checksum: Int): 
Array[Byte] = {
+    val result = new Array[Byte](java.lang.Integer.BYTES * 2 + rowBytes.length)
+    Platform.putInt(result, Platform.BYTE_ARRAY_OFFSET, checksum)
+    Platform.putInt(result, Platform.BYTE_ARRAY_OFFSET + 
java.lang.Integer.BYTES, rowBytes.length)
+
+    // Write the actual data
+    Platform.copyMemory(
+      rowBytes, Platform.BYTE_ARRAY_OFFSET,
+      result, Platform.BYTE_ARRAY_OFFSET + java.lang.Integer.BYTES * 2,
+      rowBytes.length
+    )
+    result
+  }
+
+  /**
+   * Decode and verify a value row bytes encoded with checksum via 
[[encodeValueRowWithChecksum]]
+   * back to the original value row bytes. Supports decoding both one or more 
values bytes
+   * (i.e. merged values). This copies each individual value and removes their 
encoded checksum to
+   * form the original value bytes. Because it does copy, it is more expensive 
than
+   * [[decodeAndVerifyMultiValueRowWithChecksum]] method which returns
+   * the index range instead of copy (preferred for multi-value bytes). This 
method is just used
+   * to support calling store.get for a key that has merged values.
+   *
+   * @param verifier used for checksum verification.
+   * @param keyBytes Key bytes for the value to decode, only used for checksum 
verification.
+   * @param valueBytes The value bytes to decode.
+   * @return The original value row bytes, without the checksum.
+   * */
+  def decodeAndVerifyValueRowWithChecksum(
+      verifier: Option[KeyValueIntegrityVerifier],
+      keyBytes: Array[Byte],
+      valueBytes: Array[Byte]): Array[Byte] = {
+    // First get the total size of the original values
+    // Doing this to also support decoding merged values (via merge) e.g. 
val1,val2,val3
+    val valuesEnd = Platform.BYTE_ARRAY_OFFSET + valueBytes.length
+    var currentPosition = Platform.BYTE_ARRAY_OFFSET
+    var resultSize = 0
+    var numValues = 0
+    while (currentPosition < valuesEnd) {
+      // skip the checksum (first 4 bytes)
+      currentPosition += java.lang.Integer.BYTES
+      val valueRowSize = Platform.getInt(valueBytes, currentPosition)
+      // move to the next value and skip the delimiter character used for 
rocksdb merge
+      currentPosition += java.lang.Integer.BYTES + valueRowSize + 1
+      resultSize += valueRowSize
+      numValues += 1
+    }
+
+    // include the number of delimiters used for merge
+    resultSize += numValues - 1
+
+    // now verify and decode to original merged values
+    val result = new Array[Byte](resultSize)
+    var resultPosition = Platform.BYTE_ARRAY_OFFSET
+    val keyRowIndex = ArrayIndexRange(keyBytes, 0, keyBytes.length)
+    currentPosition = Platform.BYTE_ARRAY_OFFSET // reset to beginning of 
values
+    var currentValueCount = 0 // count of values iterated
+
+    while (currentPosition < valuesEnd) {
+      currentValueCount += 1
+      val checksum = Platform.getInt(valueBytes, currentPosition)
+      currentPosition += java.lang.Integer.BYTES
+      val valueRowSize = Platform.getInt(valueBytes, currentPosition)
+      currentPosition += java.lang.Integer.BYTES
+
+      // verify the current value using the index range to avoid copying
+      // convert position to array index
+      val from = byteOffsetToIndex(currentPosition)
+      val until = from + valueRowSize
+      val valueRowIndex = ArrayIndexRange(valueBytes, from, until)
+      verifier.foreach(_.verify(keyRowIndex, Some(valueRowIndex), checksum))
+
+      // No delimiter is needed if single value or the last value in 
multi-value
+      val copyLength = if (currentValueCount < numValues) {
+        valueRowSize + 1 // copy the delimiter
+      } else {
+        valueRowSize
+      }
+      Platform.copyMemory(
+        valueBytes, currentPosition,
+        result, resultPosition,
+        copyLength
+      )
+
+      // move to the next value
+      currentPosition += copyLength
+      resultPosition += copyLength
+    }
+
+    result
+  }
+
+  /**
+   * Decode and verify a value row bytes, that might contain one or more 
values (merged)
+   * encoded with checksum via [[encodeValueRowWithChecksum]] back to the 
original value
+   * row bytes index.
+   * This returns an iterator of index range of the original individual value 
row bytes
+   * and verifies their checksum. Cheaper since it does not copy the value 
bytes unlike
+   * [[decodeAndVerifyValueRowWithChecksum]].
+   *
+   * @param verifier Used for checksum verification.
+   * @param keyBytes Key bytes for the value to decode, only used for checksum 
verification.
+   * @param valueBytes The value bytes to decode.
+   * @return Iterator of index range representing the original value row 
bytes, without checksum.
+   */
+  def decodeAndVerifyMultiValueRowWithChecksum(
+      verifier: Option[KeyValueIntegrityVerifier],
+      keyBytes: Array[Byte],
+      valueBytes: Array[Byte]): Iterator[ArrayIndexRange[Byte]] = {
+    if (valueBytes == null) {
+      Seq().iterator
+    } else {
+      new Iterator[ArrayIndexRange[Byte]] {
+        private val keyRowIndex = ArrayIndexRange(keyBytes, 0, keyBytes.length)
+        private var position: Int = Platform.BYTE_ARRAY_OFFSET
+        private val valuesEnd = Platform.BYTE_ARRAY_OFFSET + valueBytes.length
+
+        override def hasNext: Boolean = position < valuesEnd
+
+        override def next(): ArrayIndexRange[Byte] = {
+          val (valueRowIndex, checksum) =
+            getValueRowIndexAndChecksum(valueBytes, startingPosition = 
position)
+          verifier.foreach(_.verify(keyRowIndex, Some(valueRowIndex), 
checksum))
+          // move to the next value and skip the delimiter character used for 
rocksdb merge
+          position = byteIndexToOffset(valueRowIndex.untilIndex) + 1
+          valueRowIndex
+        }
+      }
+    }
+  }
+
+  /**
+   * Decodes one value row bytes encoded with checksum via 
[[encodeValueRowWithChecksum]]
+   * back to the original value row bytes. This is used for an encoded value 
row that
+   * currently only have one value. This returns the index range of the 
original value row bytes
+   * and the encoded checksum value. Cheaper since it does not copy the value 
bytes.
+   *
+   * @param bytes The value bytes to decode.
+   * @return A tuple containing the index range of the original value row 
bytes and the checksum.
+   */
+  def decodeOneValueRowIndexWithChecksum(bytes: Array[Byte]): 
(ArrayIndexRange[Byte], Int) = {
+    getValueRowIndexAndChecksum(bytes, startingPosition = 
Platform.BYTE_ARRAY_OFFSET)
+  }
+
+  /** Get the original value row index and checksum for a row encoded via
+   * [[encodeValueRowWithChecksum]]
+   * */
+  private def getValueRowIndexAndChecksum(
+      bytes: Array[Byte],
+      startingPosition: Int): (ArrayIndexRange[Byte], Int) = {
+    var position = startingPosition
+    val checksum = Platform.getInt(bytes, position)
+    position += java.lang.Integer.BYTES
+    val rowSize = Platform.getInt(bytes, position)
+    position += java.lang.Integer.BYTES
+
+    // convert position to array index
+    val fromIndex = byteOffsetToIndex(position)
+    (ArrayIndexRange(bytes, fromIndex, fromIndex + rowSize), checksum)
+  }
+
+  /**
+   * Encodes the key row bytes with a checksum value.
+   * It is encoded as: checksum (4 bytes) + rowBytes
+   *
+   * @param rowBytes Key row bytes.
+   * @param checksum Checksum value to encode with the key row bytes.
+   * @return The encoded key row bytes that includes the checksum.
+   * */
+  def encodeKeyRowWithChecksum(rowBytes: Array[Byte], checksum: Int): 
Array[Byte] = {
+    encodeSingleValueRowWithChecksum(rowBytes, checksum)
+  }
+
+  /**
+   * Decodes the key row encoded with [[encodeKeyRowWithChecksum]] and
+   * returns the original key bytes and the checksum value.
+   *
+   * @param bytes The encoded key bytes with checksum.
+   * @return Tuple of the original key bytes and the checksum value.
+   * */
+  def decodeKeyRowWithChecksum(bytes: Array[Byte]): (Array[Byte], Int) = {
+    decodeSingleValueRowWithChecksum(bytes)
+  }
+
+  /**
+   * Decode and verify the key row encoded with [[encodeKeyRowWithChecksum]] 
and
+   * returns the original key bytes.
+   *
+   * @param verifier Used for checksum verification.
+   * @param bytes The encoded key bytes with checksum.
+   * @return The original key bytes.
+   * */
+  def decodeAndVerifyKeyRowWithChecksum(
+      verifier: Option[KeyValueIntegrityVerifier],
+      bytes: Array[Byte]): Array[Byte] = {
+    val (originalBytes, checksum) = decodeKeyRowWithChecksum(bytes)
+    verifier.foreach(_.verify(originalBytes, None, checksum))
+    originalBytes
+  }
+
+  /**
+   * Decodes a key row encoded with [[encodeKeyRowWithChecksum]] and
+   * returns the index range of the original key bytes and checksum value. 
This is cheaper
+   * than [[decodeKeyRowWithChecksum]], since it doesn't copy the key bytes.
+   *
+   * @param keyBytes The encoded key bytes with checksum.
+   * @return A tuple containing the index range of the original key row bytes 
and the checksum.
+   * */
+  def decodeKeyRowIndexWithChecksum(keyBytes: Array[Byte]): 
(ArrayIndexRange[Byte], Int) = {
+    decodeSingleValueRowIndexWithChecksum(keyBytes)
+  }
+
+  /**
+   * Encodes a value row bytes that will only ever have a single value (no 
multi-value)
+   * with a checksum value.
+   * It is encoded as: checksum (4 bytes) + rowBytes.
+   * Since it will only ever have a single value, no need to encode the 
rowBytes length.
+   *
+   * @param rowBytes Value row bytes.
+   * @param checksum Checksum value to encode with the value row bytes.
+   * @return The encoded value row bytes that includes the checksum.
+   * */
+  def encodeSingleValueRowWithChecksum(rowBytes: Array[Byte], checksum: Int): 
Array[Byte] = {
+    val result = new Array[Byte](java.lang.Integer.BYTES + rowBytes.length)
+    Platform.putInt(result, Platform.BYTE_ARRAY_OFFSET, checksum)
+
+    // Write the actual data
+    Platform.copyMemory(
+      rowBytes, Platform.BYTE_ARRAY_OFFSET,
+      result, Platform.BYTE_ARRAY_OFFSET + java.lang.Integer.BYTES,
+      rowBytes.length
+    )
+    result
+  }
+
+  /**
+   * Decodes a single value row encoded with 
[[encodeSingleValueRowWithChecksum]] and
+   * returns the original value bytes and checksum value.
+   *
+   * @param bytes The encoded value bytes with checksum.
+   * @return Tuple of the original value bytes and the checksum value.
+   * */
+  def decodeSingleValueRowWithChecksum(bytes: Array[Byte]): (Array[Byte], Int) 
= {
+    val checksum = Platform.getInt(bytes, Platform.BYTE_ARRAY_OFFSET)
+    val row = new Array[Byte](bytes.length - java.lang.Integer.BYTES)
+    Platform.copyMemory(
+      bytes, Platform.BYTE_ARRAY_OFFSET + java.lang.Integer.BYTES,
+      row, Platform.BYTE_ARRAY_OFFSET,
+      row.length
+    )
+
+    (row, checksum)
+  }
+
+  /**
+   * Decode and verify a single value row encoded with 
[[encodeSingleValueRowWithChecksum]] and
+   * returns the original value bytes.
+   *
+   * @param verifier used for checksum verification.
+   * @param keyBytes Key bytes for the value to decode, only used for checksum 
verification.
+   * @param valueBytes The value bytes to decode.
+   * @return The original value row bytes, without the checksum.
+   * */
+  def decodeAndVerifySingleValueRowWithChecksum(
+      verifier: Option[KeyValueIntegrityVerifier],
+      keyBytes: Array[Byte],
+      valueBytes: Array[Byte]): Array[Byte] = {
+    val (originalValueBytes, checksum) = 
decodeSingleValueRowWithChecksum(valueBytes)
+    verifier.foreach(_.verify(keyBytes, Some(originalValueBytes), checksum))
+    originalValueBytes
+  }
+
+  /**
+   * Decodes a single value row encoded with 
[[encodeSingleValueRowWithChecksum]] and
+   * returns the index range of the original value bytes and checksum value. 
This is cheaper
+   * than [[decodeSingleValueRowWithChecksum]], since it doesn't copy the 
value bytes.
+   *
+   * @param bytes The encoded value bytes with checksum.
+   * @return A tuple containing the index range of the original value row 
bytes and the checksum.
+   * */
+  def decodeSingleValueRowIndexWithChecksum(bytes: Array[Byte]): 
(ArrayIndexRange[Byte], Int) = {
+    var position = Platform.BYTE_ARRAY_OFFSET
+    val checksum = Platform.getInt(bytes, position)
+    position += java.lang.Integer.BYTES
+
+    // convert position to array index
+    val fromIndex = byteOffsetToIndex(position)
+    (ArrayIndexRange(bytes, fromIndex, bytes.length), checksum)
+  }
+
+  /** Convert byte array address offset to array index */
+  private def byteOffsetToIndex(offset: Int): Int = {
+    offset - Platform.BYTE_ARRAY_OFFSET
+  }
+
+  /** Convert byte array index to array address offset */
+  private def byteIndexToOffset(index: Int): Int = {
+    index + Platform.BYTE_ARRAY_OFFSET
+  }
+}
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala
index eaf32c40e8cd..d7f28b79acff 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala
@@ -24,7 +24,7 @@ import org.apache.hadoop.conf.Configuration
 import org.apache.spark.io.CompressionCodec
 import org.apache.spark.sql.{Encoders, Row}
 import org.apache.spark.sql.execution.streaming.runtime.{MemoryStream, 
StreamingQueryCheckpointMetadata}
-import 
org.apache.spark.sql.execution.streaming.state.{AlsoTestWithEncodingTypes, 
AlsoTestWithRocksDBFeatures, RocksDBFileManager, RocksDBStateStoreProvider, 
TestClass}
+import 
org.apache.spark.sql.execution.streaming.state.{AlsoTestWithEncodingTypes, 
AlsoTestWithRocksDBFeatures, EnableStateStoreRowChecksum, RocksDBFileManager, 
RocksDBStateStoreProvider, TestClass}
 import org.apache.spark.sql.functions.{col, explode, timestamp_seconds}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.streaming.{InputMapRow, ListState, MapInputEvent, 
MapOutputEvent, MapStateTTLProcessor, MaxEventTimeStatefulProcessor, 
OutputMode, RunningCountStatefulProcessor, 
RunningCountStatefulProcessorWithProcTimeTimerUpdates, StatefulProcessor, 
StateStoreMetricsTest, TestMapStateProcessor, TimeMode, TimerValues, 
TransformWithStateSuiteUtils, Trigger, TTLConfig, ValueState}
@@ -1181,3 +1181,10 @@ class StateDataSourceTransformWithStateSuiteCheckpointV2 
extends
     spark.conf.set(SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION, 2)
   }
 }
+
+/**
+ * Test suite that runs all StateDataSourceTransformWithStateSuite tests with 
row checksum enabled.
+ */
+@SlowSQLTest
+class StateDataSourceTransformWithStateSuiteWithRowChecksum
+  extends StateDataSourceTransformWithStateSuite with 
EnableStateStoreRowChecksum
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala
index 9c4a7b1879f6..e4312fd16d1f 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala
@@ -197,4 +197,26 @@ class OffsetSeqLogSuite extends SharedSparkSession {
         "unsaferow")
     }
   }
+
+  test("Row checksum disabled by default") {
+    val offsetSeqMetadata = OffsetSeqMetadata.apply(batchWatermarkMs = 0, 
batchTimestampMs = 0,
+      spark.conf)
+    
assert(offsetSeqMetadata.conf.get(SQLConf.STATE_STORE_ROW_CHECKSUM_ENABLED.key) 
===
+      Some(false.toString))
+  }
+
+  test("Row checksum disabled for existing checkpoint even if conf is 
enabled") {
+    val rowChecksumConf = SQLConf.STATE_STORE_ROW_CHECKSUM_ENABLED.key
+    withSQLConf(rowChecksumConf -> true.toString) {
+      val existingChkpt = "offset-log-version-2.1.0"
+      val (_, offsetSeq) = readFromResource(existingChkpt)
+      val offsetSeqMetadata = offsetSeq.metadata.get
+      // Not present in existing checkpoint
+      assert(offsetSeqMetadata.conf.get(rowChecksumConf) === None)
+
+      val clonedSqlConf = spark.sessionState.conf.clone()
+      OffsetSeqMetadata.setSessionConf(offsetSeqMetadata, clonedSqlConf)
+      assert(!clonedSqlConf.stateStoreRowChecksumEnabled)
+    }
+  }
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala
index 40caade2acb7..ee5374c02435 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala
@@ -425,3 +425,8 @@ class ListStateSuite extends StateVariableSuiteBase {
     }
   }
 }
+
+/**
+ * Test suite that runs all ListStateSuite tests with row checksum enabled.
+ */
+class ListStateSuiteWithRowChecksum extends ListStateSuite with 
EnableStateStoreRowChecksum
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala
index 00855ba15f8d..dbbd0ce8388a 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala
@@ -290,3 +290,8 @@ class MapStateSuite extends StateVariableSuiteBase {
     }
   }
 }
+
+/**
+ * Test suite that runs all MapStateSuite tests with row checksum enabled.
+ */
+class MapStateSuiteWithRowChecksum extends MapStateSuite with 
EnableStateStoreRowChecksum
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala
index acb4795c5501..f7d7d0bc921f 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala
@@ -1318,3 +1318,10 @@ class RocksDBStateStoreCheckpointFormatV2Suite extends 
StreamTest
     }
   }
 }
+
+/**
+ * Test suite that runs all RocksDBStateStoreCheckpointFormatV2Suite tests
+ * with row checksum enabled.
+ */
+class RocksDBStateStoreCheckpointFormatV2SuiteWithRowChecksum
+  extends RocksDBStateStoreCheckpointFormatV2Suite with 
EnableStateStoreRowChecksum
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala
index 0bf95ce92797..92a0dc322e82 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala
@@ -449,3 +449,9 @@ class RocksDBStateStoreIntegrationSuite extends StreamTest
     }
   }
 }
+
+/**
+ * Test suite that runs all RocksDBStateStoreIntegrationSuite tests with row 
checksum enabled.
+ */
+class RocksDBStateStoreIntegrationSuiteWithRowChecksum
+  extends RocksDBStateStoreIntegrationSuite with EnableStateStoreRowChecksum
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreLockHardeningSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreLockHardeningSuite.scala
index 03e5bf692ef8..430789946c61 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreLockHardeningSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreLockHardeningSuite.scala
@@ -803,3 +803,9 @@ class RocksDBStateStoreLockHardeningSuite extends 
SparkFunSuite
         s"after load but was $threadId")
   }
 }
+
+/**
+ * Test suite that runs all RocksDBStateStoreLockHardeningSuite tests with row 
checksum enabled.
+ */
+class RocksDBStateStoreLockHardeningSuiteWithRowChecksum
+  extends RocksDBStateStoreLockHardeningSuite with EnableStateStoreRowChecksum
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala
index 3472d6bc34a8..9b86f2364885 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala
@@ -2754,3 +2754,10 @@ class RocksDBStateStoreSuite extends 
StateStoreSuiteBase[RocksDBStateStoreProvid
   }
 }
 
+/**
+ * Test suite that runs all RocksDBStateStoreSuite tests with row checksum 
enabled.
+ */
+@ExtendedSQLTest
+class RocksDBStateStoreSuiteWithRowChecksum extends RocksDBStateStoreSuite
+  with EnableStateStoreRowChecksum
+
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala
index de16aa38fe5d..61e551b851c2 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala
@@ -4045,3 +4045,10 @@ object RocksDBSuite {
     }
   }
 }
+
+/**
+ * Test suite that runs all RocksDBSuite tests with row checksum enabled.
+ * This ensures row checksum works correctly with all RocksDB features.
+ */
+@SlowSQLTest
+class RocksDBSuiteWithRowChecksum extends RocksDBSuite with 
EnableStateStoreRowChecksum
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreInstanceMetricSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreInstanceMetricSuite.scala
index df4d19226b9d..5c1c9d64c009 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreInstanceMetricSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreInstanceMetricSuite.scala
@@ -476,3 +476,9 @@ class StateStoreInstanceMetricSuite extends StreamTest with 
AlsoTestWithRocksDBF
       }
   }
 }
+
+/**
+ * Test suite that runs all StateStoreInstanceMetricSuite tests with row 
checksum enabled.
+ */
+class StateStoreInstanceMetricSuiteWithRowChecksum
+  extends StateStoreInstanceMetricSuite with EnableStateStoreRowChecksum
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRowChecksumSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRowChecksumSuite.scala
new file mode 100644
index 000000000000..b2597ceb406d
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRowChecksumSuite.scala
@@ -0,0 +1,445 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state
+
+import java.io.{DataInputStream, DataOutputStream}
+import java.util.UUID
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.Path
+import org.rocksdb.{ReadOptions, RocksDB => NativeRocksDB, WriteOptions}
+import org.scalatest.{BeforeAndAfter, PrivateMethodTester}
+
+import org.apache.spark.{SparkConf, SparkException}
+import org.apache.spark.io.CompressionCodec
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import 
org.apache.spark.sql.execution.streaming.checkpointing.CheckpointFileManager
+import org.apache.spark.sql.execution.streaming.runtime.StreamExecution
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.Utils
+
+/**
+ * Main test for State store row checksum feature. Has both  
[[HDFSStateStoreRowChecksumSuite]]
+ * and [[RocksDBStateStoreRowChecksumSuite]] as subclasses. Row checksum is 
also enabled
+ * in other tests (State operators, TransformWithState, State data source, RTM 
etc.)
+ * by adding the [[AlsoTestWithStateStoreRowChecksum]] trait.
+ * */
+abstract class StateStoreRowChecksumSuite extends SharedSparkSession
+  with BeforeAndAfter {
+
+  import StateStoreTestsHelper._
+
+  before {
+    StateStore.stop()
+    require(!StateStore.isMaintenanceRunning)
+    spark.streams.stateStoreCoordinator // initialize the lazy coordinator
+  }
+
+  after {
+    StateStore.stop()
+    require(!StateStore.isMaintenanceRunning)
+  }
+
+  override protected def sparkConf: SparkConf = {
+    super.sparkConf.set(SQLConf.STATE_STORE_ROW_CHECKSUM_ENABLED.key, 
true.toString)
+      // To avoid file checksum verification since we will be injecting 
corruption in this suite
+      .set(SQLConf.STREAMING_CHECKPOINT_FILE_CHECKSUM_ENABLED.key, 
false.toString)
+  }
+
+  protected def withStateStoreProvider[T <: StateStoreProvider](
+      provider: T)(f: T => Unit): Unit = {
+    try {
+      f(provider)
+    } finally {
+      provider.close()
+    }
+  }
+
+  protected def createProvider: StateStoreProvider
+
+  protected def createInitializedProvider(
+      dir: String = Utils.createTempDir().toString,
+      opId: Long = 0,
+      partition: Int = 0,
+      runId: UUID = UUID.randomUUID(),
+      keyStateEncoderSpec: KeyStateEncoderSpec = NoPrefixKeyStateEncoderSpec(
+        keySchema),
+      keySchema: StructType = keySchema,
+      valueSchema: StructType = valueSchema,
+      sqlConf: SQLConf = SQLConf.get,
+      hadoopConf: Configuration = new Configuration): StateStoreProvider = {
+    hadoopConf.set(StreamExecution.RUN_ID_KEY, runId.toString)
+    val provider = createProvider
+    provider.init(
+      StateStoreId(dir, opId, partition),
+      keySchema,
+      valueSchema,
+      keyStateEncoderSpec,
+      useColumnFamilies = false,
+      new StateStoreConf(sqlConf),
+      hadoopConf)
+    provider
+  }
+
+  protected def corruptRow(
+    provider: StateStoreProvider, store: StateStore, key: UnsafeRow): Unit
+
+  protected def corruptRowInFile(
+    provider: StateStoreProvider, version: Long, isSnapshot: Boolean): Unit
+
+  test("Detect local corrupt row") {
+    withStateStoreProvider(createInitializedProvider()) { provider =>
+      val store = provider.getStore(0)
+      put(store, "1", 11, 100)
+      put(store, "2", 22, 200)
+
+      // corrupt a row
+      corruptRow(provider, store, dataToKeyRow("1", 11))
+
+      assert(get(store, "2", 22).contains(200),
+        "failed to get the correct value for the uncorrupted row")
+
+      checkChecksumError(intercept[SparkException] {
+        // should throw row checksum exception
+        get(store, "1", 11)
+      })
+
+      checkChecksumError(intercept[SparkException] {
+        // should throw row checksum exception
+        store.iterator().foreach(kv => assert(kv.key != null && kv.value != 
null))
+      })
+
+      store.abort()
+    }
+  }
+
+  protected def corruptRowInFileTestMode: Seq[Boolean]
+
+  corruptRowInFileTestMode.foreach { isSnapshot =>
+    test(s"Detect corrupt row in checkpoint file - isSnapshot: $isSnapshot") {
+      // We want to generate snapshot per change file
+      withSQLConf(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "0") {
+        var version = 0L
+        val checkpointDir = Utils.createTempDir().toString
+
+        withStateStoreProvider(createInitializedProvider(checkpointDir)) { 
provider =>
+          val store = provider.getStore(version)
+          put(store, "1", 11, 100)
+          put(store, "2", 22, 200)
+          remove(store, k => k == ("1", 11))
+          put(store, "3", 33, 300)
+          version = store.commit() // writes change file
+
+          if (isSnapshot) {
+            provider.doMaintenance() // writes snapshot
+          }
+        }
+
+        // We should be able to successfully load the store from checkpoint
+        withStateStoreProvider(createInitializedProvider(checkpointDir)) { 
provider =>
+          val store = provider.getStore(version)
+          assert(get(store, "1", 11).isEmpty) // the removed row
+          assert(get(store, "2", 22).contains(200))
+          assert(get(store, "3", 33).contains(300))
+
+          // Now corrupt a row in the checkpoint file
+          corruptRowInFile(provider, version, isSnapshot)
+          store.abort()
+        }
+
+        // Reload the store from checkpoint, should detect the corrupt row
+        withStateStoreProvider(createInitializedProvider(checkpointDir)) { 
provider =>
+          checkChecksumError(intercept[SparkException] {
+            // should throw row checksum exception
+            provider.getStore(version)
+          }.getCause.asInstanceOf[SparkException])
+        }
+      }
+    }
+  }
+
+  test("Read verification ratio") {
+    var version = 0L
+    val checkpointDir = Utils.createTempDir().toString
+    val numRows = 10
+
+    withStateStoreProvider(createInitializedProvider(checkpointDir)) { 
provider =>
+      val store = provider.getStore(0)
+      // add 10 rows
+      (1 to numRows).foreach(v => put(store, v.toString, v, v * 100))
+
+      // read the rows
+      (1 to numRows).foreach(v => assert(get(store, v.toString, v).nonEmpty))
+      val verifier = getReadVerifier(provider, 
store).get.asInstanceOf[KeyValueChecksumVerifier]
+      // Default is to verify every row read
+      assertVerifierStats(verifier, expectedRequests = numRows, 
expectedVerifies = numRows)
+
+      store.iterator().foreach(kv => assert(kv.key != null && kv.value != 
null))
+      assertVerifierStats(verifier, expectedRequests = numRows * 2, 
expectedVerifies = numRows * 2)
+
+      version = store.commit()
+    }
+
+    // Setting to 0 means we will not verify during store read requests
+    withSQLConf(SQLConf.STATE_STORE_ROW_CHECKSUM_READ_VERIFICATION_RATIO.key 
-> "0") {
+      withStateStoreProvider(createInitializedProvider(checkpointDir)) { 
provider =>
+        val store = provider.getStore(version)
+        (1 to numRows).foreach(v => assert(get(store, v.toString, v).nonEmpty))
+
+        assert(getReadVerifier(provider, store).isEmpty, "Expected no 
verifier")
+        store.abort()
+      }
+    }
+
+    // Verify every 2 store read requests
+    withSQLConf(SQLConf.STATE_STORE_ROW_CHECKSUM_READ_VERIFICATION_RATIO.key 
-> "2") {
+      withStateStoreProvider(createInitializedProvider(checkpointDir)) { 
provider =>
+        val store = provider.getStore(version)
+
+        (1 to numRows).foreach(v => assert(get(store, v.toString, v).nonEmpty))
+        val verifier = getReadVerifier(provider, 
store).get.asInstanceOf[KeyValueChecksumVerifier]
+        assertVerifierStats(verifier, expectedRequests = numRows, 
expectedVerifies = numRows / 2)
+
+        store.iterator().foreach(kv => assert(kv.key != null && kv.value != 
null))
+        assertVerifierStats(verifier, expectedRequests = numRows * 2, 
expectedVerifies = numRows)
+
+        store.abort()
+      }
+    }
+  }
+
+  protected def checkChecksumError(
+      checksumError: SparkException,
+      opId: Int = 0,
+      partId: Int = 0): Unit = {
+    checkError(
+      exception = checksumError,
+      condition = "STATE_STORE_ROW_CHECKSUM_VERIFICATION_FAILED",
+      parameters = Map(
+        "stateStoreId" -> 
(s".*StateStoreId[\\[\\(].*(operatorId|opId)=($opId)" +
+          s".*(partitionId|partId)=($partId).*[)\\]]"),
+        "expectedChecksum" -> "^-?\\d+$", // integer
+        "computedChecksum" -> "^-?\\d+$"), // integer
+      matchPVals = true)
+  }
+
+  protected def getReadVerifier(
+      provider: StateStoreProvider, store: StateStore): 
Option[KeyValueIntegrityVerifier]
+
+  private def assertVerifierStats(verifier: KeyValueChecksumVerifier,
+      expectedRequests: Long, expectedVerifies: Long): Unit = {
+    assert(verifier.getNumRequests == expectedRequests)
+    assert(verifier.getNumVerified == expectedVerifies)
+  }
+}
+
+class HDFSStateStoreRowChecksumSuite extends StateStoreRowChecksumSuite with 
PrivateMethodTester {
+  import StateStoreTestsHelper._
+
+  override protected def createProvider: StateStoreProvider = new 
HDFSBackedStateStoreProvider
+
+  override protected def corruptRow(
+      provider: StateStoreProvider, store: StateStore, key: UnsafeRow): Unit = 
{
+    val hdfsProvider = provider.asInstanceOf[HDFSBackedStateStoreProvider]
+    val hdfsStore = store.asInstanceOf[hdfsProvider.HDFSBackedStateStore]
+
+    // Access the private hdfs store map
+    val mapToUpdateField = 
PrivateMethod[HDFSBackedStateStoreMap](Symbol("mapToUpdate"))
+    val storeMap = hdfsStore invokePrivate mapToUpdateField()
+    val mapField = 
PrivateMethod[HDFSBackedStateStoreMap.MapType](Symbol("map"))
+    val map = storeMap invokePrivate mapField()
+
+    val currentValueRow = map.get(key).asInstanceOf[StateStoreRowWithChecksum]
+    val currentValue = valueRowToData(currentValueRow.unsafeRow())
+    // corrupt the existing value by flipping the last bit
+    val corruptValue = currentValue ^ 1
+
+    // update with the corrupt value, but keeping the previous checksum
+    map.put(key, StateStoreRowWithChecksum(dataToValueRow(corruptValue), 
currentValueRow.checksum))
+  }
+
+  protected def corruptRowInFile(
+      provider: StateStoreProvider, version: Long, isSnapshot: Boolean): Unit 
= {
+    val hdfsProvider = provider.asInstanceOf[HDFSBackedStateStoreProvider]
+
+    val baseDirField = PrivateMethod[Path](Symbol("baseDir"))
+    val baseDir = hdfsProvider invokePrivate baseDirField()
+    val fileName = if (isSnapshot) s"$version.snapshot" else s"$version.delta"
+    val filePath = new Path(baseDir.toString, fileName)
+
+    val fileManagerMethod = PrivateMethod[CheckpointFileManager](Symbol("fm"))
+    val fm = hdfsProvider invokePrivate fileManagerMethod()
+
+    // Read the decompressed file content
+    val input = hdfsProvider.decompressStream(fm.open(filePath))
+    val currentData = input.readAllBytes()
+    input.close()
+
+    // Corrupt the current data by flipping the last bit of the last row we 
wrote.
+    // We typically write an EOF marker (-1) so we are skipping that.
+    val byteOffsetToCorrupt = currentData.length - java.lang.Integer.BYTES - 1
+    currentData(byteOffsetToCorrupt) = (currentData(byteOffsetToCorrupt) ^ 
0x01).toByte
+
+    // Now delete the current file and write a new file with corrupt row
+    fm.delete(filePath)
+    val output = hdfsProvider.compressStream(fm.createAtomic(filePath, 
overwriteIfPossible = true))
+    output.write(currentData)
+    output.close()
+  }
+
+  // test both snapshot and delta file
+  override protected def corruptRowInFileTestMode: Seq[Boolean] = Seq(true, 
false)
+
+  protected def getReadVerifier(
+      provider: StateStoreProvider, store: StateStore): 
Option[KeyValueIntegrityVerifier] = {
+    val hdfsProvider = provider.asInstanceOf[HDFSBackedStateStoreProvider]
+    val hdfsStore = store.asInstanceOf[hdfsProvider.HDFSBackedStateStore]
+
+    // Access the private hdfs store map
+    val mapToUpdateField = 
PrivateMethod[HDFSBackedStateStoreMap](Symbol("mapToUpdate"))
+    val storeMap = hdfsStore invokePrivate mapToUpdateField()
+    val readVerifierField = PrivateMethod[Option[KeyValueIntegrityVerifier]](
+      Symbol("readVerifier"))
+    storeMap invokePrivate readVerifierField()
+  }
+
+  test("Snapshot upload should fail if corrupt row is detected") {
+    // We want to generate snapshot per change file
+    withSQLConf(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "0") {
+      withStateStoreProvider(createInitializedProvider()) { provider =>
+        val store = provider.getStore(0L)
+        put(store, "1", 11, 100)
+        put(store, "2", 22, 200)
+        put(store, "3", 33, 300)
+
+        // Corrupt the local row, won't affect changelog file since row is 
already written to it
+        corruptRow(provider, store, dataToKeyRow("1", 11))
+        store.commit() // writes change file
+
+        checkChecksumError(intercept[SparkException] {
+          // Should throw row checksum exception while trying to write snapshot
+          // provider.doMaintenance() calls doSnapshot and then swallows 
exception.
+          // Hence, calling doSnapshot directly to avoid that.
+          val doSnapshotMethod = PrivateMethod[Unit](Symbol("doSnapshot"))
+          provider invokePrivate doSnapshotMethod("maintenance", true)
+        })
+      }
+    }
+  }
+}
+
+class RocksDBStateStoreRowChecksumSuite extends StateStoreRowChecksumSuite
+  with PrivateMethodTester {
+  import StateStoreTestsHelper._
+
+  override protected def sparkConf: SparkConf = {
+    super.sparkConf
+      // Also generate changelog files for RocksDB
+      
.set("spark.sql.streaming.stateStore.rocksdb.changelogCheckpointing.enabled", 
true.toString)
+  }
+
+  override protected def createProvider: StateStoreProvider = new 
RocksDBStateStoreProvider
+
+  override protected def corruptRow(
+      provider: StateStoreProvider, store: StateStore, key: UnsafeRow): Unit = 
{
+    val rocksDBProvider = provider.asInstanceOf[RocksDBStateStoreProvider]
+    val dbField = PrivateMethod[NativeRocksDB](Symbol("db"))
+    val db = rocksDBProvider.rocksDB invokePrivate dbField()
+
+    val readOptionsField = PrivateMethod[ReadOptions](Symbol("readOptions"))
+    val readOptions = rocksDBProvider.rocksDB invokePrivate readOptionsField()
+    val writeOptionsField = PrivateMethod[WriteOptions](Symbol("writeOptions"))
+    val writeOptions = rocksDBProvider.rocksDB invokePrivate 
writeOptionsField()
+
+    val encoder = new 
UnsafeRowDataEncoder(NoPrefixKeyStateEncoderSpec(keySchema), valueSchema)
+    val keyBytes = encoder.encodeKey(key)
+    val currentValue = db.get(readOptions, keyBytes)
+
+    // corrupt the current value by flipping the last bit
+    currentValue(currentValue.length - 1) = (currentValue(currentValue.length 
- 1) ^ 0x01).toByte
+    db.put(writeOptions, keyBytes, currentValue)
+  }
+
+  protected def corruptRowInFile(
+      provider: StateStoreProvider, version: Long, isSnapshot: Boolean): Unit 
= {
+    assert(!isSnapshot, "Doesn't support corrupting a row in snapshot file")
+    val rocksDBProvider = provider.asInstanceOf[RocksDBStateStoreProvider]
+    val rocksDBFileManager = rocksDBProvider.rocksDB.fileManager
+
+    val dfsChangelogFileMethod = 
PrivateMethod[Path](Symbol("dfsChangelogFile"))
+    val changelogFilePath = rocksDBFileManager invokePrivate 
dfsChangelogFileMethod(version, None)
+
+    val fileManagerMethod = PrivateMethod[CheckpointFileManager](Symbol("fm"))
+    val fm = rocksDBFileManager invokePrivate fileManagerMethod()
+
+    val codecMethod = PrivateMethod[CompressionCodec](Symbol("codec"))
+    val codec = rocksDBFileManager invokePrivate codecMethod()
+
+    // Read the decompressed file content
+    val input = new 
DataInputStream(codec.compressedInputStream(fm.open(changelogFilePath)))
+    val currentData = input.readAllBytes()
+    input.close()
+
+    // Corrupt the current data by flipping the last bit of the last row we 
wrote.
+    // We typically write an EOF marker (-1) so we are skipping that.
+    val byteOffsetToCorrupt = currentData.length - java.lang.Integer.BYTES - 1
+    currentData(byteOffsetToCorrupt) = (currentData(byteOffsetToCorrupt) ^ 
0x01).toByte
+
+    // Now delete the current file and write a new file with corrupt row
+    fm.delete(changelogFilePath)
+    val output = new DataOutputStream(
+      codec.compressedOutputStream(fm.createAtomic(changelogFilePath, 
overwriteIfPossible = true)))
+    output.write(currentData)
+    output.close()
+  }
+
+  // test only changelog file since zip file doesn't contain rows
+  override protected def corruptRowInFileTestMode: Seq[Boolean] = Seq(false)
+
+  protected def getReadVerifier(
+      provider: StateStoreProvider, store: StateStore): 
Option[KeyValueIntegrityVerifier] = {
+    val rocksDBProvider = provider.asInstanceOf[RocksDBStateStoreProvider]
+    val readVerifierField = PrivateMethod[Option[KeyValueIntegrityVerifier]](
+      Symbol("readVerifier"))
+    rocksDBProvider.rocksDB invokePrivate readVerifierField()
+  }
+}
+
+/**
+ * Trait that enables state store row checksum in test sparkConf.
+ * Use this to create separate test suites that test with row checksum enabled.
+ *
+ * Example:
+ * {{{
+ * class MyTestSuite extends MyBaseTestSuite {
+ *   // tests without row checksum
+ * }
+ *
+ * class MyTestSuiteWithRowChecksum extends MyTestSuite with 
EnableStateStoreRowChecksum {
+ *   // inherits all tests from MyTestSuite, but with row checksum enabled
+ * }
+ * }}}
+ */
+trait EnableStateStoreRowChecksum extends SharedSparkSession {
+  override protected def sparkConf: SparkConf = {
+    super.sparkConf
+      .set(SQLConf.STATE_STORE_ROW_CHECKSUM_ENABLED.key, true.toString)
+  }
+}
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
index 807397d96918..5032020cf5e5 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
@@ -48,6 +48,7 @@ import 
org.apache.spark.sql.execution.streaming.runtime.{MemoryStream, StreamExe
 import 
org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorSuite.withCoordinatorRef
 import org.apache.spark.sql.functions.count
 import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SharedSparkSession
 import org.apache.spark.sql.types._
 import org.apache.spark.tags.ExtendedSQLTest
 import org.apache.spark.unsafe.types.UTF8String
@@ -255,6 +256,7 @@ private object FakeStateStoreProviderWithMaintenanceError {
 
 @ExtendedSQLTest
 class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
+  with SharedSparkSession
   with BeforeAndAfter {
   import StateStoreTestsHelper._
   import StateStoreCoordinatorSuite._
@@ -262,6 +264,7 @@ class StateStoreSuite extends 
StateStoreSuiteBase[HDFSBackedStateStoreProvider]
   before {
     StateStore.stop()
     require(!StateStore.isMaintenanceRunning)
+    spark.streams.stateStoreCoordinator // initialize the lazy coordinator
   }
 
   after {
@@ -1585,6 +1588,8 @@ class StateStoreSuite extends 
StateStoreSuiteBase[HDFSBackedStateStoreProvider]
     sqlConf.setConf(SQLConf.STATE_STORE_COMPRESSION_CODEC, 
SQLConf.get.stateStoreCompressionCodec)
     sqlConf.setConf(
       SQLConf.STREAMING_CHECKPOINT_FILE_CHECKSUM_ENABLED, 
SQLConf.get.checkpointFileChecksumEnabled)
+    sqlConf.setConf(
+      SQLConf.STATE_STORE_ROW_CHECKSUM_ENABLED, 
SQLConf.get.stateStoreRowChecksumEnabled)
     sqlConf
   }
 
@@ -1664,6 +1669,12 @@ class StateStoreSuite extends 
StateStoreSuiteBase[HDFSBackedStateStoreProvider]
 
     filePath.createNewFile()
   }
+
+  override protected def testQuietly(name: String)(f: => Unit): Unit = {
+    // Use the implementation from StateStoreSuiteBase.
+    // There is another in SQLTestUtils. Doing this to avoid conflict error.
+    super[StateStoreSuiteBase].testQuietly(name)(f)
+  }
 }
 
 abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider]
@@ -2930,3 +2941,16 @@ class RenameReturnsFalseFileSystem extends 
RawLocalFileSystem {
 object RenameReturnsFalseFileSystem {
   val scheme = s"StateStoreSuite${math.abs(Random.nextInt())}fs"
 }
+
+/**
+ * Test suite that runs all StateStoreSuite tests with row checksum enabled.
+ */
+@ExtendedSQLTest
+class StateStoreSuiteWithRowChecksum
+  extends StateStoreSuite with EnableStateStoreRowChecksum {
+  override protected def testQuietly(name: String)(f: => Unit): Unit = {
+    // Use the implementation from StateStoreSuiteBase.
+    // There is another in SQLTestUtils. Doing this to avoid conflict error.
+    super[StateStoreSuite].testQuietly(name)(f)
+  }
+}
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala
index a68965dd7845..86cf77d8a57b 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala
@@ -287,3 +287,9 @@ class StatefulProcessorHandleSuite extends 
StateVariableSuiteBase {
     }
   }
 }
+
+/**
+ * Test suite that runs all StatefulProcessorHandleSuite tests with row 
checksum enabled.
+ */
+class StatefulProcessorHandleSuiteWithRowChecksum extends 
StatefulProcessorHandleSuite
+  with EnableStateStoreRowChecksum
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/TimerSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/TimerSuite.scala
index c74599a6aa74..b8ad09cb0d95 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/TimerSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/TimerSuite.scala
@@ -187,3 +187,8 @@ class TimerSuite extends StateVariableSuiteBase {
     }
   }
 }
+
+/**
+ * Test suite that runs all TimerSuite tests with row checksum enabled.
+ */
+class TimerSuiteWithRowChecksum extends TimerSuite with 
EnableStateStoreRowChecksum
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala
index fbe33ddc32db..328bc38bf7d5 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala
@@ -493,3 +493,7 @@ abstract class StateVariableSuiteBase extends 
SharedSparkSession
   }
 }
 
+/**
+ * Test suite that runs all ValueStateSuite tests with row checksum enabled.
+ */
+class ValueStateSuiteWithRowChecksum extends ValueStateSuite with 
EnableStateStoreRowChecksum
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala
index b0a76e170ff5..f84a4a4af3ec 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.streaming
 import org.apache.spark.SparkIllegalArgumentException
 import org.apache.spark.sql.Encoders
 import org.apache.spark.sql.execution.streaming.runtime.MemoryStream
-import 
org.apache.spark.sql.execution.streaming.state.{AlsoTestWithEncodingTypes, 
AlsoTestWithRocksDBFeatures, RocksDBStateStoreProvider}
+import 
org.apache.spark.sql.execution.streaming.state.{AlsoTestWithEncodingTypes, 
AlsoTestWithRocksDBFeatures, EnableStateStoreRowChecksum, 
RocksDBStateStoreProvider}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.tags.SlowSQLTest
 
@@ -495,3 +495,10 @@ class TransformWithListStateSuite extends StreamTest
     }
   }
 }
+
+/**
+ * Test suite that runs all TransformWithListStateSuite tests with row 
checksum enabled.
+ */
+@SlowSQLTest
+class TransformWithListStateSuiteWithRowChecksum
+  extends TransformWithListStateSuite with EnableStateStoreRowChecksum
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala
index fd73517d8181..320542d6cdc4 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.streaming
 import org.apache.spark.SparkIllegalArgumentException
 import org.apache.spark.sql.Encoders
 import org.apache.spark.sql.execution.streaming.runtime.MemoryStream
-import 
org.apache.spark.sql.execution.streaming.state.{AlsoTestWithEncodingTypes, 
AlsoTestWithRocksDBFeatures, RocksDBStateStoreProvider}
+import 
org.apache.spark.sql.execution.streaming.state.{AlsoTestWithEncodingTypes, 
AlsoTestWithRocksDBFeatures, EnableStateStoreRowChecksum, 
RocksDBStateStoreProvider}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.tags.SlowSQLTest
 
@@ -353,3 +353,10 @@ class TransformWithMapStateSuite extends StreamTest
     }
   }
 }
+
+/**
+ * Test suite that runs all TransformWithMapStateSuite tests with row checksum 
enabled.
+ */
+@SlowSQLTest
+class TransformWithMapStateSuiteWithRowChecksum
+  extends TransformWithMapStateSuite with EnableStateStoreRowChecksum
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateChainingSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateChainingSuite.scala
index 7d065d561feb..4d4bd39f1520 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateChainingSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateChainingSuite.scala
@@ -24,7 +24,7 @@ import org.apache.spark.{SparkRuntimeException, 
SparkThrowable}
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.ExtendedAnalysisException
 import org.apache.spark.sql.execution.streaming.runtime.{MemoryStream, 
StreamExecution}
-import 
org.apache.spark.sql.execution.streaming.state.{AlsoTestWithEncodingTypes, 
AlsoTestWithRocksDBFeatures, RocksDBStateStoreProvider}
+import 
org.apache.spark.sql.execution.streaming.state.{AlsoTestWithEncodingTypes, 
AlsoTestWithRocksDBFeatures, EnableStateStoreRowChecksum, 
RocksDBStateStoreProvider}
 import org.apache.spark.sql.functions.window
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.tags.SlowSQLTest
@@ -433,3 +433,10 @@ class TransformWithStateChainingSuite extends StreamTest
     }
   }
 }
+
+/**
+ * Test suite that runs all TransformWithStateChainingSuite tests with row 
checksum enabled.
+ */
+@SlowSQLTest
+class TransformWithStateChainingSuiteWithRowChecksum
+  extends TransformWithStateChainingSuite with EnableStateStoreRowChecksum
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala
index 1c8c567b73fc..9685f70b86a7 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.streaming
 import org.apache.spark.sql.{DataFrame, Dataset, Encoders, 
KeyValueGroupedDataset}
 import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions
 import org.apache.spark.sql.execution.streaming.runtime.MemoryStream
-import 
org.apache.spark.sql.execution.streaming.state.{AlsoTestWithEncodingTypes, 
AlsoTestWithRocksDBFeatures, RocksDBStateStoreProvider}
+import 
org.apache.spark.sql.execution.streaming.state.{AlsoTestWithEncodingTypes, 
AlsoTestWithRocksDBFeatures, EnableStateStoreRowChecksum, 
RocksDBStateStoreProvider}
 import org.apache.spark.sql.functions.{col, timestamp_seconds}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.streaming.util.StreamManualClock
@@ -763,3 +763,10 @@ class TransformWithStateInitialStateSuiteCheckpointV2
     spark.conf.set(SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION, 2)
   }
 }
+
+/**
+ * Test suite that runs all TransformWithStateInitialStateSuite tests with row 
checksum enabled.
+ */
+@SlowSQLTest
+class TransformWithStateInitialStateSuiteWithRowChecksum
+  extends TransformWithStateInitialStateSuite with EnableStateStoreRowChecksum


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to