This is an automated email from the ASF dual-hosted git repository.

ashrigondekar pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new ac717dd7aec6 [SPARK-54106][SS] State store row checksum implementation
ac717dd7aec6 is described below

commit ac717dd7aec665de578d7c6b0070e8fcdde3cea9
Author: micheal-o <[email protected]>
AuthorDate: Fri Oct 31 17:17:27 2025 -0700

    [SPARK-54106][SS] State store row checksum implementation
    
    ### What changes were proposed in this pull request?
    
    This introduces row checksum creation and verification for state store, 
both HDFS and RocksDB state store. This will help detect corruption at the row 
level and help prevent us from corrupting the remote checkpoint. We also verify 
the checksum when loading from checkpoint to detect if a row in the checkpoint 
is corrupt.
    
    Since this adds overhead, **it is disabled by default and can only be 
enabled for a new checkpoint** since the conf is also written to the offset log 
to ensure that.
    
    This also introduces a readVerificationRatio conf, that allows checksum 
verification when a row is read by client. This makes corruption detection more 
immediate and will fail the query task, rather than when it is too late. But 
can be expensive to do this frequently for reads, hence why the frequency is 
configurable.
    
    ### Why are the changes needed?
    
    Integrity verification
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    New tests added
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #52809 from micheal-o/row_checksum.
    
    Authored-by: micheal-o <[email protected]>
    Signed-off-by: Anish Shrigondekar <[email protected]>
---
 .../src/main/resources/error/error-conditions.json |   7 +
 .../org/apache/spark/sql/internal/SQLConf.scala    |  30 ++
 .../streaming/checkpointing/OffsetSeq.scala        |   6 +-
 .../streaming/state/HDFSBackedStateStoreMap.scala  |  89 +++-
 .../state/HDFSBackedStateStoreProvider.scala       | 173 +++++--
 .../sql/execution/streaming/state/RocksDB.scala    | 185 +++++--
 .../streaming/state/RocksDBStateEncoder.scala      |  40 ++
 .../state/RocksDBStateStoreProvider.scala          |  54 ++-
 .../sql/execution/streaming/state/StateStore.scala |   1 -
 .../streaming/state/StateStoreChangelog.scala      |   8 +
 .../execution/streaming/state/StateStoreConf.scala |   6 +
 .../streaming/state/StateStoreErrors.scala         |  19 +
 .../execution/streaming/state/StateStoreRow.scala  |  56 +++
 .../streaming/state/StateStoreRowChecksum.scala    | 536 +++++++++++++++++++++
 .../execution/streaming/OffsetSeqLogSuite.scala    |  22 +
 .../execution/streaming/state/RocksDBSuite.scala   |   5 +-
 .../state/StateStoreRowChecksumSuite.scala         | 464 ++++++++++++++++++
 .../streaming/state/StateStoreSuite.scala          |  22 +-
 .../streaming/state/ValueStateSuite.scala          |   4 +-
 .../sql/streaming/StateStoreMetricsTest.scala      |   4 +-
 20 files changed, 1625 insertions(+), 106 deletions(-)

diff --git a/common/utils/src/main/resources/error/error-conditions.json 
b/common/utils/src/main/resources/error/error-conditions.json
index 34b72975cc07..cfbe571bca32 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -5365,6 +5365,13 @@
     ],
     "sqlState" : "42K06"
   },
+  "STATE_STORE_ROW_CHECKSUM_VERIFICATION_FAILED" : {
+    "message" : [
+      "Row checksum verification failed for stateStore=<stateStoreId>. The row 
may be corrupted.",
+      "Expected checksum: <expectedChecksum>, Computed checksum: 
<computedChecksum>."
+    ],
+    "sqlState" : "XXKST"
+  },
   "STATE_STORE_STATE_SCHEMA_FILES_THRESHOLD_EXCEEDED" : {
     "message" : [
       "The number of state schema files <numStateSchemaFiles> exceeds the 
maximum number of state schema files for this query: <maxStateSchemaFiles>.",
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index d88cbe326cfb..09110d67b077 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -2633,6 +2633,31 @@ object SQLConf {
       .checkValue(k => k >= 0, "Must be greater than or equal to 0")
       .createWithDefault(5)
 
+  val STATE_STORE_ROW_CHECKSUM_ENABLED =
+    buildConf("spark.sql.streaming.stateStore.rowChecksum.enabled")
+      .internal()
+      .doc("When true, checksum would be generated and verified for each state 
store row. " +
+        "This is used to detect row level corruption. " +
+        "Note: This configuration cannot be changed between query restarts " +
+        "from the same checkpoint location.")
+      .version("4.1.0")
+      .booleanConf
+      .createWithDefault(false)
+
+  val STATE_STORE_ROW_CHECKSUM_READ_VERIFICATION_RATIO =
+    
buildConf("spark.sql.streaming.stateStore.rowChecksum.readVerificationRatio")
+      .internal()
+      .doc("When specified, Spark will do row checksum verification for every 
specified " +
+        "number of rows read from state store. The check is to ensure the row 
read from " +
+        "state store is not corrupt. Default is 0, which means no verification 
during read " +
+        "but we will still do verification when loading from checkpoint 
location." +
+        "Example, if you set to 1, it will do the check for every row read 
from the state store." +
+        "If set to 10, it will do the check for every 10th row read from the 
state store.")
+      .version("4.1.0")
+      .longConf
+      .checkValue(k => k >= 0, "Must be greater than or equal to 0")
+      .createWithDefault(if (Utils.isTesting) 1 else 0)
+
   val STATEFUL_SHUFFLE_PARTITIONS_INTERNAL =
     buildConf("spark.sql.streaming.internal.stateStore.partitions")
       .doc("WARN: This config is used internally and is not intended to be 
user-facing. This " +
@@ -6735,6 +6760,11 @@ class SQLConf extends Serializable with Logging with 
SqlApiConf {
   def stateStoreCoordinatorMaxLaggingStoresToReport: Int =
     getConf(STATE_STORE_COORDINATOR_MAX_LAGGING_STORES_TO_REPORT)
 
+  def stateStoreRowChecksumEnabled: Boolean = 
getConf(STATE_STORE_ROW_CHECKSUM_ENABLED)
+
+  def stateStoreRowChecksumReadVerificationRatio: Long =
+    getConf(STATE_STORE_ROW_CHECKSUM_READ_VERIFICATION_RATIO)
+
   def checkpointLocation: Option[String] = getConf(CHECKPOINT_LOCATION)
 
   def checkpointFileChecksumEnabled: Boolean = 
getConf(STREAMING_CHECKPOINT_FILE_CHECKSUM_ENABLED)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/OffsetSeq.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/OffsetSeq.scala
index 62c903cb689a..888dc0cdb912 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/OffsetSeq.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/OffsetSeq.scala
@@ -113,7 +113,8 @@ object OffsetSeqMetadata extends Logging {
     FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION, 
STREAMING_AGGREGATION_STATE_FORMAT_VERSION,
     STREAMING_JOIN_STATE_FORMAT_VERSION, STATE_STORE_COMPRESSION_CODEC,
     STATE_STORE_ROCKSDB_FORMAT_VERSION, 
STATEFUL_OPERATOR_USE_STRICT_DISTRIBUTION,
-    PRUNE_FILTERS_CAN_PRUNE_STREAMING_SUBPLAN, 
STREAMING_STATE_STORE_ENCODING_FORMAT
+    PRUNE_FILTERS_CAN_PRUNE_STREAMING_SUBPLAN, 
STREAMING_STATE_STORE_ENCODING_FORMAT,
+    STATE_STORE_ROW_CHECKSUM_ENABLED
   )
 
   /**
@@ -159,7 +160,8 @@ object OffsetSeqMetadata extends Logging {
     STATE_STORE_COMPRESSION_CODEC.key -> CompressionCodec.LZ4,
     STATEFUL_OPERATOR_USE_STRICT_DISTRIBUTION.key -> "false",
     PRUNE_FILTERS_CAN_PRUNE_STREAMING_SUBPLAN.key -> "true",
-    STREAMING_STATE_STORE_ENCODING_FORMAT.key -> "unsaferow"
+    STREAMING_STATE_STORE_ENCODING_FORMAT.key -> "unsaferow",
+    STATE_STORE_ROW_CHECKSUM_ENABLED.key -> "false"
   )
 
   def readValue[T](metadataLog: OffsetSeqMetadata, confKey: ConfigEntry[T]): 
String = {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreMap.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreMap.scala
index fe59703a1f45..f2290f8569b8 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreMap.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreMap.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.execution.streaming.state
 
+import java.util.Map.Entry
+
 import scala.collection.mutable
 import scala.jdk.CollectionConverters._
 
@@ -27,10 +29,13 @@ import org.apache.spark.sql.types.{StructField, StructType}
 trait HDFSBackedStateStoreMap {
   def size(): Int
   def get(key: UnsafeRow): UnsafeRow
-  def put(key: UnsafeRow, value: UnsafeRow): UnsafeRow
+  def put(key: UnsafeRow, value: UnsafeRowWrapper): UnsafeRowWrapper
   def putAll(map: HDFSBackedStateStoreMap): Unit
-  def remove(key: UnsafeRow): UnsafeRow
+  def remove(key: UnsafeRow): UnsafeRowWrapper
   def iterator(): Iterator[UnsafeRowPair]
+  /** Returns entries in the underlying map and skips additional checks done 
by [[iterator]].
+   * [[iterator]] should be preferred over this. */
+  def entryIterator(): Iterator[Entry[UnsafeRow, UnsafeRowWrapper]]
   def prefixScan(prefixKey: UnsafeRow): Iterator[UnsafeRowPair]
 }
 
@@ -40,42 +45,69 @@ object HDFSBackedStateStoreMap {
   //   the map when the iterator was created
   // - Any updates to the map while iterating through the filtered iterator 
does not throw
   //   java.util.ConcurrentModificationException
-  type MapType = java.util.concurrent.ConcurrentHashMap[UnsafeRow, UnsafeRow]
+  type MapType = java.util.concurrent.ConcurrentHashMap[UnsafeRow, 
UnsafeRowWrapper]
 
-  def create(keySchema: StructType, numColsPrefixKey: Int): 
HDFSBackedStateStoreMap = {
+  def create(
+      keySchema: StructType,
+      numColsPrefixKey: Int,
+      readVerifier: Option[KeyValueIntegrityVerifier]): 
HDFSBackedStateStoreMap = {
     if (numColsPrefixKey > 0) {
-      new PrefixScannableHDFSBackedStateStoreMap(keySchema, numColsPrefixKey)
+      new PrefixScannableHDFSBackedStateStoreMap(keySchema, numColsPrefixKey, 
readVerifier)
     } else {
-      new NoPrefixHDFSBackedStateStoreMap()
+      new NoPrefixHDFSBackedStateStoreMap(readVerifier)
+    }
+  }
+
+  /** Get the value row from the value wrapper and verify it */
+  def getAndVerifyValueRow(
+      key: UnsafeRow,
+      valueWrapper: UnsafeRowWrapper,
+      readVerifier: Option[KeyValueIntegrityVerifier]): UnsafeRow = {
+    Option(valueWrapper) match {
+      case Some(value) =>
+        readVerifier.foreach(_.verify(key, value))
+        value.unsafeRow()
+      case None => null
     }
   }
 }
 
-class NoPrefixHDFSBackedStateStoreMap extends HDFSBackedStateStoreMap {
+class NoPrefixHDFSBackedStateStoreMap(private val readVerifier: 
Option[KeyValueIntegrityVerifier])
+    extends HDFSBackedStateStoreMap {
   private val map = new HDFSBackedStateStoreMap.MapType()
 
   override def size(): Int = map.size()
 
-  override def get(key: UnsafeRow): UnsafeRow = map.get(key)
+  override def get(key: UnsafeRow): UnsafeRow = {
+    HDFSBackedStateStoreMap.getAndVerifyValueRow(key, map.get(key), 
readVerifier)
+  }
 
-  override def put(key: UnsafeRow, value: UnsafeRow): UnsafeRow = map.put(key, 
value)
+  override def put(key: UnsafeRow, value: UnsafeRowWrapper): UnsafeRowWrapper 
= map.put(key, value)
 
   def putAll(other: HDFSBackedStateStoreMap): Unit = {
     other match {
       case o: NoPrefixHDFSBackedStateStoreMap => map.putAll(o.map)
-      case _ => other.iterator().foreach { pair => put(pair.key, pair.value) }
+      case _ => other.entryIterator().foreach { pair => put(pair.getKey, 
pair.getValue) }
     }
   }
 
-  override def remove(key: UnsafeRow): UnsafeRow = map.remove(key)
+  override def remove(key: UnsafeRow): UnsafeRowWrapper = map.remove(key)
 
   override def iterator(): Iterator[UnsafeRowPair] = {
     val unsafeRowPair = new UnsafeRowPair()
-    map.entrySet.asScala.iterator.map { entry =>
-      unsafeRowPair.withRows(entry.getKey, entry.getValue)
+    entryIterator().map { entry =>
+      val valueRow = HDFSBackedStateStoreMap
+        .getAndVerifyValueRow(entry.getKey, entry.getValue, readVerifier)
+      unsafeRowPair.withRows(entry.getKey, valueRow)
     }
   }
 
+  /** Returns entries in the underlying map and skips additional checks done 
by [[iterator]].
+   * [[iterator]] should be preferred over this. */
+  override def entryIterator(): Iterator[Entry[UnsafeRow, UnsafeRowWrapper]] = 
{
+    map.entrySet.asScala.iterator
+  }
+
   override def prefixScan(prefixKey: UnsafeRow): Iterator[UnsafeRowPair] = {
     throw SparkUnsupportedOperationException()
   }
@@ -83,7 +115,8 @@ class NoPrefixHDFSBackedStateStoreMap extends 
HDFSBackedStateStoreMap {
 
 class PrefixScannableHDFSBackedStateStoreMap(
     keySchema: StructType,
-    numColsPrefixKey: Int) extends HDFSBackedStateStoreMap {
+    numColsPrefixKey: Int,
+    private val readVerifier: Option[KeyValueIntegrityVerifier]) extends 
HDFSBackedStateStoreMap {
 
   private val map = new HDFSBackedStateStoreMap.MapType()
 
@@ -103,9 +136,11 @@ class PrefixScannableHDFSBackedStateStoreMap(
 
   override def size(): Int = map.size()
 
-  override def get(key: UnsafeRow): UnsafeRow = map.get(key)
+  override def get(key: UnsafeRow): UnsafeRow = {
+    HDFSBackedStateStoreMap.getAndVerifyValueRow(key, map.get(key), 
readVerifier)
+  }
 
-  override def put(key: UnsafeRow, value: UnsafeRow): UnsafeRow = {
+  override def put(key: UnsafeRow, value: UnsafeRowWrapper): UnsafeRowWrapper 
= {
     val ret = map.put(key, value)
 
     val prefixKey = prefixKeyProjection(key).copy()
@@ -136,11 +171,11 @@ class PrefixScannableHDFSBackedStateStoreMap(
           prefixKeyToKeysMap.put(prefixKey, newSet)
         }
 
-      case _ => other.iterator().foreach { pair => put(pair.key, pair.value) }
+      case _ => other.entryIterator().foreach { pair => put(pair.getKey, 
pair.getValue) }
     }
   }
 
-  override def remove(key: UnsafeRow): UnsafeRow = {
+  override def remove(key: UnsafeRow): UnsafeRowWrapper = {
     val ret = map.remove(key)
 
     if (ret != null) {
@@ -156,15 +191,27 @@ class PrefixScannableHDFSBackedStateStoreMap(
 
   override def iterator(): Iterator[UnsafeRowPair] = {
     val unsafeRowPair = new UnsafeRowPair()
-    map.entrySet.asScala.iterator.map { entry =>
-      unsafeRowPair.withRows(entry.getKey, entry.getValue)
+    entryIterator().map { entry =>
+      val valueRow = HDFSBackedStateStoreMap
+        .getAndVerifyValueRow(entry.getKey, entry.getValue, readVerifier)
+      unsafeRowPair.withRows(entry.getKey, valueRow)
     }
   }
 
+  /** Returns entries in the underlying map and skips additional checks done 
by [[iterator]].
+   * [[iterator]] should be preferred over this. */
+  override def entryIterator(): Iterator[Entry[UnsafeRow, UnsafeRowWrapper]] = 
{
+    map.entrySet.asScala.iterator
+  }
+
   override def prefixScan(prefixKey: UnsafeRow): Iterator[UnsafeRowPair] = {
     val unsafeRowPair = new UnsafeRowPair()
     prefixKeyToKeysMap.getOrDefault(prefixKey, mutable.Set.empty[UnsafeRow])
       .iterator
-      .map { key => unsafeRowPair.withRows(key, map.get(key)) }
+      .map { keyRow =>
+        val valueRow = HDFSBackedStateStoreMap
+          .getAndVerifyValueRow(keyRow, map.get(keyRow), readVerifier)
+        unsafeRowPair.withRows(keyRow, valueRow)
+      }
   }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
index f9c89bce1c02..5b00feb310c9 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
@@ -107,7 +107,9 @@ private[sql] class HDFSBackedStateStoreProvider extends 
StateStoreProvider with
   }
 
   /** Implementation of [[StateStore]] API which is backed by an 
HDFS-compatible file system */
-  class HDFSBackedStateStore(val version: Long, mapToUpdate: 
HDFSBackedStateStoreMap)
+  class HDFSBackedStateStore(
+      val version: Long,
+      private val mapToUpdate: HDFSBackedStateStoreMap)
     extends StateStore {
 
     /** Trait and classes representing the internal state of the store */
@@ -163,8 +165,16 @@ private[sql] class HDFSBackedStateStoreProvider extends 
StateStoreProvider with
       verify(state == UPDATING, "Cannot put after already committed or 
aborted")
       val keyCopy = key.copy()
       val valueCopy = value.copy()
-      mapToUpdate.put(keyCopy, valueCopy)
-      writeUpdateToDeltaFile(compressedStream, keyCopy, valueCopy)
+
+      val valueWrapper = if (storeConf.rowChecksumEnabled) {
+        // Add the key-value checksum to the value row
+        StateStoreRowWithChecksum(valueCopy, KeyValueChecksum.create(keyCopy, 
Some(valueCopy)))
+      } else {
+        StateStoreRow(valueCopy)
+      }
+
+      mapToUpdate.put(keyCopy, valueWrapper)
+      writeUpdateToDeltaFile(compressedStream, keyCopy, valueWrapper)
     }
 
     override def remove(key: UnsafeRow, colFamilyName: String): Unit = {
@@ -172,7 +182,13 @@ private[sql] class HDFSBackedStateStoreProvider extends 
StateStoreProvider with
       verify(state == UPDATING, "Cannot remove after already committed or 
aborted")
       val prevValue = mapToUpdate.remove(key)
       if (prevValue != null) {
-        writeRemoveToDeltaFile(compressedStream, key)
+        val keyWrapper = if (storeConf.rowChecksumEnabled) {
+          // Add checksum for only the removed key
+          StateStoreRowWithChecksum(key, KeyValueChecksum.create(key, None))
+        } else {
+          StateStoreRow(key)
+        }
+        writeRemoveToDeltaFile(compressedStream, keyWrapper)
       }
     }
 
@@ -315,7 +331,7 @@ private[sql] class HDFSBackedStateStoreProvider extends 
StateStoreProvider with
       if (version < 0) {
         throw QueryExecutionErrors.unexpectedStateStoreVersion(version)
       }
-      val newMap = HDFSBackedStateStoreMap.create(keySchema, numColsPrefixKey)
+      val newMap = createHDFSBackedStateStoreMap()
       if (version > 0) {
         newMap.putAll(loadMap(version))
       }
@@ -603,7 +619,7 @@ private[sql] class HDFSBackedStateStoreProvider extends 
StateStoreProvider with
 
         if (lastAvailableVersion <= 0) {
           // Use an empty map for versions 0 or less.
-          lastAvailableMap = Some(HDFSBackedStateStoreMap.create(keySchema, 
numColsPrefixKey))
+          lastAvailableMap = Some(createHDFSBackedStateStoreMap())
         } else {
           lastAvailableMap =
             synchronized { Option(loadedMaps.get(lastAvailableVersion)) }
@@ -613,7 +629,7 @@ private[sql] class HDFSBackedStateStoreProvider extends 
StateStoreProvider with
 
       // Load all the deltas from the version after the last available one up 
to the target version.
       // The last available version is the one with a full snapshot, so it 
doesn't need deltas.
-      val resultMap = HDFSBackedStateStoreMap.create(keySchema, 
numColsPrefixKey)
+      val resultMap = createHDFSBackedStateStoreMap()
       resultMap.putAll(lastAvailableMap.get)
       for (deltaVersion <- lastAvailableVersion + 1 to version) {
         updateFromDeltaFile(deltaVersion, resultMap)
@@ -635,17 +651,30 @@ private[sql] class HDFSBackedStateStoreProvider extends 
StateStoreProvider with
   private def writeUpdateToDeltaFile(
       output: DataOutputStream,
       key: UnsafeRow,
-      value: UnsafeRow): Unit = {
+      value: UnsafeRowWrapper): Unit = {
     val keyBytes = key.getBytes()
-    val valueBytes = value.getBytes()
+    val valueBytes = value match {
+      case v: StateStoreRowWithChecksum =>
+        // If it has checksum, encode the value bytes with the checksum.
+        KeyValueChecksumEncoder.encodeSingleValueRowWithChecksum(
+          v.unsafeRow().getBytes(), v.checksum)
+      case _ => value.unsafeRow().getBytes()
+    }
+
     output.writeInt(keyBytes.size)
     output.write(keyBytes)
     output.writeInt(valueBytes.size)
     output.write(valueBytes)
   }
 
-  private def writeRemoveToDeltaFile(output: DataOutputStream, key: 
UnsafeRow): Unit = {
-    val keyBytes = key.getBytes()
+  private def writeRemoveToDeltaFile(output: DataOutputStream, key: 
UnsafeRowWrapper): Unit = {
+    val keyBytes = key match {
+      case k: StateStoreRowWithChecksum =>
+        // If it has checksum, encode the key bytes with the checksum.
+        KeyValueChecksumEncoder.encodeKeyRowWithChecksum(
+          k.unsafeRow().getBytes(), k.checksum)
+      case _ => key.unsafeRow().getBytes()
+    }
     output.writeInt(keyBytes.size)
     output.write(keyBytes)
     output.writeInt(-1)
@@ -669,6 +698,10 @@ private[sql] class HDFSBackedStateStoreProvider extends 
StateStoreProvider with
       input = decompressStream(sourceStream)
       var eof = false
 
+      // If row checksum is enabled, verify every record in the file to detect 
corrupt rows.
+      val verifier = KeyValueIntegrityVerifier
+        .create(stateStoreId_.toString, storeConf.rowChecksumEnabled, 
verificationRatio = 1)
+
       while (!eof) {
         val keySize = input.readInt()
         if (keySize == -1) {
@@ -681,26 +714,46 @@ private[sql] class HDFSBackedStateStoreProvider extends 
StateStoreProvider with
           Utils.readFully(input, keyRowBuffer, 0, keySize)
 
           val keyRow = new UnsafeRow(keySchema.fields.length)
-          keyRow.pointTo(keyRowBuffer, keySize)
 
           val valueSize = input.readInt()
           if (valueSize < 0) {
+            val originalKeyBytes = if (storeConf.rowChecksumEnabled) {
+              // For deleted row, we added checksum to the key side.
+              // Decode the original key and remove the checksum part
+              
KeyValueChecksumEncoder.decodeAndVerifyKeyRowWithChecksum(verifier, 
keyRowBuffer)
+            } else {
+              keyRowBuffer
+            }
+            keyRow.pointTo(originalKeyBytes, originalKeyBytes.length)
             map.remove(keyRow)
           } else {
+            keyRow.pointTo(keyRowBuffer, keySize)
+
             val valueRowBuffer = new Array[Byte](valueSize)
             Utils.readFully(input, valueRowBuffer, 0, valueSize)
             val valueRow = new UnsafeRow(valueSchema.fields.length)
+
+            val (originalValueBytes, valueWrapper) = if 
(storeConf.rowChecksumEnabled) {
+              // checksum is on the value side
+              val (valueBytes, checksum) = KeyValueChecksumEncoder
+                .decodeSingleValueRowWithChecksum(valueRowBuffer)
+              verifier.foreach(_.verify(keyRowBuffer, Some(valueBytes), 
checksum))
+              (valueBytes, StateStoreRowWithChecksum(valueRow, checksum))
+            } else {
+              (valueRowBuffer, StateStoreRow(valueRow))
+            }
+
             // If valueSize in existing file is not multiple of 8, floor it to 
multiple of 8.
             // This is a workaround for the following:
             // Prior to Spark 2.3 mistakenly append 4 bytes to the value row in
             // `RowBasedKeyValueBatch`, which gets persisted into the 
checkpoint data
-            valueRow.pointTo(valueRowBuffer, (valueSize / 8) * 8)
+            valueRow.pointTo(originalValueBytes, (originalValueBytes.length / 
8) * 8)
             if (!isValidated) {
               StateStoreProvider.validateStateRowFormat(
                 keyRow, keySchema, valueRow, valueSchema, stateStoreId, 
storeConf)
               isValidated = true
             }
-            map.put(keyRow, valueRow)
+            map.put(keyRow, valueWrapper)
           }
         }
       }
@@ -721,11 +774,28 @@ private[sql] class HDFSBackedStateStoreProvider extends 
StateStoreProvider with
     try {
       rawOutput = fm.createAtomic(targetFile, overwriteIfPossible = true)
       output = compressStream(rawOutput)
-      val iter = map.iterator()
+      // Entry iterator doesn't do verification, we will do it ourselves
+      // Using this instead of iterator() since we want the UnsafeRowWrapper
+      val iter = map.entryIterator()
+
+      // If row checksum is enabled, we will verify every entry in the map 
before writing snapshot,
+      // to prevent writing corrupt rows since the map might have been created 
a while ago.
+      // This is fine since write snapshot is typically done in the background.
+      val verifier = KeyValueIntegrityVerifier
+        .create(stateStoreId_.toString, storeConf.rowChecksumEnabled, 
verificationRatio = 1)
+
       while (iter.hasNext) {
         val entry = iter.next()
-        val keyBytes = entry.key.getBytes()
-        val valueBytes = entry.value.getBytes()
+        verifier.foreach(_.verify(entry.getKey, entry.getValue))
+
+        val keyBytes = entry.getKey.getBytes()
+        val valueBytes = entry.getValue match {
+          case v: StateStoreRowWithChecksum =>
+            // If it has checksum, encode it with the checksum.
+            KeyValueChecksumEncoder.encodeSingleValueRowWithChecksum(
+              v.unsafeRow().getBytes(), v.checksum)
+          case o => o.unsafeRow().getBytes()
+        }
         output.writeInt(keyBytes.size)
         output.write(keyBytes)
         output.writeInt(valueBytes.size)
@@ -782,13 +852,17 @@ private[sql] class HDFSBackedStateStoreProvider extends 
StateStoreProvider with
   */
   private def readSnapshotFile(version: Long): Option[HDFSBackedStateStoreMap] 
= {
     val fileToRead = snapshotFile(version)
-    val map = HDFSBackedStateStoreMap.create(keySchema, numColsPrefixKey)
+    val map = createHDFSBackedStateStoreMap()
     var input: DataInputStream = null
 
     try {
       input = decompressStream(fm.open(fileToRead))
       var eof = false
 
+      // If row checksum is enabled, verify every record in the file to detect 
corrupt rows.
+      val verifier = KeyValueIntegrityVerifier
+        .create(stateStoreId_.toString, storeConf.rowChecksumEnabled, 
verificationRatio = 1)
+
       while (!eof) {
         val keySize = input.readInt()
         if (keySize == -1) {
@@ -811,17 +885,28 @@ private[sql] class HDFSBackedStateStoreProvider extends 
StateStoreProvider with
             val valueRowBuffer = new Array[Byte](valueSize)
             Utils.readFully(input, valueRowBuffer, 0, valueSize)
             val valueRow = new UnsafeRow(valueSchema.fields.length)
+
+            val (originalValueBytes, valueWrapper) = if 
(storeConf.rowChecksumEnabled) {
+              // Checksum is on the value side
+              val (valueBytes, checksum) = KeyValueChecksumEncoder
+                .decodeSingleValueRowWithChecksum(valueRowBuffer)
+              verifier.foreach(_.verify(keyRowBuffer, Some(valueBytes), 
checksum))
+              (valueBytes, StateStoreRowWithChecksum(valueRow, checksum))
+            } else {
+              (valueRowBuffer, StateStoreRow(valueRow))
+            }
+
             // If valueSize in existing file is not multiple of 8, floor it to 
multiple of 8.
             // This is a workaround for the following:
             // Prior to Spark 2.3 mistakenly append 4 bytes to the value row in
             // `RowBasedKeyValueBatch`, which gets persisted into the 
checkpoint data
-            valueRow.pointTo(valueRowBuffer, (valueSize / 8) * 8)
+            valueRow.pointTo(originalValueBytes, (originalValueBytes.length / 
8) * 8)
             if (!isValidated) {
               StateStoreProvider.validateStateRowFormat(
                 keyRow, keySchema, valueRow, valueSchema, stateStoreId, 
storeConf)
               isValidated = true
             }
-            map.put(keyRow, valueRow)
+            map.put(keyRow, valueWrapper)
           }
         }
       }
@@ -838,7 +923,7 @@ private[sql] class HDFSBackedStateStoreProvider extends 
StateStoreProvider with
 
 
   /** Perform a snapshot of the store to allow delta files to be consolidated 
*/
-  private def doSnapshot(opType: String): Unit = {
+  private def doSnapshot(opType: String, throwEx: Boolean = false): Unit = {
     try {
       val ((files, _), e1) = Utils.timeTakenMs(fetchFiles())
       logDebug(s"fetchFiles() took $e1 ms.")
@@ -860,6 +945,9 @@ private[sql] class HDFSBackedStateStoreProvider extends 
StateStoreProvider with
     } catch {
       case NonFatal(e) =>
         logWarning(log"Error doing snapshots", e)
+        if (throwEx) {
+          throw e
+        }
     }
   }
 
@@ -1079,7 +1167,7 @@ private[sql] class HDFSBackedStateStoreProvider extends 
StateStoreProvider with
         throw QueryExecutionErrors.unexpectedStateStoreVersion(endVersion)
       }
 
-      val newMap = HDFSBackedStateStoreMap.create(keySchema, numColsPrefixKey)
+      val newMap = createHDFSBackedStateStoreMap()
       newMap.putAll(constructMapFromSnapshot(snapshotVersion, endVersion))
 
       newMap
@@ -1107,7 +1195,7 @@ private[sql] class HDFSBackedStateStoreProvider extends 
StateStoreProvider with
       }
 
       // Load all the deltas from the version after the start version up to 
the end version.
-      val resultMap = HDFSBackedStateStoreMap.create(keySchema, 
numColsPrefixKey)
+      val resultMap = createHDFSBackedStateStoreMap()
       resultMap.putAll(startVersionMap.get)
       for (deltaVersion <- snapshotVersion + 1 to endVersion) {
         updateFromDeltaFile(deltaVersion, resultMap)
@@ -1122,6 +1210,15 @@ private[sql] class HDFSBackedStateStoreProvider extends 
StateStoreProvider with
     result
   }
 
+  private def createHDFSBackedStateStoreMap(): HDFSBackedStateStoreMap = {
+    val readVerifier = KeyValueIntegrityVerifier.create(
+      stateStoreId_.toString,
+      storeConf.rowChecksumEnabled,
+      storeConf.rowChecksumReadVerificationRatio)
+
+    HDFSBackedStateStoreMap.create(keySchema, numColsPrefixKey, readVerifier)
+  }
+
   override def getStateStoreChangeDataReader(
       startVersion: Long,
       endVersion: Long,
@@ -1140,9 +1237,9 @@ private[sql] class HDFSBackedStateStoreProvider extends 
StateStoreProvider with
       throw StateStoreErrors.multipleColumnFamiliesNotSupported(providerName)
     }
 
-    new HDFSBackedStateStoreChangeDataReader(fm, baseDir, startVersion, 
endVersion,
+    new HDFSBackedStateStoreChangeDataReader(stateStoreId_, fm, baseDir, 
startVersion, endVersion,
       CompressionCodec.createCodec(sparkConf, storeConf.compressionCodec),
-      keySchema, valueSchema)
+      keySchema, valueSchema, storeConf)
   }
 
   /** Reports to the coordinator the store's latest snapshot version */
@@ -1158,15 +1255,17 @@ private[sql] class HDFSBackedStateStoreProvider extends 
StateStoreProvider with
 
 /** [[StateStoreChangeDataReader]] implementation for 
[[HDFSBackedStateStoreProvider]] */
 class HDFSBackedStateStoreChangeDataReader(
+    storeId: StateStoreId,
     fm: CheckpointFileManager,
     stateLocation: Path,
     startVersion: Long,
     endVersion: Long,
     compressionCodec: CompressionCodec,
     keySchema: StructType,
-    valueSchema: StructType)
+    valueSchema: StructType,
+    storeConf: StateStoreConf)
   extends StateStoreChangeDataReader(
-    fm, stateLocation, startVersion, endVersion, compressionCodec) {
+    storeId, fm, stateLocation, startVersion, endVersion, compressionCodec, 
storeConf) {
 
   override protected val changelogSuffix: String = "delta"
 
@@ -1177,16 +1276,32 @@ class HDFSBackedStateStoreChangeDataReader(
     }
     val (recordType, keyArray, valueArray) = reader.next()
     val keyRow = new UnsafeRow(keySchema.fields.length)
-    keyRow.pointTo(keyArray, keyArray.length)
     if (valueArray == null) {
+      val originalKeyBytes = if (storeConf.rowChecksumEnabled) {
+        // Decode the original key and remove the checksum part
+        
KeyValueChecksumEncoder.decodeAndVerifyKeyRowWithChecksum(readVerifier, 
keyArray)
+      } else {
+        keyArray
+      }
+      keyRow.pointTo(originalKeyBytes, originalKeyBytes.length)
       (recordType, keyRow, null, currentChangelogVersion - 1)
     } else {
+      keyRow.pointTo(keyArray, keyArray.length)
+
       val valueRow = new UnsafeRow(valueSchema.fields.length)
+      val originalValueBytes = if (storeConf.rowChecksumEnabled) {
+        // Checksum is on the value side
+        KeyValueChecksumEncoder.decodeAndVerifySingleValueRowWithChecksum(
+          readVerifier, keyArray, valueArray)
+      } else {
+        valueArray
+      }
+
       // If valueSize in existing file is not multiple of 8, floor it to 
multiple of 8.
       // This is a workaround for the following:
       // Prior to Spark 2.3 mistakenly append 4 bytes to the value row in
       // `RowBasedKeyValueBatch`, which gets persisted into the checkpoint data
-      valueRow.pointTo(valueArray, (valueArray.length / 8) * 8)
+      valueRow.pointTo(originalValueBytes, (originalValueBytes.length / 8) * 8)
       (recordType, keyRow, valueRow, currentChangelogVersion - 1)
     }
   }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
index 00c5ab7c0cd1..6dbf6a777c3b 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
@@ -232,6 +232,12 @@ class RocksDB(
 
   private val shouldForceSnapshot: AtomicBoolean = new AtomicBoolean(false)
 
+  // Integrity verifier that is only used when clients read from the db e.g. 
db.get()
+  private val readVerifier: Option[KeyValueIntegrityVerifier] = 
KeyValueIntegrityVerifier.create(
+    loggingId,
+    conf.rowChecksumEnabled,
+    conf.rowChecksumReadVerificationRatio)
+
   private def getColumnFamilyInfo(cfName: String): ColumnFamilyInfo = {
     colFamilyNameToInfoMap.get(cfName)
   }
@@ -831,31 +837,26 @@ class RocksDB(
       try {
         changelogReader = fileManager.getChangelogReader(v, uniqueId)
 
-        if (useColumnFamilies) {
-          changelogReader.foreach { case (recordType, key, value) =>
-            recordType match {
-              case RecordType.PUT_RECORD =>
-                put(key, value, includesPrefix = true, deriveCfName = true)
-
-              case RecordType.DELETE_RECORD =>
-                remove(key, includesPrefix = true, deriveCfName = true)
-
-              case RecordType.MERGE_RECORD =>
-                merge(key, value, includesPrefix = true, deriveCfName = true)
-            }
-          }
-        } else {
-          changelogReader.foreach { case (recordType, key, value) =>
-            recordType match {
-              case RecordType.PUT_RECORD =>
-                put(key, value)
-
-              case RecordType.DELETE_RECORD =>
-                remove(key)
-
-              case RecordType.MERGE_RECORD =>
-                merge(key, value)
-            }
+        // If row checksum is enabled, verify every record in the changelog 
file
+        val kvVerifier = KeyValueIntegrityVerifier
+          .create(loggingId, conf.rowChecksumEnabled, verificationRatio = 1)
+
+        changelogReader.foreach { case (recordType, key, value) =>
+          recordType match {
+            case RecordType.PUT_RECORD =>
+              verifyChangelogRecord(kvVerifier, key, Some(value))
+              put(key, value, includesPrefix = useColumnFamilies,
+                deriveCfName = useColumnFamilies, includesChecksum = 
conf.rowChecksumEnabled)
+
+            case RecordType.DELETE_RECORD =>
+              verifyChangelogRecord(kvVerifier, key, None)
+              remove(key, includesPrefix = useColumnFamilies,
+                deriveCfName = useColumnFamilies, includesChecksum = 
conf.rowChecksumEnabled)
+
+            case RecordType.MERGE_RECORD =>
+              verifyChangelogRecord(kvVerifier, key, Some(value))
+              merge(key, value, includesPrefix = useColumnFamilies,
+                deriveCfName = useColumnFamilies, includesChecksum = 
conf.rowChecksumEnabled)
           }
         }
       } finally {
@@ -870,6 +871,28 @@ class RocksDB(
     )
   }
 
+  private def verifyChangelogRecord(
+      verifier: Option[KeyValueIntegrityVerifier],
+      keyBytes: Array[Byte],
+      valueBytes: Option[Array[Byte]]): Unit = {
+    verifier match {
+      case Some(v) if v.isInstanceOf[KeyValueChecksumVerifier] =>
+        // Do checksum verification inline using array index without copying 
bytes
+        valueBytes.map { value =>
+          // Checksum is on the value side for PUT/MERGE record
+          val (valueIndex, checksum) = KeyValueChecksumEncoder
+            .decodeOneValueRowIndexWithChecksum(value)
+          v.verify(ArrayIndexRange(keyBytes, 0, keyBytes.length), 
Some(valueIndex), checksum)
+        }.getOrElse {
+          // For DELETE valueBytes is None, we only check the key
+          val (keyIndex, checksum) = KeyValueChecksumEncoder
+            .decodeKeyRowIndexWithChecksum(keyBytes)
+          v.verify(keyIndex, None, checksum)
+        }
+      case _ =>
+    }
+  }
+
   /**
    * Function to encode state row with virtual col family id prefix
    * @param data - passed byte array to be stored in state store
@@ -904,13 +927,41 @@ class RocksDB(
       key: Array[Byte],
       cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Array[Byte] = {
     updateMemoryUsageIfNeeded()
+    val (finalKey, value) = getValue(key, cfName)
+    if (conf.rowChecksumEnabled && value != null) {
+      KeyValueChecksumEncoder.decodeAndVerifyValueRowWithChecksum(
+        readVerifier, finalKey, value)
+    } else {
+      value
+    }
+  }
+
+  /**
+   * Get the values for a given key if present, that were merged (via merge).
+   * This returns the values as an iterator of index range, to allow inline 
access
+   * of each value bytes without copying, for better performance.
+   * Note: This method is currently only supported when row checksum is 
enabled.
+   * */
+  def multiGet(
+      key: Array[Byte],
+      cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME): 
Iterator[ArrayIndexRange[Byte]] = {
+    assert(conf.rowChecksumEnabled, "multiGet is only allowed when row 
checksum is enabled")
+    updateMemoryUsageIfNeeded()
+
+    val (finalKey, value) = getValue(key, cfName)
+    KeyValueChecksumEncoder.decodeAndVerifyMultiValueRowWithChecksum(
+      readVerifier, finalKey, value)
+  }
+
+  /** Returns a tuple of the final key used to store the value in the db and 
the value. */
+  private def getValue(key: Array[Byte], cfName: String): (Array[Byte], 
Array[Byte]) = {
     val keyWithPrefix = if (useColumnFamilies) {
       encodeStateRowWithPrefix(key, cfName)
     } else {
       key
     }
 
-    db.get(readOptions, keyWithPrefix)
+    (keyWithPrefix, db.get(readOptions, keyWithPrefix))
   }
 
   /**
@@ -971,7 +1022,8 @@ class RocksDB(
       value: Array[Byte],
       cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME,
       includesPrefix: Boolean = false,
-      deriveCfName: Boolean = false): Unit = {
+      deriveCfName: Boolean = false,
+      includesChecksum: Boolean = false): Unit = {
     updateMemoryUsageIfNeeded()
     val keyWithPrefix = if (useColumnFamilies && !includesPrefix) {
       encodeStateRowWithPrefix(key, cfName)
@@ -986,9 +1038,16 @@ class RocksDB(
       cfName
     }
 
+    val valueWithChecksum = if (conf.rowChecksumEnabled && !includesChecksum) {
+      KeyValueChecksumEncoder.encodeValueRowWithChecksum(value,
+        KeyValueChecksum.create(keyWithPrefix, Some(value)))
+    } else {
+      value
+    }
+
     handleMetricsUpdate(keyWithPrefix, columnFamilyName, isPutOrMerge = true)
-    db.put(writeOptions, keyWithPrefix, value)
-    changelogWriter.foreach(_.put(keyWithPrefix, value))
+    db.put(writeOptions, keyWithPrefix, valueWithChecksum)
+    changelogWriter.foreach(_.put(keyWithPrefix, valueWithChecksum))
   }
 
   /**
@@ -1007,7 +1066,8 @@ class RocksDB(
       value: Array[Byte],
       cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME,
       includesPrefix: Boolean = false,
-      deriveCfName: Boolean = false): Unit = {
+      deriveCfName: Boolean = false,
+      includesChecksum: Boolean = false): Unit = {
     updateMemoryUsageIfNeeded()
     val keyWithPrefix = if (useColumnFamilies && !includesPrefix) {
       encodeStateRowWithPrefix(key, cfName)
@@ -1022,9 +1082,16 @@ class RocksDB(
       cfName
     }
 
+    val valueWithChecksum = if (conf.rowChecksumEnabled && !includesChecksum) {
+      KeyValueChecksumEncoder.encodeValueRowWithChecksum(value,
+        KeyValueChecksum.create(keyWithPrefix, Some(value)))
+    } else {
+      value
+    }
+
     handleMetricsUpdate(keyWithPrefix, columnFamilyName, isPutOrMerge = true)
-    db.merge(writeOptions, keyWithPrefix, value)
-    changelogWriter.foreach(_.merge(keyWithPrefix, value))
+    db.merge(writeOptions, keyWithPrefix, valueWithChecksum)
+    changelogWriter.foreach(_.merge(keyWithPrefix, valueWithChecksum))
   }
 
   /**
@@ -1035,14 +1102,23 @@ class RocksDB(
       key: Array[Byte],
       cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME,
       includesPrefix: Boolean = false,
-      deriveCfName: Boolean = false): Unit = {
+      deriveCfName: Boolean = false,
+      includesChecksum: Boolean = false): Unit = {
     updateMemoryUsageIfNeeded()
-    val keyWithPrefix = if (useColumnFamilies && !includesPrefix) {
-      encodeStateRowWithPrefix(key, cfName)
+    val originalKey = if (conf.rowChecksumEnabled && includesChecksum) {
+      // When we are replaying changelog record, the delete key in the file 
includes checksum.
+      // Remove the checksum, so we use the original key for db.delete.
+      KeyValueChecksumEncoder.decodeKeyRowWithChecksum(key)._1
     } else {
       key
     }
 
+    val keyWithPrefix = if (useColumnFamilies && !includesPrefix) {
+      encodeStateRowWithPrefix(originalKey, cfName)
+    } else {
+      originalKey
+    }
+
     val columnFamilyName = if (deriveCfName && useColumnFamilies) {
       val (_, cfName) = decodeStateRowWithPrefix(keyWithPrefix)
       cfName
@@ -1052,7 +1128,18 @@ class RocksDB(
 
     handleMetricsUpdate(keyWithPrefix, columnFamilyName, isPutOrMerge = false)
     db.delete(writeOptions, keyWithPrefix)
-    changelogWriter.foreach(_.delete(keyWithPrefix))
+    changelogWriter match {
+      case Some(writer) =>
+        val keyWithChecksum = if (conf.rowChecksumEnabled) {
+          // For delete, we will write a checksum with the key row only to the 
changelog file.
+          KeyValueChecksumEncoder.encodeKeyRowWithChecksum(keyWithPrefix,
+            KeyValueChecksum.create(keyWithPrefix, None))
+        } else {
+          keyWithPrefix
+        }
+        writer.delete(keyWithChecksum)
+      case None => // During changelog replay, there is no changelog writer.
+    }
   }
 
   /**
@@ -1079,7 +1166,14 @@ class RocksDB(
             iter.key
           }
 
-          byteArrayPair.set(key, iter.value)
+          val value = if (conf.rowChecksumEnabled) {
+            KeyValueChecksumEncoder.decodeAndVerifyValueRowWithChecksum(
+              readVerifier, iter.key, iter.value)
+          } else {
+            iter.value
+          }
+
+          byteArrayPair.set(key, value)
           iter.next()
           byteArrayPair
         } else {
@@ -1170,7 +1264,14 @@ class RocksDB(
             iter.key
           }
 
-          byteArrayPair.set(key, iter.value)
+          val value = if (conf.rowChecksumEnabled) {
+            KeyValueChecksumEncoder.decodeAndVerifyValueRowWithChecksum(
+              readVerifier, iter.key, iter.value)
+          } else {
+            iter.value
+          }
+
+          byteArrayPair.set(key, value)
           iter.next()
           byteArrayPair
         } else {
@@ -1969,8 +2070,10 @@ case class RocksDBConf(
     allowFAllocate: Boolean,
     compression: String,
     reportSnapshotUploadLag: Boolean,
+    maxVersionsToDeletePerMaintenance: Int,
     fileChecksumEnabled: Boolean,
-    maxVersionsToDeletePerMaintenance: Int)
+    rowChecksumEnabled: Boolean,
+    rowChecksumReadVerificationRatio: Long)
 
 object RocksDBConf {
   /** Common prefix of all confs in SQLConf that affects RocksDB */
@@ -2169,8 +2272,10 @@ object RocksDBConf {
       getBooleanConf(ALLOW_FALLOCATE_CONF),
       getStringConf(COMPRESSION_CONF),
       storeConf.reportSnapshotUploadLag,
+      storeConf.maxVersionsToDeletePerMaintenance,
       storeConf.checkpointFileChecksumEnabled,
-      storeConf.maxVersionsToDeletePerMaintenance)
+      storeConf.rowChecksumEnabled,
+      storeConf.rowChecksumReadVerificationRatio)
   }
 
   def apply(): RocksDBConf = apply(new StateStoreConf())
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
index f49c79f96b9c..b4410362c4d3 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
@@ -54,6 +54,7 @@ sealed trait RocksDBValueStateEncoder {
   def encodeValue(row: UnsafeRow): Array[Byte]
   def decodeValue(valueBytes: Array[Byte]): UnsafeRow
   def decodeValues(valueBytes: Array[Byte]): Iterator[UnsafeRow]
+  def decodeValues(valueBytesIterator: Iterator[ArrayIndexRange[Byte]]): 
Iterator[UnsafeRow]
 }
 
 trait StateSchemaProvider extends Serializable {
@@ -1785,6 +1786,40 @@ class MultiValuedStateEncoder(
     }
   }
 
+  /** Takes in an iterator of [[ArrayIndexRange]], each index range presents 
the range of bytes
+   * for the current value in the underlying array. */
+  override def decodeValues(
+      valueBytesIterator: Iterator[ArrayIndexRange[Byte]]): 
Iterator[UnsafeRow] = {
+    if (valueBytesIterator == null) {
+      Seq().iterator
+    } else {
+      new Iterator[UnsafeRow] {
+        override def hasNext: Boolean = valueBytesIterator.hasNext
+
+        override def next(): UnsafeRow = {
+          // Get the index range of the next value
+          val valueBytesIndex = valueBytesIterator.next()
+          val allValuesBytes = valueBytesIndex.array
+          // convert array index to memory offset
+          var pos = valueBytesIndex.fromIndex + Platform.BYTE_ARRAY_OFFSET
+          // Get value length
+          val numBytes = Platform.getInt(allValuesBytes, pos)
+          pos += java.lang.Integer.BYTES
+
+          // Extract the bytes for this value
+          val encodedValue = new Array[Byte](numBytes)
+          Platform.copyMemory(
+            allValuesBytes, pos,
+            encodedValue, Platform.BYTE_ARRAY_OFFSET,
+            numBytes
+          )
+
+          dataEncoder.decodeValue(encodedValue)
+        }
+      }
+    }
+  }
+
   override def supportsMultipleValuesPerKey: Boolean = true
 }
 
@@ -1824,4 +1859,9 @@ class SingleValueStateEncoder(
   override def decodeValues(valueBytes: Array[Byte]): Iterator[UnsafeRow] = {
     throw new IllegalStateException("This encoder doesn't support multiple 
values!")
   }
+
+  override def decodeValues(
+      valueBytesIterator: Iterator[ArrayIndexRange[Byte]]): 
Iterator[UnsafeRow] = {
+    throw new IllegalStateException("This encoder doesn't support multiple 
values!")
+  }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
index edcebca466a5..f13c3a690b11 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
@@ -273,8 +273,14 @@ private[sql] class RocksDBStateStoreProvider
       verify(valueEncoder.supportsMultipleValuesPerKey, "valuesIterator 
requires a encoder " +
       "that supports multiple values for a single key.")
 
-      val encodedValues = rocksDB.get(keyEncoder.encodeKey(key), colFamilyName)
-      valueEncoder.decodeValues(encodedValues)
+      if (storeConf.rowChecksumEnabled) {
+        // multiGet provides better perf for row checksum, since it avoids 
copying values
+        val encodedValuesIterator = 
rocksDB.multiGet(keyEncoder.encodeKey(key), colFamilyName)
+        valueEncoder.decodeValues(encodedValuesIterator)
+      } else {
+        val encodedValues = rocksDB.get(keyEncoder.encodeKey(key), 
colFamilyName)
+        valueEncoder.decodeValues(encodedValues)
+      }
     }
 
     override def merge(key: UnsafeRow, value: UnsafeRow,
@@ -895,6 +901,7 @@ private[sql] class RocksDBStateStoreProvider
     val statePath = stateStoreId.storeCheckpointLocation()
     val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf)
     new RocksDBStateStoreChangeDataReader(
+      stateStoreId,
       CheckpointFileManager.create(statePath, hadoopConf),
       rocksDB,
       statePath,
@@ -903,6 +910,7 @@ private[sql] class RocksDBStateStoreProvider
       endVersionStateStoreCkptId,
       CompressionCodec.createCodec(sparkConf, storeConf.compressionCodec),
       keyValueEncoderMap,
+      storeConf,
       colFamilyNameOpt)
   }
 
@@ -1238,6 +1246,7 @@ object RocksDBStateStoreProvider {
 
 /** [[StateStoreChangeDataReader]] implementation for 
[[RocksDBStateStoreProvider]] */
 class RocksDBStateStoreChangeDataReader(
+    storeId: StateStoreId,
     fm: CheckpointFileManager,
     rocksDB: RocksDB,
     stateLocation: Path,
@@ -1247,9 +1256,11 @@ class RocksDBStateStoreChangeDataReader(
     compressionCodec: CompressionCodec,
     keyValueEncoderMap:
       ConcurrentHashMap[String, (RocksDBKeyStateEncoder, 
RocksDBValueStateEncoder, Short)],
+    storeConf: StateStoreConf,
     colFamilyNameOpt: Option[String] = None)
   extends StateStoreChangeDataReader(
-    fm, stateLocation, startVersion, endVersion, compressionCodec, 
colFamilyNameOpt) {
+    storeId, fm, stateLocation, startVersion, endVersion, compressionCodec,
+    storeConf, colFamilyNameOpt) {
 
   override protected val versionsAndUniqueIds: Array[(Long, Option[String])] =
     if (endVersionStateStoreCkptId.isDefined) {
@@ -1285,15 +1296,30 @@ class RocksDBStateStoreChangeDataReader(
         }
 
         val nextRecord = reader.next()
+        val keyBytes = if (storeConf.rowChecksumEnabled
+          && nextRecord._1 == RecordType.DELETE_RECORD) {
+          // remove checksum and decode to the original key
+          KeyValueChecksumEncoder
+            .decodeAndVerifyKeyRowWithChecksum(readVerifier, nextRecord._2)
+        } else {
+          nextRecord._2
+        }
         val colFamilyIdBytes: Array[Byte] =
           RocksDBStateStoreProvider.getColumnFamilyIdAsBytes(currEncoder._3)
         val endIndex = colFamilyIdBytes.size
         // Function checks for byte arrays being equal
         // from index 0 to endIndex - 1 (both inclusive)
-        if (java.util.Arrays.equals(nextRecord._2, 0, endIndex,
+        if (java.util.Arrays.equals(keyBytes, 0, endIndex,
           colFamilyIdBytes, 0, endIndex)) {
-          val extractedKey = 
RocksDBStateStoreProvider.decodeStateRowWithPrefix(nextRecord._2)
-          val result = (nextRecord._1, extractedKey, nextRecord._3)
+          val valueBytes = if (storeConf.rowChecksumEnabled &&
+            nextRecord._1 != RecordType.DELETE_RECORD) {
+            KeyValueChecksumEncoder.decodeAndVerifyValueRowWithChecksum(
+              readVerifier, keyBytes, nextRecord._3)
+          } else {
+            nextRecord._3
+          }
+          val extractedKey = 
RocksDBStateStoreProvider.decodeStateRowWithPrefix(keyBytes)
+          val result = (nextRecord._1, extractedKey, valueBytes)
           currRecord = result
         }
       }
@@ -1302,7 +1328,21 @@ class RocksDBStateStoreChangeDataReader(
       if (reader == null) {
         return null
       }
-      currRecord = reader.next()
+      val nextRecord = reader.next()
+      currRecord = if (storeConf.rowChecksumEnabled) {
+        nextRecord._1 match {
+          case RecordType.DELETE_RECORD =>
+            val key = KeyValueChecksumEncoder
+              .decodeAndVerifyKeyRowWithChecksum(readVerifier, nextRecord._2)
+            (nextRecord._1, key, nextRecord._3)
+          case _ =>
+            val value = 
KeyValueChecksumEncoder.decodeAndVerifyValueRowWithChecksum(
+              readVerifier, nextRecord._2, nextRecord._3)
+            (nextRecord._1, nextRecord._2, value)
+        }
+      } else {
+        nextRecord
+      }
     }
 
     val keyRow = currEncoder._1.decodeKey(currRecord._2)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
index c2b3d0676d36..9a5c78e92ed9 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
@@ -1614,4 +1614,3 @@ object StateStore extends Logging {
     }
   }
 }
-
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala
index bd56f72489f9..2029c0988756 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala
@@ -619,6 +619,7 @@ class StateStoreChangelogReaderV4(
  * store. In each iteration, it will return a tuple of (changeType: 
[[RecordType]],
  * nested key: [[UnsafeRow]], nested value: [[UnsafeRow]], batchId: [[Long]])
  *
+ * @param storeId id of the state store
  * @param fm checkpoint file manager used to manage streaming query checkpoint
  * @param stateLocation location of the state store
  * @param startVersion start version of the changelog file to read
@@ -627,17 +628,24 @@ class StateStoreChangelogReaderV4(
  * @param colFamilyNameOpt optional column family name to read from
  */
 abstract class StateStoreChangeDataReader(
+    storeId: StateStoreId,
     fm: CheckpointFileManager,
     stateLocation: Path,
     startVersion: Long,
     endVersion: Long,
     compressionCodec: CompressionCodec,
+    storeConf: StateStoreConf,
     colFamilyNameOpt: Option[String] = None)
   extends NextIterator[(RecordType.Value, UnsafeRow, UnsafeRow, Long)] with 
Logging {
 
   assert(startVersion >= 1)
   assert(endVersion >= startVersion)
 
+  protected val readVerifier: Option[KeyValueIntegrityVerifier] = 
KeyValueIntegrityVerifier.create(
+    storeId.toString,
+    storeConf.rowChecksumEnabled,
+    storeConf.rowChecksumReadVerificationRatio)
+
   /**
    * Iterator that iterates over the changelog files in the state store.
    */
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
index 74904a37f450..75a1b37ea5fe 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
@@ -123,6 +123,12 @@ class StateStoreConf(
   /** Whether to unload the store on task completion. */
   val unloadOnCommit = sqlConf.stateStoreUnloadOnCommit
 
+  /** whether to enable checksum for state store rows. */
+  val rowChecksumEnabled = sqlConf.stateStoreRowChecksumEnabled
+
+  /** How often should we do row checksum verification when rows are read from 
the state store. */
+  val rowChecksumReadVerificationRatio: Long = 
sqlConf.stateStoreRowChecksumReadVerificationRatio
+
   /** The version of the state store checkpoint format. */
   val stateStoreCheckpointFormatVersion: Int = 
sqlConf.stateStoreCheckpointFormatVersion
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala
index 970499a054b5..7efe752e63a1 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala
@@ -104,6 +104,13 @@ object StateStoreErrors {
     new StateStoreInvalidStamp(providedStamp, currentStamp)
   }
 
+  def rowChecksumVerificationFailed(
+      storeId: String,
+      expectedChecksum: Int,
+      computedChecksum: Int): StateStoreRowChecksumVerificationFailed = {
+    new StateStoreRowChecksumVerificationFailed(storeId, expectedChecksum, 
computedChecksum)
+  }
+
   def incorrectNumOrderingColsForRangeScan(numOrderingCols: String):
     StateStoreIncorrectNumOrderingColsForRangeScan = {
     new StateStoreIncorrectNumOrderingColsForRangeScan(numOrderingCols)
@@ -576,6 +583,18 @@ class StateStoreCommitValidationFailed(
     )
   )
 
+class StateStoreRowChecksumVerificationFailed(
+    storeId: String,
+    expectedChecksum: Int,
+    computedChecksum: Int)
+  extends SparkException(
+    errorClass = "STATE_STORE_ROW_CHECKSUM_VERIFICATION_FAILED",
+    messageParameters = Map(
+      "stateStoreId" -> storeId,
+      "expectedChecksum" -> expectedChecksum.toString,
+      "computedChecksum" -> computedChecksum.toString),
+    cause = null)
+
 class StateStoreUnexpectedEmptyFileInRocksDBZip(fileName: String, zipFileName: 
String)
   extends SparkException(
     errorClass = "STATE_STORE_UNEXPECTED_EMPTY_FILE_IN_ROCKSDB_ZIP",
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRow.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRow.scala
new file mode 100644
index 000000000000..3268beff1167
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRow.scala
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state
+
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+
+/**
+ * [[UnsafeRow]] is immutable and this allows wrapping it with additional info
+ * without having to copy/modify it.
+ * */
+trait UnsafeRowWrapper {
+  def unsafeRow(): UnsafeRow
+}
+
+/**
+ * A simple wrapper over the state store [[UnsafeRow]]
+ * */
+class StateStoreRow(row: UnsafeRow) extends UnsafeRowWrapper {
+  override def unsafeRow(): UnsafeRow = row
+}
+
+object StateStoreRow {
+  def apply(row: UnsafeRow): StateStoreRow = new StateStoreRow(row)
+}
+
+/**
+ * This is used to represent a range of indices in an array. Useful when we 
want to operate on a
+ * subset of an array without copying it.
+ *
+ * @param array The underlying array.
+ * @param fromIndex The starting index.
+ * @param untilIndex The end index (exclusive).
+ * */
+case class ArrayIndexRange[T](array: Array[T], fromIndex: Int, untilIndex: 
Int) {
+  // When fromIndex == untilIndex, it is an empty array
+  assert(fromIndex >= 0 && fromIndex <= untilIndex,
+    s"Invalid range: fromIndex ($fromIndex) should be >= 0 and <= untilIndex 
($untilIndex)")
+
+  /** The number of elements in the range. */
+  def length: Int = untilIndex - fromIndex
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRowChecksum.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRowChecksum.scala
new file mode 100644
index 000000000000..f02a598600c3
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRowChecksum.scala
@@ -0,0 +1,536 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state
+
+import java.util.zip.CRC32C
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.LogKeys.{CHECKSUM, STATE_STORE_ID}
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.unsafe.Platform
+
+/**
+ * A State store row and wrapper for [[UnsafeRow]] that includes the row 
checksum.
+ *
+ * @param row The UnsafeRow to be wrapped.
+ * @param rowChecksum The checksum of the row.
+ */
+class StateStoreRowWithChecksum(row: UnsafeRow, rowChecksum: Int) extends 
StateStoreRow(row) {
+  def checksum: Int = rowChecksum
+}
+
+object StateStoreRowWithChecksum {
+  def apply(row: UnsafeRow, rowChecksum: Int): StateStoreRowWithChecksum = {
+    new StateStoreRowWithChecksum(row, rowChecksum)
+  }
+}
+
+/**
+ * Integrity verifier for a State store key and value pair.
+ * To ensure the key and value in the state store isn't corrupt.
+ * */
+trait KeyValueIntegrityVerifier {
+  def verify(key: UnsafeRow, value: UnsafeRowWrapper): Unit
+  def verify(keyBytes: Array[Byte], valueBytes: Option[Array[Byte]], 
expectedChecksum: Int): Unit
+  def verify(
+      keyBytes: ArrayIndexRange[Byte],
+      valueBytes: Option[ArrayIndexRange[Byte]],
+      expectedChecksum: Int): Unit
+}
+
+object KeyValueIntegrityVerifier {
+  def create(
+      storeId: String,
+      rowChecksumEnabled: Boolean,
+      verificationRatio: Long): Option[KeyValueIntegrityVerifier] = {
+    assert(verificationRatio >= 0, "Verification ratio must be non-negative")
+    if (rowChecksumEnabled && verificationRatio != 0) {
+      Some(KeyValueChecksumVerifier(storeId, verificationRatio))
+    } else {
+      None
+    }
+  }
+}
+
+/**
+ * Checksum based key value verifier. Computes the checksum of the key and 
value bytes
+ * and compares it to the expected checksum. This also supports rate limiting 
the number of
+ * verification performed via the [[verificationRatio]] parameter. Given that 
the verify method
+ * can be called many times for different key and value pairs, this can be 
used to reduce the number
+ * of key-value verified.
+ *
+ * NOTE: Not thread safe and expected to be accessed one thread at a time.
+ *
+ * @param storeId The id of the state store, used for logging purpose.
+ * @param verificationRatio How often should verification occur.
+ *                          Setting a value N means for every N verify call.
+ * */
+class KeyValueChecksumVerifier(
+    storeId: String,
+    verificationRatio: Long) extends KeyValueIntegrityVerifier with Logging {
+  assert(verificationRatio > 0, "Verification ratio must be greater than 0")
+
+  // Number of verification requests received via verify method call
+  @volatile private var numRequests = 0L
+  // Number of checksum verification performed
+  @volatile private var numVerified = 0L
+  def getNumRequests: Long = numRequests
+  def getNumVerified: Long = numVerified
+
+  /** [[verify]] methods call this to see if they should proceed */
+  private def shouldVerify(): Boolean = {
+    numRequests += 1
+    if (numRequests % verificationRatio == 0) {
+      numVerified += 1
+      true
+    } else {
+      false
+    }
+  }
+
+  /**
+   * Computes the checksum of the key and value row, and compares it with the
+   * expected checksum included in the [[UnsafeRowWrapper]]. Might not do any
+   * checksum verification, depending on the [[verificationRatio]].
+   *
+   * @param key The key row.
+   * @param value The value row.
+   * */
+  override def verify(key: UnsafeRow, value: UnsafeRowWrapper): Unit = {
+    assert(value.isInstanceOf[StateStoreRowWithChecksum],
+      s"Expected StateStoreRowWithChecksum, but got ${value.getClass.getName}")
+
+    if (shouldVerify()) {
+      val valWithChecksum = value.asInstanceOf[StateStoreRowWithChecksum]
+      val expected = valWithChecksum.checksum
+      val computed = KeyValueChecksum.create(key, 
Some(valWithChecksum.unsafeRow()))
+      verifyChecksum(expected, computed)
+    }
+  }
+
+  /**
+   * Computes the checksum of the key and value bytes (if present), and 
compares it with the
+   * expected checksum specified. Might not do any checksum verification,
+   * depending on the [[verificationRatio]].
+   *
+   * @param keyBytes The key bytes.
+   * @param valueBytes Optional value bytes.
+   * @param expectedChecksum The expected checksum value.
+   * */
+  override def verify(
+      keyBytes: Array[Byte],
+      valueBytes: Option[Array[Byte]],
+      expectedChecksum: Int): Unit = {
+    if (shouldVerify()) {
+      val computed = KeyValueChecksum.create(keyBytes, valueBytes)
+      verifyChecksum(expectedChecksum, computed)
+    }
+  }
+
+  /**
+   * Computes the checksum of the key and value bytes (if present), and 
compares it with the
+   * expected checksum specified. Might not do any checksum verification,
+   * depending on the [[verificationRatio]].
+   *
+   * @param keyBytes Specifies the index range of the key bytes in the 
underlying array.
+   * @param valueBytes Optional, specifies the index range of the value bytes
+   *                   in the underlying array.
+   * @param expectedChecksum The expected checksum value.
+   * */
+  override def verify(
+      keyBytes: ArrayIndexRange[Byte],
+      valueBytes: Option[ArrayIndexRange[Byte]],
+      expectedChecksum: Int): Unit = {
+    if (shouldVerify()) {
+      val computed = KeyValueChecksum.create(keyBytes, valueBytes)
+      verifyChecksum(expectedChecksum, computed)
+    }
+  }
+
+  private def verifyChecksum(expected: Int, computed: Int): Unit = {
+    logDebug(s"Verifying row checksum, expected: $expected, computed: 
$computed")
+    if (expected != computed) {
+      logError(log"Row checksum verification failed for store 
${MDC(STATE_STORE_ID, storeId)}, " +
+        log"Expected checksum: ${MDC(CHECKSUM, expected)}, " +
+        log"Computed checksum: ${MDC(CHECKSUM, computed)}")
+      throw StateStoreErrors.rowChecksumVerificationFailed(storeId, expected, 
computed)
+    }
+  }
+}
+
+object KeyValueChecksumVerifier {
+  def apply(storeId: String, verificationRatio: Long): 
KeyValueChecksumVerifier = {
+    new KeyValueChecksumVerifier(storeId, verificationRatio)
+  }
+}
+
+/** For Key Value checksum creation */
+object KeyValueChecksum {
+  /**
+   * Creates a checksum value using the bytes of the key and value row.
+   * If value row isn't specified, will create using key row bytes only.
+   * */
+  def create(keyRow: UnsafeRow, valueRow: Option[UnsafeRow]): Int = {
+    create(keyRow.getBytes, valueRow.map(_.getBytes))
+  }
+
+  /**
+   * Creates a checksum value using the key and value bytes.
+   * If value bytes isn't specified, will create using key bytes only.
+   * */
+  def create(keyBytes: Array[Byte], valueBytes: Option[Array[Byte]]): Int = {
+    create(
+      ArrayIndexRange(keyBytes, 0, keyBytes.length),
+      valueBytes.map(v => ArrayIndexRange(v, 0, v.length)))
+  }
+
+  /**
+   * Creates a checksum value using key bytes array index range and value 
bytes array index range.
+   * If value bytes index range isn't specified, will create using key bytes 
only.
+   *
+   * @param keyBytes Specifies the index range of bytes to use in the 
underlying array.
+   * @param valueBytes Optional, specifies the index range of bytes to use in 
the underlying array.
+   * */
+  def create(keyBytes: ArrayIndexRange[Byte], valueBytes: 
Option[ArrayIndexRange[Byte]]): Int = {
+    // We can later make the checksum algorithm configurable
+    val crc32c = new CRC32C()
+
+    crc32c.update(keyBytes.array, keyBytes.fromIndex, keyBytes.length)
+    valueBytes.foreach { value =>
+      crc32c.update(value.array, value.fromIndex, value.length)
+    }
+
+    crc32c.getValue.toInt
+  }
+}
+
+/**
+ * Used to encode and decode checksum value with/from the row bytes.
+ * */
+object KeyValueChecksumEncoder {
+  /**
+   * Encodes the value row bytes with a checksum value. This encodes the bytes 
in a way that
+   * supports additional values to be later merged to it. If the value would 
only ever have a
+   * single value (no merge), then you should use 
[[encodeSingleValueRowWithChecksum]] instead.
+   *
+   * It is encoded as: checksum (4 bytes) + rowBytes.length (4 bytes) + 
rowBytes
+   *
+   * @param rowBytes Value row bytes.
+   * @param checksum Checksum value to encode with the value row bytes.
+   * @return The encoded value row bytes that includes the checksum.
+   * */
+  def encodeValueRowWithChecksum(rowBytes: Array[Byte], checksum: Int): 
Array[Byte] = {
+    val result = new Array[Byte](java.lang.Integer.BYTES * 2 + rowBytes.length)
+    Platform.putInt(result, Platform.BYTE_ARRAY_OFFSET, checksum)
+    Platform.putInt(result, Platform.BYTE_ARRAY_OFFSET + 
java.lang.Integer.BYTES, rowBytes.length)
+
+    // Write the actual data
+    Platform.copyMemory(
+      rowBytes, Platform.BYTE_ARRAY_OFFSET,
+      result, Platform.BYTE_ARRAY_OFFSET + java.lang.Integer.BYTES * 2,
+      rowBytes.length
+    )
+    result
+  }
+
+  /**
+   * Decode and verify a value row bytes encoded with checksum via 
[[encodeValueRowWithChecksum]]
+   * back to the original value row bytes. Supports decoding both one or more 
values bytes
+   * (i.e. merged values). This copies each individual value and removes their 
encoded checksum to
+   * form the original value bytes. Because it does copy, it is more expensive 
than
+   * [[decodeAndVerifyMultiValueRowWithChecksum]] method which returns
+   * the index range instead of copy (preferred for multi-value bytes). This 
method is just used
+   * to support calling store.get for a key that has merged values.
+   *
+   * @param verifier used for checksum verification.
+   * @param keyBytes Key bytes for the value to decode, only used for checksum 
verification.
+   * @param valueBytes The value bytes to decode.
+   * @return The original value row bytes, without the checksum.
+   * */
+  def decodeAndVerifyValueRowWithChecksum(
+      verifier: Option[KeyValueIntegrityVerifier],
+      keyBytes: Array[Byte],
+      valueBytes: Array[Byte]): Array[Byte] = {
+    // First get the total size of the original values
+    // Doing this to also support decoding merged values (via merge) e.g. 
val1,val2,val3
+    val valuesEnd = Platform.BYTE_ARRAY_OFFSET + valueBytes.length
+    var currentPosition = Platform.BYTE_ARRAY_OFFSET
+    var resultSize = 0
+    var numValues = 0
+    while (currentPosition < valuesEnd) {
+      // skip the checksum (first 4 bytes)
+      currentPosition += java.lang.Integer.BYTES
+      val valueRowSize = Platform.getInt(valueBytes, currentPosition)
+      // move to the next value and skip the delimiter character used for 
rocksdb merge
+      currentPosition += java.lang.Integer.BYTES + valueRowSize + 1
+      resultSize += valueRowSize
+      numValues += 1
+    }
+
+    // include the number of delimiters used for merge
+    resultSize += numValues - 1
+
+    // now verify and decode to original merged values
+    val result = new Array[Byte](resultSize)
+    var resultPosition = Platform.BYTE_ARRAY_OFFSET
+    val keyRowIndex = ArrayIndexRange(keyBytes, 0, keyBytes.length)
+    currentPosition = Platform.BYTE_ARRAY_OFFSET // reset to beginning of 
values
+    var currentValueCount = 0 // count of values iterated
+
+    while (currentPosition < valuesEnd) {
+      currentValueCount += 1
+      val checksum = Platform.getInt(valueBytes, currentPosition)
+      currentPosition += java.lang.Integer.BYTES
+      val valueRowSize = Platform.getInt(valueBytes, currentPosition)
+      currentPosition += java.lang.Integer.BYTES
+
+      // verify the current value using the index range to avoid copying
+      // convert position to array index
+      val from = byteOffsetToIndex(currentPosition)
+      val until = from + valueRowSize
+      val valueRowIndex = ArrayIndexRange(valueBytes, from, until)
+      verifier.foreach(_.verify(keyRowIndex, Some(valueRowIndex), checksum))
+
+      // No delimiter is needed if single value or the last value in 
multi-value
+      val copyLength = if (currentValueCount < numValues) {
+        valueRowSize + 1 // copy the delimiter
+      } else {
+        valueRowSize
+      }
+      Platform.copyMemory(
+        valueBytes, currentPosition,
+        result, resultPosition,
+        copyLength
+      )
+
+      // move to the next value
+      currentPosition += copyLength
+      resultPosition += copyLength
+    }
+
+    result
+  }
+
+  /**
+   * Decode and verify a value row bytes, that might contain one or more 
values (merged)
+   * encoded with checksum via [[encodeValueRowWithChecksum]] back to the 
original value
+   * row bytes index.
+   * This returns an iterator of index range of the original individual value 
row bytes
+   * and verifies their checksum. Cheaper since it does not copy the value 
bytes unlike
+   * [[decodeAndVerifyValueRowWithChecksum]].
+   *
+   * @param verifier Used for checksum verification.
+   * @param keyBytes Key bytes for the value to decode, only used for checksum 
verification.
+   * @param valueBytes The value bytes to decode.
+   * @return Iterator of index range representing the original value row 
bytes, without checksum.
+   */
+  def decodeAndVerifyMultiValueRowWithChecksum(
+      verifier: Option[KeyValueIntegrityVerifier],
+      keyBytes: Array[Byte],
+      valueBytes: Array[Byte]): Iterator[ArrayIndexRange[Byte]] = {
+    if (valueBytes == null) {
+      Seq().iterator
+    } else {
+      new Iterator[ArrayIndexRange[Byte]] {
+        private val keyRowIndex = ArrayIndexRange(keyBytes, 0, keyBytes.length)
+        private var position: Int = Platform.BYTE_ARRAY_OFFSET
+        private val valuesEnd = Platform.BYTE_ARRAY_OFFSET + valueBytes.length
+
+        override def hasNext: Boolean = position < valuesEnd
+
+        override def next(): ArrayIndexRange[Byte] = {
+          val (valueRowIndex, checksum) =
+            getValueRowIndexAndChecksum(valueBytes, startingPosition = 
position)
+          verifier.foreach(_.verify(keyRowIndex, Some(valueRowIndex), 
checksum))
+          // move to the next value and skip the delimiter character used for 
rocksdb merge
+          position = byteIndexToOffset(valueRowIndex.untilIndex) + 1
+          valueRowIndex
+        }
+      }
+    }
+  }
+
+  /**
+   * Decodes one value row bytes encoded with checksum via 
[[encodeValueRowWithChecksum]]
+   * back to the original value row bytes. This is used for an encoded value 
row that
+   * currently only have one value. This returns the index range of the 
original value row bytes
+   * and the encoded checksum value. Cheaper since it does not copy the value 
bytes.
+   *
+   * @param bytes The value bytes to decode.
+   * @return A tuple containing the index range of the original value row 
bytes and the checksum.
+   */
+  def decodeOneValueRowIndexWithChecksum(bytes: Array[Byte]): 
(ArrayIndexRange[Byte], Int) = {
+    getValueRowIndexAndChecksum(bytes, startingPosition = 
Platform.BYTE_ARRAY_OFFSET)
+  }
+
+  /** Get the original value row index and checksum for a row encoded via
+   * [[encodeValueRowWithChecksum]]
+   * */
+  private def getValueRowIndexAndChecksum(
+      bytes: Array[Byte],
+      startingPosition: Int): (ArrayIndexRange[Byte], Int) = {
+    var position = startingPosition
+    val checksum = Platform.getInt(bytes, position)
+    position += java.lang.Integer.BYTES
+    val rowSize = Platform.getInt(bytes, position)
+    position += java.lang.Integer.BYTES
+
+    // convert position to array index
+    val fromIndex = byteOffsetToIndex(position)
+    (ArrayIndexRange(bytes, fromIndex, fromIndex + rowSize), checksum)
+  }
+
+  /**
+   * Encodes the key row bytes with a checksum value.
+   * It is encoded as: checksum (4 bytes) + rowBytes
+   *
+   * @param rowBytes Key row bytes.
+   * @param checksum Checksum value to encode with the key row bytes.
+   * @return The encoded key row bytes that includes the checksum.
+   * */
+  def encodeKeyRowWithChecksum(rowBytes: Array[Byte], checksum: Int): 
Array[Byte] = {
+    encodeSingleValueRowWithChecksum(rowBytes, checksum)
+  }
+
+  /**
+   * Decodes the key row encoded with [[encodeKeyRowWithChecksum]] and
+   * returns the original key bytes and the checksum value.
+   *
+   * @param bytes The encoded key bytes with checksum.
+   * @return Tuple of the original key bytes and the checksum value.
+   * */
+  def decodeKeyRowWithChecksum(bytes: Array[Byte]): (Array[Byte], Int) = {
+    decodeSingleValueRowWithChecksum(bytes)
+  }
+
+  /**
+   * Decode and verify the key row encoded with [[encodeKeyRowWithChecksum]] 
and
+   * returns the original key bytes.
+   *
+   * @param verifier Used for checksum verification.
+   * @param bytes The encoded key bytes with checksum.
+   * @return The original key bytes.
+   * */
+  def decodeAndVerifyKeyRowWithChecksum(
+      verifier: Option[KeyValueIntegrityVerifier],
+      bytes: Array[Byte]): Array[Byte] = {
+    val (originalBytes, checksum) = decodeKeyRowWithChecksum(bytes)
+    verifier.foreach(_.verify(originalBytes, None, checksum))
+    originalBytes
+  }
+
+  /**
+   * Decodes a key row encoded with [[encodeKeyRowWithChecksum]] and
+   * returns the index range of the original key bytes and checksum value. 
This is cheaper
+   * than [[decodeKeyRowWithChecksum]], since it doesn't copy the key bytes.
+   *
+   * @param keyBytes The encoded key bytes with checksum.
+   * @return A tuple containing the index range of the original key row bytes 
and the checksum.
+   * */
+  def decodeKeyRowIndexWithChecksum(keyBytes: Array[Byte]): 
(ArrayIndexRange[Byte], Int) = {
+    decodeSingleValueRowIndexWithChecksum(keyBytes)
+  }
+
+  /**
+   * Encodes a value row bytes that will only ever have a single value (no 
multi-value)
+   * with a checksum value.
+   * It is encoded as: checksum (4 bytes) + rowBytes.
+   * Since it will only ever have a single value, no need to encode the 
rowBytes length.
+   *
+   * @param rowBytes Value row bytes.
+   * @param checksum Checksum value to encode with the value row bytes.
+   * @return The encoded value row bytes that includes the checksum.
+   * */
+  def encodeSingleValueRowWithChecksum(rowBytes: Array[Byte], checksum: Int): 
Array[Byte] = {
+    val result = new Array[Byte](java.lang.Integer.BYTES + rowBytes.length)
+    Platform.putInt(result, Platform.BYTE_ARRAY_OFFSET, checksum)
+
+    // Write the actual data
+    Platform.copyMemory(
+      rowBytes, Platform.BYTE_ARRAY_OFFSET,
+      result, Platform.BYTE_ARRAY_OFFSET + java.lang.Integer.BYTES,
+      rowBytes.length
+    )
+    result
+  }
+
+  /**
+   * Decodes a single value row encoded with 
[[encodeSingleValueRowWithChecksum]] and
+   * returns the original value bytes and checksum value.
+   *
+   * @param bytes The encoded value bytes with checksum.
+   * @return Tuple of the original value bytes and the checksum value.
+   * */
+  def decodeSingleValueRowWithChecksum(bytes: Array[Byte]): (Array[Byte], Int) 
= {
+    val checksum = Platform.getInt(bytes, Platform.BYTE_ARRAY_OFFSET)
+    val row = new Array[Byte](bytes.length - java.lang.Integer.BYTES)
+    Platform.copyMemory(
+      bytes, Platform.BYTE_ARRAY_OFFSET + java.lang.Integer.BYTES,
+      row, Platform.BYTE_ARRAY_OFFSET,
+      row.length
+    )
+
+    (row, checksum)
+  }
+
+  /**
+   * Decode and verify a single value row encoded with 
[[encodeSingleValueRowWithChecksum]] and
+   * returns the original value bytes.
+   *
+   * @param verifier used for checksum verification.
+   * @param keyBytes Key bytes for the value to decode, only used for checksum 
verification.
+   * @param valueBytes The value bytes to decode.
+   * @return The original value row bytes, without the checksum.
+   * */
+  def decodeAndVerifySingleValueRowWithChecksum(
+      verifier: Option[KeyValueIntegrityVerifier],
+      keyBytes: Array[Byte],
+      valueBytes: Array[Byte]): Array[Byte] = {
+    val (originalValueBytes, checksum) = 
decodeSingleValueRowWithChecksum(valueBytes)
+    verifier.foreach(_.verify(keyBytes, Some(originalValueBytes), checksum))
+    originalValueBytes
+  }
+
+  /**
+   * Decodes a single value row encoded with 
[[encodeSingleValueRowWithChecksum]] and
+   * returns the index range of the original value bytes and checksum value. 
This is cheaper
+   * than [[decodeSingleValueRowWithChecksum]], since it doesn't copy the 
value bytes.
+   *
+   * @param bytes The encoded value bytes with checksum.
+   * @return A tuple containing the index range of the original value row 
bytes and the checksum.
+   * */
+  def decodeSingleValueRowIndexWithChecksum(bytes: Array[Byte]): 
(ArrayIndexRange[Byte], Int) = {
+    var position = Platform.BYTE_ARRAY_OFFSET
+    val checksum = Platform.getInt(bytes, position)
+    position += java.lang.Integer.BYTES
+
+    // convert position to array index
+    val fromIndex = byteOffsetToIndex(position)
+    (ArrayIndexRange(bytes, fromIndex, bytes.length), checksum)
+  }
+
+  /** Convert byte array address offset to array index */
+  private def byteOffsetToIndex(offset: Int): Int = {
+    offset - Platform.BYTE_ARRAY_OFFSET
+  }
+
+  /** Convert byte array index to array address offset */
+  private def byteIndexToOffset(index: Int): Int = {
+    index + Platform.BYTE_ARRAY_OFFSET
+  }
+}
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala
index 9c4a7b1879f6..e4312fd16d1f 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala
@@ -197,4 +197,26 @@ class OffsetSeqLogSuite extends SharedSparkSession {
         "unsaferow")
     }
   }
+
+  test("Row checksum disabled by default") {
+    val offsetSeqMetadata = OffsetSeqMetadata.apply(batchWatermarkMs = 0, 
batchTimestampMs = 0,
+      spark.conf)
+    
assert(offsetSeqMetadata.conf.get(SQLConf.STATE_STORE_ROW_CHECKSUM_ENABLED.key) 
===
+      Some(false.toString))
+  }
+
+  test("Row checksum disabled for existing checkpoint even if conf is 
enabled") {
+    val rowChecksumConf = SQLConf.STATE_STORE_ROW_CHECKSUM_ENABLED.key
+    withSQLConf(rowChecksumConf -> true.toString) {
+      val existingChkpt = "offset-log-version-2.1.0"
+      val (_, offsetSeq) = readFromResource(existingChkpt)
+      val offsetSeqMetadata = offsetSeq.metadata.get
+      // Not present in existing checkpoint
+      assert(offsetSeqMetadata.conf.get(rowChecksumConf) === None)
+
+      val clonedSqlConf = spark.sessionState.conf.clone()
+      OffsetSeqMetadata.setSessionConf(offsetSeqMetadata, clonedSqlConf)
+      assert(!clonedSqlConf.stateStoreRowChecksumEnabled)
+    }
+  }
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala
index 3182c65a2aee..e1045e158188 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala
@@ -148,8 +148,9 @@ trait AlsoTestWithEncodingTypes extends SQLTestUtils {
   }
 }
 
-trait AlsoTestWithRocksDBFeatures
-  extends SQLTestUtils with RocksDBStateStoreChangelogCheckpointingTestUtil {
+trait AlsoTestWithRocksDBFeatures extends SQLTestUtils
+  with RocksDBStateStoreChangelogCheckpointingTestUtil
+  with AlsoTestWithStateStoreRowChecksum {
 
   sealed trait TestMode
   case object TestWithChangelogCheckpointingEnabled extends TestMode
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRowChecksumSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRowChecksumSuite.scala
new file mode 100644
index 000000000000..1268ccbc9e0a
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRowChecksumSuite.scala
@@ -0,0 +1,464 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state
+
+import java.io.{DataInputStream, DataOutputStream}
+import java.util.UUID
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.Path
+import org.rocksdb.{ReadOptions, RocksDB => NativeRocksDB, WriteOptions}
+import org.scalactic.source.Position
+import org.scalatest.{BeforeAndAfter, PrivateMethodTester, Tag}
+
+import org.apache.spark.{SparkConf, SparkException}
+import org.apache.spark.io.CompressionCodec
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import 
org.apache.spark.sql.execution.streaming.checkpointing.CheckpointFileManager
+import org.apache.spark.sql.execution.streaming.runtime.StreamExecution
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils}
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.Utils
+
+/**
+ * Main test for State store row checksum feature. Has both  
[[HDFSStateStoreRowChecksumSuite]]
+ * and [[RocksDBStateStoreRowChecksumSuite]] as subclasses. Row checksum is 
also enabled
+ * in other tests (State operators, TransformWithState, State data source, RTM 
etc.)
+ * by adding the [[AlsoTestWithStateStoreRowChecksum]] trait.
+ * */
+abstract class StateStoreRowChecksumSuite extends SharedSparkSession
+  with BeforeAndAfter {
+
+  import StateStoreTestsHelper._
+
+  before {
+    StateStore.stop()
+    require(!StateStore.isMaintenanceRunning)
+    spark.streams.stateStoreCoordinator // initialize the lazy coordinator
+  }
+
+  after {
+    StateStore.stop()
+    require(!StateStore.isMaintenanceRunning)
+  }
+
+  override protected def sparkConf: SparkConf = {
+    super.sparkConf.set(SQLConf.STATE_STORE_ROW_CHECKSUM_ENABLED.key, 
true.toString)
+      // To avoid file checksum verification since we will be injecting 
corruption in this suite
+      .set(SQLConf.STREAMING_CHECKPOINT_FILE_CHECKSUM_ENABLED.key, 
false.toString)
+  }
+
+  protected def withStateStoreProvider[T <: StateStoreProvider](
+      provider: T)(f: T => Unit): Unit = {
+    try {
+      f(provider)
+    } finally {
+      provider.close()
+    }
+  }
+
+  protected def createProvider: StateStoreProvider
+
+  protected def createInitializedProvider(
+      dir: String = Utils.createTempDir().toString,
+      opId: Long = 0,
+      partition: Int = 0,
+      runId: UUID = UUID.randomUUID(),
+      keyStateEncoderSpec: KeyStateEncoderSpec = NoPrefixKeyStateEncoderSpec(
+        keySchema),
+      keySchema: StructType = keySchema,
+      valueSchema: StructType = valueSchema,
+      sqlConf: SQLConf = SQLConf.get,
+      hadoopConf: Configuration = new Configuration): StateStoreProvider = {
+    hadoopConf.set(StreamExecution.RUN_ID_KEY, runId.toString)
+    val provider = createProvider
+    provider.init(
+      StateStoreId(dir, opId, partition),
+      keySchema,
+      valueSchema,
+      keyStateEncoderSpec,
+      useColumnFamilies = false,
+      new StateStoreConf(sqlConf),
+      hadoopConf)
+    provider
+  }
+
+  protected def corruptRow(
+    provider: StateStoreProvider, store: StateStore, key: UnsafeRow): Unit
+
+  protected def corruptRowInFile(
+    provider: StateStoreProvider, version: Long, isSnapshot: Boolean): Unit
+
+  test("Detect local corrupt row") {
+    withStateStoreProvider(createInitializedProvider()) { provider =>
+      val store = provider.getStore(0)
+      put(store, "1", 11, 100)
+      put(store, "2", 22, 200)
+
+      // corrupt a row
+      corruptRow(provider, store, dataToKeyRow("1", 11))
+
+      assert(get(store, "2", 22).contains(200),
+        "failed to get the correct value for the uncorrupted row")
+
+      checkChecksumError(intercept[SparkException] {
+        // should throw row checksum exception
+        get(store, "1", 11)
+      })
+
+      checkChecksumError(intercept[SparkException] {
+        // should throw row checksum exception
+        store.iterator().foreach(kv => assert(kv.key != null && kv.value != 
null))
+      })
+
+      store.abort()
+    }
+  }
+
+  protected def corruptRowInFileTestMode: Seq[Boolean]
+
+  corruptRowInFileTestMode.foreach { isSnapshot =>
+    test(s"Detect corrupt row in checkpoint file - isSnapshot: $isSnapshot") {
+      // We want to generate snapshot per change file
+      withSQLConf(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "0") {
+        var version = 0L
+        val checkpointDir = Utils.createTempDir().toString
+
+        withStateStoreProvider(createInitializedProvider(checkpointDir)) { 
provider =>
+          val store = provider.getStore(version)
+          put(store, "1", 11, 100)
+          put(store, "2", 22, 200)
+          remove(store, k => k == ("1", 11))
+          put(store, "3", 33, 300)
+          version = store.commit() // writes change file
+
+          if (isSnapshot) {
+            provider.doMaintenance() // writes snapshot
+          }
+        }
+
+        // We should be able to successfully load the store from checkpoint
+        withStateStoreProvider(createInitializedProvider(checkpointDir)) { 
provider =>
+          val store = provider.getStore(version)
+          assert(get(store, "1", 11).isEmpty) // the removed row
+          assert(get(store, "2", 22).contains(200))
+          assert(get(store, "3", 33).contains(300))
+
+          // Now corrupt a row in the checkpoint file
+          corruptRowInFile(provider, version, isSnapshot)
+          store.abort()
+        }
+
+        // Reload the store from checkpoint, should detect the corrupt row
+        withStateStoreProvider(createInitializedProvider(checkpointDir)) { 
provider =>
+          checkChecksumError(intercept[SparkException] {
+            // should throw row checksum exception
+            provider.getStore(version)
+          }.getCause.asInstanceOf[SparkException])
+        }
+      }
+    }
+  }
+
+  test("Read verification ratio") {
+    var version = 0L
+    val checkpointDir = Utils.createTempDir().toString
+    val numRows = 10
+
+    withStateStoreProvider(createInitializedProvider(checkpointDir)) { 
provider =>
+      val store = provider.getStore(0)
+      // add 10 rows
+      (1 to numRows).foreach(v => put(store, v.toString, v, v * 100))
+
+      // read the rows
+      (1 to numRows).foreach(v => assert(get(store, v.toString, v).nonEmpty))
+      val verifier = getReadVerifier(provider, 
store).get.asInstanceOf[KeyValueChecksumVerifier]
+      // Default is to verify every row read
+      assertVerifierStats(verifier, expectedRequests = numRows, 
expectedVerifies = numRows)
+
+      store.iterator().foreach(kv => assert(kv.key != null && kv.value != 
null))
+      assertVerifierStats(verifier, expectedRequests = numRows * 2, 
expectedVerifies = numRows * 2)
+
+      version = store.commit()
+    }
+
+    // Setting to 0 means we will not verify during store read requests
+    withSQLConf(SQLConf.STATE_STORE_ROW_CHECKSUM_READ_VERIFICATION_RATIO.key 
-> "0") {
+      withStateStoreProvider(createInitializedProvider(checkpointDir)) { 
provider =>
+        val store = provider.getStore(version)
+        (1 to numRows).foreach(v => assert(get(store, v.toString, v).nonEmpty))
+
+        assert(getReadVerifier(provider, store).isEmpty, "Expected no 
verifier")
+        store.abort()
+      }
+    }
+
+    // Verify every 2 store read requests
+    withSQLConf(SQLConf.STATE_STORE_ROW_CHECKSUM_READ_VERIFICATION_RATIO.key 
-> "2") {
+      withStateStoreProvider(createInitializedProvider(checkpointDir)) { 
provider =>
+        val store = provider.getStore(version)
+
+        (1 to numRows).foreach(v => assert(get(store, v.toString, v).nonEmpty))
+        val verifier = getReadVerifier(provider, 
store).get.asInstanceOf[KeyValueChecksumVerifier]
+        assertVerifierStats(verifier, expectedRequests = numRows, 
expectedVerifies = numRows / 2)
+
+        store.iterator().foreach(kv => assert(kv.key != null && kv.value != 
null))
+        assertVerifierStats(verifier, expectedRequests = numRows * 2, 
expectedVerifies = numRows)
+
+        store.abort()
+      }
+    }
+  }
+
+  protected def checkChecksumError(
+      checksumError: SparkException,
+      opId: Int = 0,
+      partId: Int = 0): Unit = {
+    checkError(
+      exception = checksumError,
+      condition = "STATE_STORE_ROW_CHECKSUM_VERIFICATION_FAILED",
+      parameters = Map(
+        "stateStoreId" -> 
(s".*StateStoreId[\\[\\(].*(operatorId|opId)=($opId)" +
+          s".*(partitionId|partId)=($partId).*[)\\]]"),
+        "expectedChecksum" -> "^-?\\d+$", // integer
+        "computedChecksum" -> "^-?\\d+$"), // integer
+      matchPVals = true)
+  }
+
+  protected def getReadVerifier(
+      provider: StateStoreProvider, store: StateStore): 
Option[KeyValueIntegrityVerifier]
+
+  private def assertVerifierStats(verifier: KeyValueChecksumVerifier,
+      expectedRequests: Long, expectedVerifies: Long): Unit = {
+    assert(verifier.getNumRequests == expectedRequests)
+    assert(verifier.getNumVerified == expectedVerifies)
+  }
+}
+
+class HDFSStateStoreRowChecksumSuite extends StateStoreRowChecksumSuite with 
PrivateMethodTester {
+  import StateStoreTestsHelper._
+
+  override protected def createProvider: StateStoreProvider = new 
HDFSBackedStateStoreProvider
+
+  override protected def corruptRow(
+      provider: StateStoreProvider, store: StateStore, key: UnsafeRow): Unit = 
{
+    val hdfsProvider = provider.asInstanceOf[HDFSBackedStateStoreProvider]
+    val hdfsStore = store.asInstanceOf[hdfsProvider.HDFSBackedStateStore]
+
+    // Access the private hdfs store map
+    val mapToUpdateField = 
PrivateMethod[HDFSBackedStateStoreMap](Symbol("mapToUpdate"))
+    val storeMap = hdfsStore invokePrivate mapToUpdateField()
+    val mapField = 
PrivateMethod[HDFSBackedStateStoreMap.MapType](Symbol("map"))
+    val map = storeMap invokePrivate mapField()
+
+    val currentValueRow = map.get(key).asInstanceOf[StateStoreRowWithChecksum]
+    val currentValue = valueRowToData(currentValueRow.unsafeRow())
+    // corrupt the existing value by flipping the last bit
+    val corruptValue = currentValue ^ 1
+
+    // update with the corrupt value, but keeping the previous checksum
+    map.put(key, StateStoreRowWithChecksum(dataToValueRow(corruptValue), 
currentValueRow.checksum))
+  }
+
+  protected def corruptRowInFile(
+      provider: StateStoreProvider, version: Long, isSnapshot: Boolean): Unit 
= {
+    val hdfsProvider = provider.asInstanceOf[HDFSBackedStateStoreProvider]
+
+    val baseDirField = PrivateMethod[Path](Symbol("baseDir"))
+    val baseDir = hdfsProvider invokePrivate baseDirField()
+    val fileName = if (isSnapshot) s"$version.snapshot" else s"$version.delta"
+    val filePath = new Path(baseDir.toString, fileName)
+
+    val fileManagerMethod = PrivateMethod[CheckpointFileManager](Symbol("fm"))
+    val fm = hdfsProvider invokePrivate fileManagerMethod()
+
+    // Read the decompressed file content
+    val input = hdfsProvider.decompressStream(fm.open(filePath))
+    val currentData = input.readAllBytes()
+    input.close()
+
+    // Corrupt the current data by flipping the last bit of the last row we 
wrote.
+    // We typically write an EOF marker (-1) so we are skipping that.
+    val byteOffsetToCorrupt = currentData.length - java.lang.Integer.BYTES - 1
+    currentData(byteOffsetToCorrupt) = (currentData(byteOffsetToCorrupt) ^ 
0x01).toByte
+
+    // Now delete the current file and write a new file with corrupt row
+    fm.delete(filePath)
+    val output = hdfsProvider.compressStream(fm.createAtomic(filePath, 
overwriteIfPossible = true))
+    output.write(currentData)
+    output.close()
+  }
+
+  // test both snapshot and delta file
+  override protected def corruptRowInFileTestMode: Seq[Boolean] = Seq(true, 
false)
+
+  protected def getReadVerifier(
+      provider: StateStoreProvider, store: StateStore): 
Option[KeyValueIntegrityVerifier] = {
+    val hdfsProvider = provider.asInstanceOf[HDFSBackedStateStoreProvider]
+    val hdfsStore = store.asInstanceOf[hdfsProvider.HDFSBackedStateStore]
+
+    // Access the private hdfs store map
+    val mapToUpdateField = 
PrivateMethod[HDFSBackedStateStoreMap](Symbol("mapToUpdate"))
+    val storeMap = hdfsStore invokePrivate mapToUpdateField()
+    val readVerifierField = PrivateMethod[Option[KeyValueIntegrityVerifier]](
+      Symbol("readVerifier"))
+    storeMap invokePrivate readVerifierField()
+  }
+
+  test("Snapshot upload should fail if corrupt row is detected") {
+    // We want to generate snapshot per change file
+    withSQLConf(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "0") {
+      withStateStoreProvider(createInitializedProvider()) { provider =>
+        val store = provider.getStore(0L)
+        put(store, "1", 11, 100)
+        put(store, "2", 22, 200)
+        put(store, "3", 33, 300)
+
+        // Corrupt the local row, won't affect changelog file since row is 
already written to it
+        corruptRow(provider, store, dataToKeyRow("1", 11))
+        store.commit() // writes change file
+
+        checkChecksumError(intercept[SparkException] {
+          // Should throw row checksum exception while trying to write snapshot
+          // provider.doMaintenance() calls doSnapshot and then swallows 
exception.
+          // Hence, calling doSnapshot directly to avoid that.
+          val doSnapshotMethod = PrivateMethod[Unit](Symbol("doSnapshot"))
+          provider invokePrivate doSnapshotMethod("maintenance", true)
+        })
+      }
+    }
+  }
+}
+
+class RocksDBStateStoreRowChecksumSuite extends StateStoreRowChecksumSuite
+  with PrivateMethodTester {
+  import StateStoreTestsHelper._
+
+  override protected def sparkConf: SparkConf = {
+    super.sparkConf
+      // Also generate changelog files for RocksDB
+      
.set("spark.sql.streaming.stateStore.rocksdb.changelogCheckpointing.enabled", 
true.toString)
+  }
+
+  override protected def createProvider: StateStoreProvider = new 
RocksDBStateStoreProvider
+
+  override protected def corruptRow(
+      provider: StateStoreProvider, store: StateStore, key: UnsafeRow): Unit = 
{
+    val rocksDBProvider = provider.asInstanceOf[RocksDBStateStoreProvider]
+    val dbField = PrivateMethod[NativeRocksDB](Symbol("db"))
+    val db = rocksDBProvider.rocksDB invokePrivate dbField()
+
+    val readOptionsField = PrivateMethod[ReadOptions](Symbol("readOptions"))
+    val readOptions = rocksDBProvider.rocksDB invokePrivate readOptionsField()
+    val writeOptionsField = PrivateMethod[WriteOptions](Symbol("writeOptions"))
+    val writeOptions = rocksDBProvider.rocksDB invokePrivate 
writeOptionsField()
+
+    val encoder = new 
UnsafeRowDataEncoder(NoPrefixKeyStateEncoderSpec(keySchema), valueSchema)
+    val keyBytes = encoder.encodeKey(key)
+    val currentValue = db.get(readOptions, keyBytes)
+
+    // corrupt the current value by flipping the last bit
+    currentValue(currentValue.length - 1) = (currentValue(currentValue.length 
- 1) ^ 0x01).toByte
+    db.put(writeOptions, keyBytes, currentValue)
+  }
+
+  protected def corruptRowInFile(
+      provider: StateStoreProvider, version: Long, isSnapshot: Boolean): Unit 
= {
+    assert(!isSnapshot, "Doesn't support corrupting a row in snapshot file")
+    val rocksDBProvider = provider.asInstanceOf[RocksDBStateStoreProvider]
+    val rocksDBFileManager = rocksDBProvider.rocksDB.fileManager
+
+    val dfsChangelogFileMethod = 
PrivateMethod[Path](Symbol("dfsChangelogFile"))
+    val changelogFilePath = rocksDBFileManager invokePrivate 
dfsChangelogFileMethod(version, None)
+
+    val fileManagerMethod = PrivateMethod[CheckpointFileManager](Symbol("fm"))
+    val fm = rocksDBFileManager invokePrivate fileManagerMethod()
+
+    val codecMethod = PrivateMethod[CompressionCodec](Symbol("codec"))
+    val codec = rocksDBFileManager invokePrivate codecMethod()
+
+    // Read the decompressed file content
+    val input = new 
DataInputStream(codec.compressedInputStream(fm.open(changelogFilePath)))
+    val currentData = input.readAllBytes()
+    input.close()
+
+    // Corrupt the current data by flipping the last bit of the last row we 
wrote.
+    // We typically write an EOF marker (-1) so we are skipping that.
+    val byteOffsetToCorrupt = currentData.length - java.lang.Integer.BYTES - 1
+    currentData(byteOffsetToCorrupt) = (currentData(byteOffsetToCorrupt) ^ 
0x01).toByte
+
+    // Now delete the current file and write a new file with corrupt row
+    fm.delete(changelogFilePath)
+    val output = new DataOutputStream(
+      codec.compressedOutputStream(fm.createAtomic(changelogFilePath, 
overwriteIfPossible = true)))
+    output.write(currentData)
+    output.close()
+  }
+
+  // test only changelog file since zip file doesn't contain rows
+  override protected def corruptRowInFileTestMode: Seq[Boolean] = Seq(false)
+
+  protected def getReadVerifier(
+      provider: StateStoreProvider, store: StateStore): 
Option[KeyValueIntegrityVerifier] = {
+    val rocksDBProvider = provider.asInstanceOf[RocksDBStateStoreProvider]
+    val readVerifierField = PrivateMethod[Option[KeyValueIntegrityVerifier]](
+      Symbol("readVerifier"))
+    rocksDBProvider.rocksDB invokePrivate readVerifierField()
+  }
+}
+
+/** Used to enable testing with and without row checksum enabled in other 
suites
+ * to make sure row checksum works well with other features. */
+trait AlsoTestWithStateStoreRowChecksum extends SQLTestUtils {
+  override protected def test(testName: String, testTags: Tag*)(testBody: => 
Any)
+      (implicit pos: Position): Unit = {
+    testWithRowChecksumEnabled(testName, testTags: _*)(testBody)
+    testWithRowChecksumDisabled(testName, testTags: _*)(testBody)
+  }
+
+  def testWithRowChecksumEnabled(testName: String, testTags: Tag*)
+      (testBody: => Any): Unit = {
+    super.test(testName + " (with row checksum)", testTags: _*) {
+      // in case tests have any code that needs to execute before every test
+      super.beforeEach()
+      withSQLConf(SQLConf.STATE_STORE_ROW_CHECKSUM_ENABLED.key -> 
true.toString) {
+        testBody
+      }
+      // in case tests have any code that needs to execute after every test
+      super.afterEach()
+    }
+  }
+
+  def testWithRowChecksumDisabled(testName: String, testTags: Tag*)
+      (testBody: => Any): Unit = {
+    super.test(testName + " (without row checksum)", testTags: _*) {
+      // in case tests have any code that needs to execute before every test
+      super.beforeEach()
+      withSQLConf(SQLConf.STATE_STORE_ROW_CHECKSUM_ENABLED.key -> 
false.toString) {
+        testBody
+      }
+      // in case tests have any code that needs to execute after every test
+      super.afterEach()
+    }
+  }
+
+  // The default implementation in SQLTestUtils times out the `withTempDir()` 
call
+  // after 10 seconds. We don't want that because it causes flakiness in tests.
+  override protected def waitForTasksToFinish(): Unit = {}
+}
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
index 6bb64315e356..bc3500a4f047 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
@@ -48,6 +48,7 @@ import 
org.apache.spark.sql.execution.streaming.runtime.{MemoryStream, StreamExe
 import 
org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorSuite.withCoordinatorRef
 import org.apache.spark.sql.functions.count
 import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SharedSparkSession
 import org.apache.spark.sql.types._
 import org.apache.spark.tags.ExtendedSQLTest
 import org.apache.spark.unsafe.types.UTF8String
@@ -255,13 +256,19 @@ private object FakeStateStoreProviderWithMaintenanceError 
{
 
 @ExtendedSQLTest
 class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
+  with AlsoTestWithStateStoreRowChecksum
+  with SharedSparkSession
   with BeforeAndAfter {
   import StateStoreTestsHelper._
   import StateStoreCoordinatorSuite._
 
+  override def beforeEach(): Unit = {}
+  override def afterEach(): Unit = {}
+
   before {
     StateStore.stop()
     require(!StateStore.isMaintenanceRunning)
+    spark.streams.stateStoreCoordinator // initialize the lazy coordinator
   }
 
   after {
@@ -800,7 +807,7 @@ class StateStoreSuite extends 
StateStoreSuiteBase[HDFSBackedStateStoreProvider]
     }
   }
 
-  test("SPARK-51291: corrupted file handling") {
+  testWithRowChecksumDisabled("SPARK-51291: corrupted file handling") {
     tryWithProviderResource(newStoreProvider(opId = Random.nextInt(), 
partition = 0,
       minDeltasForSnapshot = 5)) { provider =>
 
@@ -1297,7 +1304,10 @@ class StateStoreSuite extends 
StateStoreSuiteBase[HDFSBackedStateStoreProvider]
     }
   }
 
-  test("expose metrics with custom metrics to StateStoreMetrics") {
+  // When row checksum is enabled, we store checksum in the map and JVM does 
some memory
+  // optimization that will cause the memory used to be significantly lower 
for the
+  // reloadedProvider compared to the initial provider. The test expects it to 
be higher.
+  testWithRowChecksumDisabled("expose metrics with custom metrics to 
StateStoreMetrics") {
     def getCustomMetric(metrics: StateStoreMetrics, name: String): Long = {
       val metricPair = metrics.customMetrics.find(_._1.name == name)
       assert(metricPair.isDefined)
@@ -1499,6 +1509,8 @@ class StateStoreSuite extends 
StateStoreSuiteBase[HDFSBackedStateStoreProvider]
     sqlConf.setConf(SQLConf.STATE_STORE_COMPRESSION_CODEC, 
SQLConf.get.stateStoreCompressionCodec)
     sqlConf.setConf(
       SQLConf.STREAMING_CHECKPOINT_FILE_CHECKSUM_ENABLED, 
SQLConf.get.checkpointFileChecksumEnabled)
+    sqlConf.setConf(
+      SQLConf.STATE_STORE_ROW_CHECKSUM_ENABLED, 
SQLConf.get.stateStoreRowChecksumEnabled)
     sqlConf
   }
 
@@ -1578,6 +1590,12 @@ class StateStoreSuite extends 
StateStoreSuiteBase[HDFSBackedStateStoreProvider]
 
     filePath.createNewFile()
   }
+
+  override protected def testQuietly(name: String)(f: => Unit): Unit = {
+    // Use the implementation from StateStoreSuiteBase.
+    // There is another in SQLTestUtils. Doing this to avoid conflict error.
+    super[StateStoreSuiteBase].testQuietly(name)(f)
+  }
 }
 
 abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider]
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala
index fbe33ddc32db..06feee657715 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala
@@ -428,7 +428,9 @@ class ValueStateSuite extends StateVariableSuiteBase {
  * types (ValueState, ListState, MapState) used in arbitrary stateful 
operators.
  */
 abstract class StateVariableSuiteBase extends SharedSparkSession
-  with BeforeAndAfter with AlsoTestWithEncodingTypes {
+  with BeforeAndAfter
+  with AlsoTestWithStateStoreRowChecksum
+  with AlsoTestWithEncodingTypes {
 
   before {
     StateStore.stop()
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala
index e9bfaf2fc56a..d8b997483582 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala
@@ -20,9 +20,11 @@ package org.apache.spark.sql.streaming
 import org.scalatest.time.SpanSugar._
 
 import org.apache.spark.sql.execution.streaming.runtime.StreamExecution
+import 
org.apache.spark.sql.execution.streaming.state.AlsoTestWithStateStoreRowChecksum
 import org.apache.spark.util.ArrayImplicits._
 
-trait StateStoreMetricsTest extends StreamTest {
+trait StateStoreMetricsTest extends StreamTest
+  with AlsoTestWithStateStoreRowChecksum {
 
   private var lastCheckedRecentProgressIndex = -1
   private var lastQuery: StreamExecution = null


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to