This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 331d0bf30092 [SPARK-50017][SS] Support Avro encoding for 
TransformWithState operator
331d0bf30092 is described below

commit 331d0bf30092be62191476e4a679b403e1a369b9
Author: Eric Marnadi <[email protected]>
AuthorDate: Tue Nov 26 13:33:04 2024 +0900

    [SPARK-50017][SS] Support Avro encoding for TransformWithState operator
    
    ### What changes were proposed in this pull request?
    
    Currently, we use the internal byte representation to store state for 
stateful streaming operators in the StateStore. This PR introduces Avro 
serialization and deserialization capabilities in the RocksDBStateEncoder so 
that we can instead use Avro encoding to store state. This is currently enabled 
for the TransformWithState operator via SQLConf to support all functionality 
supported by TWS
    
    ### Why are the changes needed?
    
    UnsafeRow is an inherently unstable format that makes no guarantees of 
being backwards-compatible. Therefore, if the format changes between Spark 
releases, this could cause StateStore corruptions. Avro is more stable, and 
inherently enables schema evolution.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Amended and added to unit tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #48401 from ericm-db/avro.
    
    Lead-authored-by: Eric Marnadi <[email protected]>
    Co-authored-by: Eric Marnadi <[email protected]>
    Signed-off-by: Jungtaek Lim <[email protected]>
---
 .../org/apache/spark/sql/internal/SQLConf.scala    |  13 +
 .../StateStoreColumnFamilySchemaUtils.scala        |  41 +-
 .../sql/execution/streaming/StreamExecution.scala  |   1 +
 .../streaming/state/RocksDBStateEncoder.scala      | 538 +++++++++++++++++++--
 .../state/RocksDBStateStoreProvider.scala          | 120 ++++-
 .../state/StateSchemaCompatibilityChecker.scala    |  25 +
 .../sql/execution/streaming/state/StateStore.scala |  15 +-
 .../execution/streaming/state/StateStoreConf.scala |   3 +
 .../StateDataSourceTransformWithStateSuite.scala   |   4 +-
 .../streaming/state/RocksDBStateStoreSuite.scala   |   1 +
 .../execution/streaming/state/RocksDBSuite.scala   |  13 +
 .../streaming/state/ValueStateSuite.scala          |   2 +-
 .../streaming/TransformWithListStateSuite.scala    |   5 +-
 .../sql/streaming/TransformWithMapStateSuite.scala |   5 +-
 .../sql/streaming/TransformWithStateSuite.scala    |  17 +-
 .../sql/streaming/TransformWithStateTTLTest.scala  |   5 +-
 .../TransformWithValueStateTTLSuite.scala          |   3 +-
 17 files changed, 744 insertions(+), 67 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index ba0a37541e49..378eca09097f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -2221,6 +2221,17 @@ object SQLConf {
       .intConf
       .createWithDefault(1)
 
+  val STREAMING_STATE_STORE_ENCODING_FORMAT =
+    buildConf("spark.sql.streaming.stateStore.encodingFormat")
+      .doc("The encoding format used for stateful operators to store 
information " +
+        "in the state store")
+      .version("4.0.0")
+      .stringConf
+      .transform(_.toLowerCase(Locale.ROOT))
+      .checkValue(v => Set("unsaferow", "avro").contains(v),
+        "Valid values are 'unsaferow' and 'avro'")
+      .createWithDefault("unsaferow")
+
   val STATE_STORE_COMPRESSION_CODEC =
     buildConf("spark.sql.streaming.stateStore.compression.codec")
       .internal()
@@ -5596,6 +5607,8 @@ class SQLConf extends Serializable with Logging with 
SqlApiConf {
 
   def stateStoreCheckpointFormatVersion: Int = 
getConf(STATE_STORE_CHECKPOINT_FORMAT_VERSION)
 
+  def stateStoreEncodingFormat: String = 
getConf(STREAMING_STATE_STORE_ENCODING_FORMAT)
+
   def checkpointRenamedFileCheck: Boolean = 
getConf(CHECKPOINT_RENAMEDFILE_CHECK_ENABLED)
 
   def parquetFilterPushDown: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_ENABLED)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateStoreColumnFamilySchemaUtils.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateStoreColumnFamilySchemaUtils.scala
index 7da8408f98b0..585298fa4c99 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateStoreColumnFamilySchemaUtils.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateStoreColumnFamilySchemaUtils.scala
@@ -20,10 +20,49 @@ import org.apache.spark.sql.Encoder
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import 
org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchemaUtils._
 import 
org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, 
PrefixKeyScanStateEncoderSpec, StateStoreColFamilySchema}
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types._
 
 object StateStoreColumnFamilySchemaUtils {
 
+  /**
+   * Avro uses zig-zag encoding for some fixed-length types, like Longs and 
Ints. For range scans
+   * we want to use big-endian encoding, so we need to convert the source 
schema to replace these
+   * types with BinaryType.
+   *
+   * @param schema The schema to convert
+   * @param ordinals If non-empty, only convert fields at these ordinals.
+   *                 If empty, convert all fields.
+   */
+  def convertForRangeScan(schema: StructType, ordinals: Seq[Int] = Seq.empty): 
StructType = {
+    val ordinalSet = ordinals.toSet
+
+    StructType(schema.fields.zipWithIndex.flatMap { case (field, idx) =>
+      if ((ordinals.isEmpty || ordinalSet.contains(idx)) && 
isFixedSize(field.dataType)) {
+        // For each numeric field, create two fields:
+        // 1. Byte marker for null, positive, or negative values
+        // 2. The original numeric value in big-endian format
+        // Byte type is converted to Int in Avro, which doesn't work for us as 
Avro
+        // uses zig-zag encoding as opposed to big-endian for Ints
+        Seq(
+          StructField(s"${field.name}_marker", BinaryType, nullable = false),
+          field.copy(name = s"${field.name}_value", BinaryType)
+        )
+      } else {
+        Seq(field)
+      }
+    })
+  }
+
+  private def isFixedSize(dataType: DataType): Boolean = dataType match {
+    case _: ByteType | _: BooleanType | _: ShortType | _: IntegerType | _: 
LongType |
+         _: FloatType | _: DoubleType => true
+    case _ => false
+  }
+
+  def getTtlColFamilyName(stateName: String): String = {
+    "$ttl_" + stateName
+  }
+
   def getValueStateSchema[T](
       stateName: String,
       keyEncoder: ExpressionEncoder[Any],
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
index bd501c935723..44202bb0d294 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
@@ -715,6 +715,7 @@ abstract class StreamExecution(
 
 object StreamExecution {
   val QUERY_ID_KEY = "sql.streaming.queryId"
+  val RUN_ID_KEY = "sql.streaming.runId"
   val IS_CONTINUOUS_PROCESSING = "__is_continuous_processing"
   val IO_EXCEPTION_NAMES = Seq(
     classOf[InterruptedException].getName,
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
index 4c7a226e0973..f39022c1f53a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
@@ -17,13 +17,21 @@
 
 package org.apache.spark.sql.execution.streaming.state
 
+import java.io.ByteArrayOutputStream
 import java.lang.Double.{doubleToRawLongBits, longBitsToDouble}
 import java.lang.Float.{floatToRawIntBits, intBitsToFloat}
 import java.nio.{ByteBuffer, ByteOrder}
 
+import org.apache.avro.Schema
+import org.apache.avro.generic.{GenericData, GenericDatumReader, 
GenericDatumWriter, GenericRecord}
+import org.apache.avro.io.{DecoderFactory, EncoderFactory}
+
 import org.apache.spark.internal.Logging
+import org.apache.spark.sql.avro.{AvroDeserializer, AvroSerializer, 
SchemaConverters}
+import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{BoundReference, JoinedRow, 
UnsafeProjection, UnsafeRow}
 import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter
+import 
org.apache.spark.sql.execution.streaming.StateStoreColumnFamilySchemaUtils
 import 
org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider.{STATE_ENCODING_NUM_VERSION_BYTES,
 STATE_ENCODING_VERSION, VIRTUAL_COL_FAMILY_PREFIX_BYTES}
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.Platform
@@ -49,6 +57,7 @@ abstract class RocksDBKeyStateEncoderBase(
   def offsetForColFamilyPrefix: Int =
     if (useColumnFamilies) VIRTUAL_COL_FAMILY_PREFIX_BYTES else 0
 
+  val out = new ByteArrayOutputStream
   /**
    * Get Byte Array for the virtual column family id that is used as prefix for
    * key state rows.
@@ -89,23 +98,24 @@ abstract class RocksDBKeyStateEncoderBase(
   }
 }
 
-object RocksDBStateEncoder {
+object RocksDBStateEncoder extends Logging {
   def getKeyEncoder(
       keyStateEncoderSpec: KeyStateEncoderSpec,
       useColumnFamilies: Boolean,
-      virtualColFamilyId: Option[Short] = None): RocksDBKeyStateEncoder = {
+      virtualColFamilyId: Option[Short] = None,
+      avroEnc: Option[AvroEncoder] = None): RocksDBKeyStateEncoder = {
     // Return the key state encoder based on the requested type
     keyStateEncoderSpec match {
       case NoPrefixKeyStateEncoderSpec(keySchema) =>
-        new NoPrefixKeyStateEncoder(keySchema, useColumnFamilies, 
virtualColFamilyId)
+        new NoPrefixKeyStateEncoder(keySchema, useColumnFamilies, 
virtualColFamilyId, avroEnc)
 
       case PrefixKeyScanStateEncoderSpec(keySchema, numColsPrefixKey) =>
         new PrefixKeyScanStateEncoder(keySchema, numColsPrefixKey,
-          useColumnFamilies, virtualColFamilyId)
+          useColumnFamilies, virtualColFamilyId, avroEnc)
 
       case RangeKeyScanStateEncoderSpec(keySchema, orderingOrdinals) =>
         new RangeKeyScanStateEncoder(keySchema, orderingOrdinals,
-          useColumnFamilies, virtualColFamilyId)
+          useColumnFamilies, virtualColFamilyId, avroEnc)
 
       case _ =>
         throw new IllegalArgumentException(s"Unsupported key state encoder 
spec: " +
@@ -115,11 +125,12 @@ object RocksDBStateEncoder {
 
   def getValueEncoder(
       valueSchema: StructType,
-      useMultipleValuesPerKey: Boolean): RocksDBValueStateEncoder = {
+      useMultipleValuesPerKey: Boolean,
+      avroEnc: Option[AvroEncoder] = None): RocksDBValueStateEncoder = {
     if (useMultipleValuesPerKey) {
-      new MultiValuedStateEncoder(valueSchema)
+      new MultiValuedStateEncoder(valueSchema, avroEnc)
     } else {
-      new SingleValueStateEncoder(valueSchema)
+      new SingleValueStateEncoder(valueSchema, avroEnc)
     }
   }
 
@@ -145,6 +156,26 @@ object RocksDBStateEncoder {
     encodedBytes
   }
 
+  /**
+   * This method takes an UnsafeRow, and serializes to a byte array using Avro 
encoding.
+   */
+  def encodeUnsafeRowToAvro(
+     row: UnsafeRow,
+     avroSerializer: AvroSerializer,
+     valueAvroType: Schema,
+     out: ByteArrayOutputStream): Array[Byte] = {
+    // InternalRow -> Avro.GenericDataRecord
+    val avroData =
+      avroSerializer.serialize(row)
+    out.reset()
+    val encoder = EncoderFactory.get().directBinaryEncoder(out, null)
+    val writer = new GenericDatumWriter[Any](
+      valueAvroType) // Defining Avro writer for this struct type
+    writer.write(avroData, encoder) // Avro.GenericDataRecord -> byte array
+    encoder.flush()
+    out.toByteArray
+  }
+
   def decodeToUnsafeRow(bytes: Array[Byte], numFields: Int): UnsafeRow = {
     if (bytes != null) {
       val row = new UnsafeRow(numFields)
@@ -154,6 +185,26 @@ object RocksDBStateEncoder {
     }
   }
 
+  /**
+   * This method takes a byte array written using Avro encoding, and
+   * deserializes to an UnsafeRow using the Avro deserializer
+   */
+  def decodeFromAvroToUnsafeRow(
+      valueBytes: Array[Byte],
+      avroDeserializer: AvroDeserializer,
+      valueAvroType: Schema,
+      valueProj: UnsafeProjection): UnsafeRow = {
+    val reader = new GenericDatumReader[Any](valueAvroType)
+    val decoder = DecoderFactory.get().binaryDecoder(valueBytes, 0, 
valueBytes.length, null)
+    // bytes -> Avro.GenericDataRecord
+    val genericData = reader.read(null, decoder)
+    // Avro.GenericDataRecord -> InternalRow
+    val internalRow = avroDeserializer.deserialize(
+      genericData).orNull.asInstanceOf[InternalRow]
+    // InternalRow -> UnsafeRow
+    valueProj.apply(internalRow)
+  }
+
   def decodeToUnsafeRow(bytes: Array[Byte], reusedRow: UnsafeRow): UnsafeRow = 
{
     if (bytes != null) {
       // Platform.BYTE_ARRAY_OFFSET is the recommended way refer to the 1st 
offset. See Platform.
@@ -174,16 +225,20 @@ object RocksDBStateEncoder {
  * @param keySchema - schema of the key to be encoded
  * @param numColsPrefixKey - number of columns to be used for prefix key
  * @param useColumnFamilies - if column family is enabled for this encoder
+ * @param avroEnc - if Avro encoding is specified for this StateEncoder, this 
encoder will
+ *                be defined
  */
 class PrefixKeyScanStateEncoder(
     keySchema: StructType,
     numColsPrefixKey: Int,
     useColumnFamilies: Boolean = false,
-    virtualColFamilyId: Option[Short] = None)
+    virtualColFamilyId: Option[Short] = None,
+    avroEnc: Option[AvroEncoder] = None)
   extends RocksDBKeyStateEncoderBase(useColumnFamilies, virtualColFamilyId) {
 
   import RocksDBStateEncoder._
 
+  private val usingAvroEncoding = avroEnc.isDefined
   private val prefixKeyFieldsWithIdx: Seq[(StructField, Int)] = {
     keySchema.zipWithIndex.take(numColsPrefixKey)
   }
@@ -203,6 +258,18 @@ class PrefixKeyScanStateEncoder(
     UnsafeProjection.create(refs)
   }
 
+  // Prefix Key schema and projection definitions used by the Avro Serializers
+  // and Deserializers
+  private val prefixKeySchema = StructType(keySchema.take(numColsPrefixKey))
+  private lazy val prefixKeyAvroType = 
SchemaConverters.toAvroType(prefixKeySchema)
+  private val prefixKeyProj = UnsafeProjection.create(prefixKeySchema)
+
+  // Remaining Key schema and projection definitions used by the Avro 
Serializers
+  // and Deserializers
+  private val remainingKeySchema = StructType(keySchema.drop(numColsPrefixKey))
+  private lazy val remainingKeyAvroType = 
SchemaConverters.toAvroType(remainingKeySchema)
+  private val remainingKeyProj = UnsafeProjection.create(remainingKeySchema)
+
   // This is quite simple to do - just bind sequentially, as we don't change 
the order.
   private val restoreKeyProjection: UnsafeProjection = 
UnsafeProjection.create(keySchema)
 
@@ -210,9 +277,24 @@ class PrefixKeyScanStateEncoder(
   private val joinedRowOnKey = new JoinedRow()
 
   override def encodeKey(row: UnsafeRow): Array[Byte] = {
-    val prefixKeyEncoded = encodeUnsafeRow(extractPrefixKey(row))
-    val remainingEncoded = encodeUnsafeRow(remainingKeyProjection(row))
-
+    val (prefixKeyEncoded, remainingEncoded) = if (usingAvroEncoding) {
+      (
+        encodeUnsafeRowToAvro(
+          extractPrefixKey(row),
+          avroEnc.get.keySerializer,
+          prefixKeyAvroType,
+          out
+        ),
+        encodeUnsafeRowToAvro(
+          remainingKeyProjection(row),
+          avroEnc.get.suffixKeySerializer.get,
+          remainingKeyAvroType,
+          out
+        )
+      )
+    } else {
+      (encodeUnsafeRow(extractPrefixKey(row)), 
encodeUnsafeRow(remainingKeyProjection(row)))
+    }
     val (encodedBytes, startingOffset) = encodeColumnFamilyPrefix(
       prefixKeyEncoded.length + remainingEncoded.length + 4
     )
@@ -243,9 +325,25 @@ class PrefixKeyScanStateEncoder(
     Platform.copyMemory(keyBytes, decodeKeyStartOffset + 4 + 
prefixKeyEncodedLen,
       remainingKeyEncoded, Platform.BYTE_ARRAY_OFFSET, remainingKeyEncodedLen)
 
-    val prefixKeyDecoded = decodeToUnsafeRow(prefixKeyEncoded, numFields = 
numColsPrefixKey)
-    val remainingKeyDecoded = decodeToUnsafeRow(remainingKeyEncoded,
-      numFields = keySchema.length - numColsPrefixKey)
+    val (prefixKeyDecoded, remainingKeyDecoded) = if (usingAvroEncoding) {
+      (
+        decodeFromAvroToUnsafeRow(
+          prefixKeyEncoded,
+          avroEnc.get.keyDeserializer,
+          prefixKeyAvroType,
+          prefixKeyProj
+        ),
+        decodeFromAvroToUnsafeRow(
+          remainingKeyEncoded,
+          avroEnc.get.suffixKeyDeserializer.get,
+          remainingKeyAvroType,
+          remainingKeyProj
+        )
+      )
+    } else {
+      (decodeToUnsafeRow(prefixKeyEncoded, numFields = numColsPrefixKey),
+        decodeToUnsafeRow(remainingKeyEncoded, numFields = keySchema.length - 
numColsPrefixKey))
+    }
 
     
restoreKeyProjection(joinedRowOnKey.withLeft(prefixKeyDecoded).withRight(remainingKeyDecoded))
   }
@@ -255,7 +353,11 @@ class PrefixKeyScanStateEncoder(
   }
 
   override def encodePrefixKey(prefixKey: UnsafeRow): Array[Byte] = {
-    val prefixKeyEncoded = encodeUnsafeRow(prefixKey)
+    val prefixKeyEncoded = if (usingAvroEncoding) {
+      encodeUnsafeRowToAvro(prefixKey, avroEnc.get.keySerializer, 
prefixKeyAvroType, out)
+    } else {
+      encodeUnsafeRow(prefixKey)
+    }
     val (prefix, startingOffset) = encodeColumnFamilyPrefix(
       prefixKeyEncoded.length + 4
     )
@@ -299,13 +401,16 @@ class PrefixKeyScanStateEncoder(
  * @param keySchema - schema of the key to be encoded
  * @param orderingOrdinals - the ordinals for which the range scan is 
constructed
  * @param useColumnFamilies - if column family is enabled for this encoder
+ * @param avroEnc - if Avro encoding is specified for this StateEncoder, this 
encoder will
+ *                be defined
  */
 class RangeKeyScanStateEncoder(
     keySchema: StructType,
     orderingOrdinals: Seq[Int],
     useColumnFamilies: Boolean = false,
-    virtualColFamilyId: Option[Short] = None)
-  extends RocksDBKeyStateEncoderBase(useColumnFamilies, virtualColFamilyId) {
+    virtualColFamilyId: Option[Short] = None,
+    avroEnc: Option[AvroEncoder] = None)
+  extends RocksDBKeyStateEncoderBase(useColumnFamilies, virtualColFamilyId) 
with Logging {
 
   import RocksDBStateEncoder._
 
@@ -374,6 +479,22 @@ class RangeKeyScanStateEncoder(
     UnsafeProjection.create(refs)
   }
 
+  private val rangeScanAvroSchema = 
StateStoreColumnFamilySchemaUtils.convertForRangeScan(
+    StructType(rangeScanKeyFieldsWithOrdinal.map(_._1).toArray))
+
+  private lazy val rangeScanAvroType = 
SchemaConverters.toAvroType(rangeScanAvroSchema)
+
+  private val rangeScanAvroProjection = 
UnsafeProjection.create(rangeScanAvroSchema)
+
+  // Existing remainder key schema stuff
+  private val remainingKeySchema = StructType(
+    0.to(keySchema.length - 1).diff(orderingOrdinals).map(keySchema(_))
+  )
+
+  private lazy val remainingKeyAvroType = 
SchemaConverters.toAvroType(remainingKeySchema)
+
+  private val remainingKeyAvroProjection = 
UnsafeProjection.create(remainingKeySchema)
+
   // Reusable objects
   private val joinedRowOnKey = new JoinedRow()
 
@@ -563,13 +684,272 @@ class RangeKeyScanStateEncoder(
     writer.getRow()
   }
 
+  /**
+   * Encodes an UnsafeRow into an Avro-compatible byte array format for range 
scan operations.
+   *
+   * This method transforms row data into a binary format that preserves 
ordering when
+   * used in range scans.
+   * For each field in the row:
+   * - A marker byte is written to indicate null status or sign (for numeric 
types)
+   * - The value is written in big-endian format
+   *
+   * Special handling is implemented for:
+   * - Null values: marked with nullValMarker followed by zero bytes
+   * - Negative numbers: marked with negativeValMarker
+   * - Floating point numbers: bit manipulation to handle sign and NaN values 
correctly
+   *
+   * @param row The UnsafeRow to encode
+   * @param avroType The Avro schema defining the structure for encoding
+   * @return Array[Byte] containing the Avro-encoded data that preserves 
ordering for range scans
+   * @throws UnsupportedOperationException if a field's data type is not 
supported for range
+   *                                       scan encoding
+   */
+  def encodePrefixKeyForRangeScan(
+      row: UnsafeRow,
+      avroType: Schema): Array[Byte] = {
+    val record = new GenericData.Record(avroType)
+    var fieldIdx = 0
+    rangeScanKeyFieldsWithOrdinal.zipWithIndex.foreach { case 
(fieldWithOrdinal, idx) =>
+      val field = fieldWithOrdinal._1
+      val value = row.get(idx, field.dataType)
+
+      // Create marker byte buffer
+      val markerBuffer = ByteBuffer.allocate(1)
+      markerBuffer.order(ByteOrder.BIG_ENDIAN)
+
+      if (value == null) {
+        markerBuffer.put(nullValMarker)
+        record.put(fieldIdx, ByteBuffer.wrap(markerBuffer.array()))
+        record.put(fieldIdx + 1, ByteBuffer.wrap(new 
Array[Byte](field.dataType.defaultSize)))
+      } else {
+        field.dataType match {
+          case BooleanType =>
+            markerBuffer.put(positiveValMarker)
+            record.put(fieldIdx, ByteBuffer.wrap(markerBuffer.array()))
+            val valueBuffer = ByteBuffer.allocate(1)
+            valueBuffer.put(if (value.asInstanceOf[Boolean]) 1.toByte else 
0.toByte)
+            record.put(fieldIdx + 1, ByteBuffer.wrap(valueBuffer.array()))
+
+          case ByteType =>
+            val byteVal = value.asInstanceOf[Byte]
+            markerBuffer.put(if (byteVal < 0) negativeValMarker else 
positiveValMarker)
+            record.put(fieldIdx, ByteBuffer.wrap(markerBuffer.array()))
+
+            val valueBuffer = ByteBuffer.allocate(1)
+            valueBuffer.order(ByteOrder.BIG_ENDIAN)
+            valueBuffer.put(byteVal)
+            record.put(fieldIdx + 1, ByteBuffer.wrap(valueBuffer.array()))
+
+          case ShortType =>
+            val shortVal = value.asInstanceOf[Short]
+            markerBuffer.put(if (shortVal < 0) negativeValMarker else 
positiveValMarker)
+            record.put(fieldIdx, ByteBuffer.wrap(markerBuffer.array()))
+
+            val valueBuffer = ByteBuffer.allocate(2)
+            valueBuffer.order(ByteOrder.BIG_ENDIAN)
+            valueBuffer.putShort(shortVal)
+            record.put(fieldIdx + 1, ByteBuffer.wrap(valueBuffer.array()))
+
+          case IntegerType =>
+            val intVal = value.asInstanceOf[Int]
+            markerBuffer.put(if (intVal < 0) negativeValMarker else 
positiveValMarker)
+            record.put(fieldIdx, ByteBuffer.wrap(markerBuffer.array()))
+
+            val valueBuffer = ByteBuffer.allocate(4)
+            valueBuffer.order(ByteOrder.BIG_ENDIAN)
+            valueBuffer.putInt(intVal)
+            record.put(fieldIdx + 1, ByteBuffer.wrap(valueBuffer.array()))
+
+          case LongType =>
+            val longVal = value.asInstanceOf[Long]
+            markerBuffer.put(if (longVal < 0) negativeValMarker else 
positiveValMarker)
+            record.put(fieldIdx, ByteBuffer.wrap(markerBuffer.array()))
+
+            val valueBuffer = ByteBuffer.allocate(8)
+            valueBuffer.order(ByteOrder.BIG_ENDIAN)
+            valueBuffer.putLong(longVal)
+            record.put(fieldIdx + 1, ByteBuffer.wrap(valueBuffer.array()))
+
+          case FloatType =>
+            val floatVal = value.asInstanceOf[Float]
+            val rawBits = floatToRawIntBits(floatVal)
+            markerBuffer.put(if ((rawBits & floatSignBitMask) != 0) {
+              negativeValMarker
+            } else {
+              positiveValMarker
+            })
+            record.put(fieldIdx, ByteBuffer.wrap(markerBuffer.array()))
+
+            val valueBuffer = ByteBuffer.allocate(4)
+            valueBuffer.order(ByteOrder.BIG_ENDIAN)
+            if ((rawBits & floatSignBitMask) != 0) {
+              val updatedVal = rawBits ^ floatFlipBitMask
+              valueBuffer.putFloat(intBitsToFloat(updatedVal))
+            } else {
+              valueBuffer.putFloat(floatVal)
+            }
+            record.put(fieldIdx + 1, ByteBuffer.wrap(valueBuffer.array()))
+
+          case DoubleType =>
+            val doubleVal = value.asInstanceOf[Double]
+            val rawBits = doubleToRawLongBits(doubleVal)
+            markerBuffer.put(if ((rawBits & doubleSignBitMask) != 0) {
+              negativeValMarker
+            } else {
+              positiveValMarker
+            })
+            record.put(fieldIdx, ByteBuffer.wrap(markerBuffer.array()))
+
+            val valueBuffer = ByteBuffer.allocate(8)
+            valueBuffer.order(ByteOrder.BIG_ENDIAN)
+            if ((rawBits & doubleSignBitMask) != 0) {
+              val updatedVal = rawBits ^ doubleFlipBitMask
+              valueBuffer.putDouble(longBitsToDouble(updatedVal))
+            } else {
+              valueBuffer.putDouble(doubleVal)
+            }
+            record.put(fieldIdx + 1, ByteBuffer.wrap(valueBuffer.array()))
+
+          case _ => throw new UnsupportedOperationException(
+            s"Range scan encoding not supported for data type: 
${field.dataType}")
+        }
+      }
+      fieldIdx += 2
+    }
+
+    out.reset()
+    val writer = new GenericDatumWriter[GenericRecord](rangeScanAvroType)
+    val encoder = EncoderFactory.get().binaryEncoder(out, null)
+    writer.write(record, encoder)
+    encoder.flush()
+    out.toByteArray
+  }
+
+  /**
+   * Decodes an Avro-encoded byte array back into an UnsafeRow for range scan 
operations.
+   *
+   * This method reverses the encoding process performed by 
encodePrefixKeyForRangeScan:
+   * - Reads the marker byte to determine null status or sign
+   * - Reconstructs the original values from big-endian format
+   * - Handles special cases for floating point numbers by reversing bit 
manipulations
+   *
+   * The decoding process preserves the original data types and values, 
including:
+   * - Null values marked by nullValMarker
+   * - Sign information for numeric types
+   * - Proper restoration of negative floating point values
+   *
+   * @param bytes The Avro-encoded byte array to decode
+   * @param avroType The Avro schema defining the structure for decoding
+   * @return UnsafeRow containing the decoded data
+   * @throws UnsupportedOperationException if a field's data type is not 
supported for range
+   *                                       scan decoding
+   */
+  def decodePrefixKeyForRangeScan(
+      bytes: Array[Byte],
+      avroType: Schema): UnsafeRow = {
+
+    val reader = new GenericDatumReader[GenericRecord](avroType)
+    val decoder = DecoderFactory.get().binaryDecoder(bytes, 0, bytes.length, 
null)
+    val record = reader.read(null, decoder)
+
+    val rowWriter = new UnsafeRowWriter(rangeScanKeyFieldsWithOrdinal.length)
+    rowWriter.resetRowWriter()
+
+    var fieldIdx = 0
+    rangeScanKeyFieldsWithOrdinal.zipWithIndex.foreach { case 
(fieldWithOrdinal, idx) =>
+      val field = fieldWithOrdinal._1
+
+      val markerBytes = record.get(fieldIdx).asInstanceOf[ByteBuffer].array()
+      val markerBuf = ByteBuffer.wrap(markerBytes)
+      markerBuf.order(ByteOrder.BIG_ENDIAN)
+      val marker = markerBuf.get()
+
+      if (marker == nullValMarker) {
+        rowWriter.setNullAt(idx)
+      } else {
+        field.dataType match {
+          case BooleanType =>
+            val bytes = record.get(fieldIdx + 
1).asInstanceOf[ByteBuffer].array()
+            rowWriter.write(idx, bytes(0) == 1)
+
+          case ByteType =>
+            val bytes = record.get(fieldIdx + 
1).asInstanceOf[ByteBuffer].array()
+            val valueBuf = ByteBuffer.wrap(bytes)
+            valueBuf.order(ByteOrder.BIG_ENDIAN)
+            rowWriter.write(idx, valueBuf.get())
+
+          case ShortType =>
+            val bytes = record.get(fieldIdx + 
1).asInstanceOf[ByteBuffer].array()
+            val valueBuf = ByteBuffer.wrap(bytes)
+            valueBuf.order(ByteOrder.BIG_ENDIAN)
+            rowWriter.write(idx, valueBuf.getShort())
+
+          case IntegerType =>
+            val bytes = record.get(fieldIdx + 
1).asInstanceOf[ByteBuffer].array()
+            val valueBuf = ByteBuffer.wrap(bytes)
+            valueBuf.order(ByteOrder.BIG_ENDIAN)
+            rowWriter.write(idx, valueBuf.getInt())
+
+          case LongType =>
+            val bytes = record.get(fieldIdx + 
1).asInstanceOf[ByteBuffer].array()
+            val valueBuf = ByteBuffer.wrap(bytes)
+            valueBuf.order(ByteOrder.BIG_ENDIAN)
+            rowWriter.write(idx, valueBuf.getLong())
+
+          case FloatType =>
+            val bytes = record.get(fieldIdx + 
1).asInstanceOf[ByteBuffer].array()
+            val valueBuf = ByteBuffer.wrap(bytes)
+            valueBuf.order(ByteOrder.BIG_ENDIAN)
+            if (marker == negativeValMarker) {
+              val floatVal = valueBuf.getFloat
+              val updatedVal = floatToRawIntBits(floatVal) ^ floatFlipBitMask
+              rowWriter.write(idx, intBitsToFloat(updatedVal))
+            } else {
+              rowWriter.write(idx, valueBuf.getFloat())
+            }
+
+          case DoubleType =>
+            val bytes = record.get(fieldIdx + 
1).asInstanceOf[ByteBuffer].array()
+            val valueBuf = ByteBuffer.wrap(bytes)
+            valueBuf.order(ByteOrder.BIG_ENDIAN)
+            if (marker == negativeValMarker) {
+              val doubleVal = valueBuf.getDouble
+              val updatedVal = doubleToRawLongBits(doubleVal) ^ 
doubleFlipBitMask
+              rowWriter.write(idx, longBitsToDouble(updatedVal))
+            } else {
+              rowWriter.write(idx, valueBuf.getDouble())
+            }
+
+          case _ => throw new UnsupportedOperationException(
+            s"Range scan decoding not supported for data type: 
${field.dataType}")
+        }
+      }
+      fieldIdx += 2
+    }
+
+    rowWriter.getRow()
+  }
+
   override def encodeKey(row: UnsafeRow): Array[Byte] = {
     // This prefix key has the columns specified by orderingOrdinals
     val prefixKey = extractPrefixKey(row)
-    val rangeScanKeyEncoded = 
encodeUnsafeRow(encodePrefixKeyForRangeScan(prefixKey))
+    val rangeScanKeyEncoded = if (avroEnc.isDefined) {
+      encodePrefixKeyForRangeScan(prefixKey, rangeScanAvroType)
+    } else {
+      encodeUnsafeRow(encodePrefixKeyForRangeScan(prefixKey))
+    }
 
     val result = if (orderingOrdinals.length < keySchema.length) {
-      val remainingEncoded = encodeUnsafeRow(remainingKeyProjection(row))
+      val remainingEncoded = if (avroEnc.isDefined) {
+        encodeUnsafeRowToAvro(
+          remainingKeyProjection(row),
+          avroEnc.get.keySerializer,
+          remainingKeyAvroType,
+          out
+        )
+      } else {
+        encodeUnsafeRow(remainingKeyProjection(row))
+      }
       val (encodedBytes, startingOffset) = encodeColumnFamilyPrefix(
         rangeScanKeyEncoded.length + remainingEncoded.length + 4
       )
@@ -606,9 +986,12 @@ class RangeKeyScanStateEncoder(
     Platform.copyMemory(keyBytes, decodeKeyStartOffset + 4,
       prefixKeyEncoded, Platform.BYTE_ARRAY_OFFSET, prefixKeyEncodedLen)
 
-    val prefixKeyDecodedForRangeScan = decodeToUnsafeRow(prefixKeyEncoded,
-      numFields = orderingOrdinals.length)
-    val prefixKeyDecoded = 
decodePrefixKeyForRangeScan(prefixKeyDecodedForRangeScan)
+    val prefixKeyDecoded = if (avroEnc.isDefined) {
+      decodePrefixKeyForRangeScan(prefixKeyEncoded, rangeScanAvroType)
+    } else {
+      decodePrefixKeyForRangeScan(decodeToUnsafeRow(prefixKeyEncoded,
+        numFields = orderingOrdinals.length))
+    }
 
     if (orderingOrdinals.length < keySchema.length) {
       // Here we calculate the remainingKeyEncodedLen leveraging the length of 
keyBytes
@@ -620,8 +1003,14 @@ class RangeKeyScanStateEncoder(
         remainingKeyEncoded, Platform.BYTE_ARRAY_OFFSET,
         remainingKeyEncodedLen)
 
-      val remainingKeyDecoded = decodeToUnsafeRow(remainingKeyEncoded,
-        numFields = keySchema.length - orderingOrdinals.length)
+      val remainingKeyDecoded = if (avroEnc.isDefined) {
+        decodeFromAvroToUnsafeRow(remainingKeyEncoded,
+          avroEnc.get.keyDeserializer,
+          remainingKeyAvroType, remainingKeyAvroProjection)
+      } else {
+        decodeToUnsafeRow(remainingKeyEncoded,
+          numFields = keySchema.length - orderingOrdinals.length)
+      }
 
       val joined = 
joinedRowOnKey.withLeft(prefixKeyDecoded).withRight(remainingKeyDecoded)
       val restored = restoreKeyProjection(joined)
@@ -634,7 +1023,11 @@ class RangeKeyScanStateEncoder(
   }
 
   override def encodePrefixKey(prefixKey: UnsafeRow): Array[Byte] = {
-    val rangeScanKeyEncoded = 
encodeUnsafeRow(encodePrefixKeyForRangeScan(prefixKey))
+    val rangeScanKeyEncoded = if (avroEnc.isDefined) {
+      encodePrefixKeyForRangeScan(prefixKey, rangeScanAvroType)
+    } else {
+      encodeUnsafeRow(encodePrefixKeyForRangeScan(prefixKey))
+    }
     val (prefix, startingOffset) = 
encodeColumnFamilyPrefix(rangeScanKeyEncoded.length + 4)
 
     Platform.putInt(prefix, startingOffset, rangeScanKeyEncoded.length)
@@ -653,6 +1046,7 @@ class RangeKeyScanStateEncoder(
  * It uses the first byte of the generated byte array to store the version the 
describes how the
  * row is encoded in the rest of the byte array. Currently, the default 
version is 0,
  *
+ * If the avroEnc is specified, we are using Avro encoding for this column 
family's keys
  * VERSION 0:  [ VERSION (1 byte) | ROW (N bytes) ]
  *    The bytes of a UnsafeRow is written unmodified to starting from offset 1
  *    (offset 0 is the version byte of value 0). That is, if the unsafe row 
has N bytes,
@@ -661,19 +1055,27 @@ class RangeKeyScanStateEncoder(
 class NoPrefixKeyStateEncoder(
     keySchema: StructType,
     useColumnFamilies: Boolean = false,
-    virtualColFamilyId: Option[Short] = None)
-  extends RocksDBKeyStateEncoderBase(useColumnFamilies, virtualColFamilyId) {
+    virtualColFamilyId: Option[Short] = None,
+    avroEnc: Option[AvroEncoder] = None)
+  extends RocksDBKeyStateEncoderBase(useColumnFamilies, virtualColFamilyId) 
with Logging {
 
   import RocksDBStateEncoder._
 
   // Reusable objects
+  private val usingAvroEncoding = avroEnc.isDefined
   private val keyRow = new UnsafeRow(keySchema.size)
+  private lazy val keyAvroType = SchemaConverters.toAvroType(keySchema)
+  private val keyProj = UnsafeProjection.create(keySchema)
 
   override def encodeKey(row: UnsafeRow): Array[Byte] = {
     if (!useColumnFamilies) {
       encodeUnsafeRow(row)
     } else {
-      val bytesToEncode = row.getBytes
+      // If avroEnc is defined, we know that we need to use Avro to
+      // encode this UnsafeRow to Avro bytes
+      val bytesToEncode = if (usingAvroEncoding) {
+        encodeUnsafeRowToAvro(row, avroEnc.get.keySerializer, keyAvroType, out)
+      } else row.getBytes
       val (encodedBytes, startingOffset) = encodeColumnFamilyPrefix(
         bytesToEncode.length +
           STATE_ENCODING_NUM_VERSION_BYTES
@@ -697,11 +1099,21 @@ class NoPrefixKeyStateEncoder(
     if (useColumnFamilies) {
       if (keyBytes != null) {
         // Platform.BYTE_ARRAY_OFFSET is the recommended way refer to the 1st 
offset. See Platform.
-        keyRow.pointTo(
-          keyBytes,
-          decodeKeyStartOffset + STATE_ENCODING_NUM_VERSION_BYTES,
-          keyBytes.length - STATE_ENCODING_NUM_VERSION_BYTES - 
VIRTUAL_COL_FAMILY_PREFIX_BYTES)
-        keyRow
+        if (usingAvroEncoding) {
+          val dataLength = keyBytes.length - STATE_ENCODING_NUM_VERSION_BYTES -
+            VIRTUAL_COL_FAMILY_PREFIX_BYTES
+          val avroBytes = new Array[Byte](dataLength)
+          Platform.copyMemory(
+            keyBytes, decodeKeyStartOffset + STATE_ENCODING_NUM_VERSION_BYTES,
+            avroBytes, Platform.BYTE_ARRAY_OFFSET, dataLength)
+          decodeFromAvroToUnsafeRow(avroBytes, avroEnc.get.keyDeserializer, 
keyAvroType, keyProj)
+        } else {
+          keyRow.pointTo(
+            keyBytes,
+            decodeKeyStartOffset + STATE_ENCODING_NUM_VERSION_BYTES,
+            keyBytes.length - STATE_ENCODING_NUM_VERSION_BYTES - 
VIRTUAL_COL_FAMILY_PREFIX_BYTES)
+          keyRow
+        }
       } else {
         null
       }
@@ -727,17 +1139,28 @@ class NoPrefixKeyStateEncoder(
  * This encoder supports RocksDB StringAppendOperator merge operator. Values 
encoded can be
  * merged in RocksDB using merge operation, and all merged values can be read 
using decodeValues
  * operation.
+ * If the avroEnc is specified, we are using Avro encoding for this column 
family's values
  */
-class MultiValuedStateEncoder(valueSchema: StructType)
+class MultiValuedStateEncoder(
+    valueSchema: StructType,
+    avroEnc: Option[AvroEncoder] = None)
   extends RocksDBValueStateEncoder with Logging {
 
   import RocksDBStateEncoder._
 
+  private val usingAvroEncoding = avroEnc.isDefined
   // Reusable objects
+  private val out = new ByteArrayOutputStream
   private val valueRow = new UnsafeRow(valueSchema.size)
+  private lazy val valueAvroType = SchemaConverters.toAvroType(valueSchema)
+  private val valueProj = UnsafeProjection.create(valueSchema)
 
   override def encodeValue(row: UnsafeRow): Array[Byte] = {
-    val bytes = encodeUnsafeRow(row)
+    val bytes = if (usingAvroEncoding) {
+      encodeUnsafeRowToAvro(row, avroEnc.get.valueSerializer, valueAvroType, 
out)
+    } else {
+      encodeUnsafeRow(row)
+    }
     val numBytes = bytes.length
 
     val encodedBytes = new Array[Byte](java.lang.Integer.BYTES + bytes.length)
@@ -756,7 +1179,12 @@ class MultiValuedStateEncoder(valueSchema: StructType)
       val encodedValue = new Array[Byte](numBytes)
       Platform.copyMemory(valueBytes, java.lang.Integer.BYTES + 
Platform.BYTE_ARRAY_OFFSET,
         encodedValue, Platform.BYTE_ARRAY_OFFSET, numBytes)
-      decodeToUnsafeRow(encodedValue, valueRow)
+      if (usingAvroEncoding) {
+        decodeFromAvroToUnsafeRow(
+          encodedValue, avroEnc.get.valueDeserializer, valueAvroType, 
valueProj)
+      } else {
+        decodeToUnsafeRow(encodedValue, valueRow)
+      }
     }
   }
 
@@ -782,7 +1210,12 @@ class MultiValuedStateEncoder(valueSchema: StructType)
 
           pos += numBytes
           pos += 1 // eat the delimiter character
-          decodeToUnsafeRow(encodedValue, valueRow)
+          if (usingAvroEncoding) {
+            decodeFromAvroToUnsafeRow(
+              encodedValue, avroEnc.get.valueDeserializer, valueAvroType, 
valueProj)
+          } else {
+            decodeToUnsafeRow(encodedValue, valueRow)
+          }
         }
       }
     }
@@ -802,16 +1235,29 @@ class MultiValuedStateEncoder(valueSchema: StructType)
  *    The bytes of a UnsafeRow is written unmodified to starting from offset 1
  *    (offset 0 is the version byte of value 0). That is, if the unsafe row 
has N bytes,
  *    then the generated array byte will be N+1 bytes.
+ * If the avroEnc is specified, we are using Avro encoding for this column 
family's values
  */
-class SingleValueStateEncoder(valueSchema: StructType)
-  extends RocksDBValueStateEncoder {
+class SingleValueStateEncoder(
+    valueSchema: StructType,
+    avroEnc: Option[AvroEncoder] = None)
+  extends RocksDBValueStateEncoder with Logging {
 
   import RocksDBStateEncoder._
 
+  private val usingAvroEncoding = avroEnc.isDefined
   // Reusable objects
+  private val out = new ByteArrayOutputStream
   private val valueRow = new UnsafeRow(valueSchema.size)
+  private lazy val valueAvroType = SchemaConverters.toAvroType(valueSchema)
+  private val valueProj = UnsafeProjection.create(valueSchema)
 
-  override def encodeValue(row: UnsafeRow): Array[Byte] = encodeUnsafeRow(row)
+  override def encodeValue(row: UnsafeRow): Array[Byte] = {
+    if (usingAvroEncoding) {
+      encodeUnsafeRowToAvro(row, avroEnc.get.valueSerializer, valueAvroType, 
out)
+    } else {
+      encodeUnsafeRow(row)
+    }
+  }
 
   /**
    * Decode byte array for a value to a UnsafeRow.
@@ -820,7 +1266,15 @@ class SingleValueStateEncoder(valueSchema: StructType)
    *       the given byte array.
    */
   override def decodeValue(valueBytes: Array[Byte]): UnsafeRow = {
-    decodeToUnsafeRow(valueBytes, valueRow)
+    if (valueBytes == null) {
+      return null
+    }
+    if (usingAvroEncoding) {
+      decodeFromAvroToUnsafeRow(
+        valueBytes, avroEnc.get.valueDeserializer, valueAvroType, valueProj)
+    } else {
+      decodeToUnsafeRow(valueBytes, valueRow)
+    }
   }
 
   override def supportsMultipleValuesPerKey: Boolean = false
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
index 1fc6ab5910c6..e5a4175aeec1 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
@@ -18,10 +18,12 @@
 package org.apache.spark.sql.execution.streaming.state
 
 import java.io._
-import java.util.concurrent.ConcurrentHashMap
+import java.util.UUID
+import java.util.concurrent.{ConcurrentHashMap, TimeUnit}
 
 import scala.util.control.NonFatal
 
+import com.google.common.cache.CacheBuilder
 import org.apache.hadoop.conf.Configuration
 import org.apache.hadoop.fs.Path
 
@@ -29,11 +31,12 @@ import org.apache.spark.{SparkConf, SparkEnv, 
SparkException}
 import org.apache.spark.internal.{Logging, MDC}
 import org.apache.spark.internal.LogKeys._
 import org.apache.spark.io.CompressionCodec
+import org.apache.spark.sql.avro.{AvroDeserializer, AvroOptions, 
AvroSerializer, SchemaConverters}
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow
 import org.apache.spark.sql.errors.QueryExecutionErrors
-import org.apache.spark.sql.execution.streaming.CheckpointFileManager
+import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, 
StreamExecution}
 import org.apache.spark.sql.types.StructType
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{NonFateSharingCache, Utils}
 
 private[sql] class RocksDBStateStoreProvider
   extends StateStoreProvider with Logging with Closeable
@@ -74,10 +77,17 @@ private[sql] class RocksDBStateStoreProvider
         isInternal: Boolean = false): Unit = {
       verifyColFamilyCreationOrDeletion("create_col_family", colFamilyName, 
isInternal)
       val newColFamilyId = rocksDB.createColFamilyIfAbsent(colFamilyName)
+      // Create cache key using store ID to avoid collisions
+      val avroEncCacheKey = 
s"${getRunId(hadoopConf)}_${stateStoreId.operatorId}_" +
+        s"${stateStoreId.partitionId}_$colFamilyName"
+
+      val avroEnc = getAvroEnc(
+        stateStoreEncoding, avroEncCacheKey, keyStateEncoderSpec, valueSchema)
+
       keyValueEncoderMap.putIfAbsent(colFamilyName,
         (RocksDBStateEncoder.getKeyEncoder(keyStateEncoderSpec, 
useColumnFamilies,
-          Some(newColFamilyId)), 
RocksDBStateEncoder.getValueEncoder(valueSchema,
-          useMultipleValuesPerKey)))
+          Some(newColFamilyId), avroEnc), 
RocksDBStateEncoder.getValueEncoder(valueSchema,
+          useMultipleValuesPerKey, avroEnc)))
     }
 
     override def get(key: UnsafeRow, colFamilyName: String): UnsafeRow = {
@@ -364,6 +374,7 @@ private[sql] class RocksDBStateStoreProvider
     this.storeConf = storeConf
     this.hadoopConf = hadoopConf
     this.useColumnFamilies = useColumnFamilies
+    this.stateStoreEncoding = storeConf.stateStoreEncodingFormat
 
     if (useMultipleValuesPerKey) {
       require(useColumnFamilies, "Multiple values per key support requires 
column families to be" +
@@ -377,10 +388,17 @@ private[sql] class RocksDBStateStoreProvider
       defaultColFamilyId = 
Some(rocksDB.createColFamilyIfAbsent(StateStore.DEFAULT_COL_FAMILY_NAME))
     }
 
+    val colFamilyName = StateStore.DEFAULT_COL_FAMILY_NAME
+    // Create cache key using store ID to avoid collisions
+    val avroEncCacheKey = 
s"${getRunId(hadoopConf)}_${stateStoreId.operatorId}_" +
+      s"${stateStoreId.partitionId}_$colFamilyName"
+    val avroEnc = getAvroEnc(
+      stateStoreEncoding, avroEncCacheKey, keyStateEncoderSpec, valueSchema)
+
     keyValueEncoderMap.putIfAbsent(StateStore.DEFAULT_COL_FAMILY_NAME,
       (RocksDBStateEncoder.getKeyEncoder(keyStateEncoderSpec,
-        useColumnFamilies, defaultColFamilyId),
-        RocksDBStateEncoder.getValueEncoder(valueSchema, 
useMultipleValuesPerKey)))
+        useColumnFamilies, defaultColFamilyId, avroEnc),
+        RocksDBStateEncoder.getValueEncoder(valueSchema, 
useMultipleValuesPerKey, avroEnc)))
   }
 
   override def stateStoreId: StateStoreId = stateStoreId_
@@ -458,6 +476,7 @@ private[sql] class RocksDBStateStoreProvider
   @volatile private var storeConf: StateStoreConf = _
   @volatile private var hadoopConf: Configuration = _
   @volatile private var useColumnFamilies: Boolean = _
+  @volatile private var stateStoreEncoding: String = _
 
   private[sql] lazy val rocksDB = {
     val dfsRootDir = stateStoreId.storeCheckpointLocation().toString
@@ -593,6 +612,93 @@ object RocksDBStateStoreProvider {
   val STATE_ENCODING_VERSION: Byte = 0
   val VIRTUAL_COL_FAMILY_PREFIX_BYTES = 2
 
+  private val MAX_AVRO_ENCODERS_IN_CACHE = 1000
+  // Add the cache at companion object level so it persists across provider 
instances
+  private val avroEncoderMap: NonFateSharingCache[String, AvroEncoder] = {
+    val guavaCache = CacheBuilder.newBuilder()
+      .maximumSize(MAX_AVRO_ENCODERS_IN_CACHE)  // Adjust size based on your 
needs
+      .expireAfterAccess(1, TimeUnit.HOURS)  // Optional: Add expiration if 
needed
+      .build[String, AvroEncoder]()
+
+    new NonFateSharingCache(guavaCache)
+  }
+
+  def getAvroEnc(
+      stateStoreEncoding: String,
+      avroEncCacheKey: String,
+      keyStateEncoderSpec: KeyStateEncoderSpec,
+      valueSchema: StructType): Option[AvroEncoder] = {
+
+    stateStoreEncoding match {
+      case "avro" => Some(
+        RocksDBStateStoreProvider.avroEncoderMap.get(
+          avroEncCacheKey,
+          new java.util.concurrent.Callable[AvroEncoder] {
+            override def call(): AvroEncoder = 
createAvroEnc(keyStateEncoderSpec, valueSchema)
+          }
+        )
+      )
+      case "unsaferow" => None
+    }
+  }
+
+  private def getRunId(hadoopConf: Configuration): String = {
+    val runId = hadoopConf.get(StreamExecution.RUN_ID_KEY)
+    if (runId != null) {
+      runId
+    } else {
+      assert(Utils.isTesting, "Failed to find query id/batch Id in task 
context")
+      UUID.randomUUID().toString
+    }
+  }
+
+  private def getAvroSerializer(schema: StructType): AvroSerializer = {
+    val avroType = SchemaConverters.toAvroType(schema)
+    new AvroSerializer(schema, avroType, nullable = false)
+  }
+
+  private def getAvroDeserializer(schema: StructType): AvroDeserializer = {
+    val avroType = SchemaConverters.toAvroType(schema)
+    val avroOptions = AvroOptions(Map.empty)
+    new AvroDeserializer(avroType, schema,
+      avroOptions.datetimeRebaseModeInRead, 
avroOptions.useStableIdForUnionType,
+      avroOptions.stableIdPrefixForUnionType, 
avroOptions.recursiveFieldMaxDepth)
+  }
+
+  private def createAvroEnc(
+      keyStateEncoderSpec: KeyStateEncoderSpec,
+      valueSchema: StructType
+  ): AvroEncoder = {
+    val valueSerializer = getAvroSerializer(valueSchema)
+    val valueDeserializer = getAvroDeserializer(valueSchema)
+    val keySchema = keyStateEncoderSpec match {
+      case NoPrefixKeyStateEncoderSpec(schema) =>
+        schema
+      case PrefixKeyScanStateEncoderSpec(schema, numColsPrefixKey) =>
+        StructType(schema.take(numColsPrefixKey))
+      case RangeKeyScanStateEncoderSpec(schema, orderingOrdinals) =>
+        val remainingSchema = {
+          0.until(schema.length).diff(orderingOrdinals).map { ordinal =>
+            schema(ordinal)
+          }
+        }
+        StructType(remainingSchema)
+    }
+    val suffixKeySchema = keyStateEncoderSpec match {
+      case PrefixKeyScanStateEncoderSpec(schema, numColsPrefixKey) =>
+        Some(StructType(schema.drop(numColsPrefixKey)))
+      case _ => None
+    }
+    AvroEncoder(
+      getAvroSerializer(keySchema),
+      getAvroDeserializer(keySchema),
+      valueSerializer,
+      valueDeserializer,
+      suffixKeySchema.map(getAvroSerializer),
+      suffixKeySchema.map(getAvroDeserializer)
+    )
+  }
+
   // Native operation latencies report as latency in microseconds
   // as SQLMetrics support millis. Convert the value to millis
   val CUSTOM_METRIC_GET_TIME = StateStoreCustomTimingMetric(
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala
index 721d72b6a099..48b15ac04f40 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala
@@ -24,6 +24,7 @@ import org.apache.hadoop.fs.Path
 
 import org.apache.spark.SparkUnsupportedOperationException
 import org.apache.spark.internal.{Logging, LogKeys, MDC}
+import org.apache.spark.sql.avro.{AvroDeserializer, AvroSerializer}
 import org.apache.spark.sql.catalyst.util.UnsafeRowUtils
 import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, 
StatefulOperatorStateInfo}
 import 
org.apache.spark.sql.execution.streaming.state.SchemaHelper.{SchemaReader, 
SchemaWriter}
@@ -37,6 +38,30 @@ case class StateSchemaValidationResult(
     schemaPath: String
 )
 
+/**
+ * An Avro-based encoder used for serializing between UnsafeRow and Avro
+ *  byte arrays in RocksDB state stores.
+ *
+ * This encoder is primarily utilized by [[RocksDBStateStoreProvider]] and 
[[RocksDBStateEncoder]]
+ * to handle serialization and deserialization of state store data.
+ *
+ * @param keySerializer Serializer for converting state store keys to Avro 
format
+ * @param keyDeserializer Deserializer for converting Avro-encoded keys back 
to UnsafeRow
+ * @param valueSerializer Serializer for converting state store values to Avro 
format
+ * @param valueDeserializer Deserializer for converting Avro-encoded values 
back to UnsafeRow
+ * @param suffixKeySerializer Optional serializer for handling suffix keys in 
Avro format
+ * @param suffixKeyDeserializer Optional deserializer for converting 
Avro-encoded suffix
+ *                              keys back to UnsafeRow
+ */
+case class AvroEncoder(
+  keySerializer: AvroSerializer,
+  keyDeserializer: AvroDeserializer,
+  valueSerializer: AvroSerializer,
+  valueDeserializer: AvroDeserializer,
+  suffixKeySerializer: Option[AvroSerializer] = None,
+  suffixKeyDeserializer: Option[AvroDeserializer] = None
+) extends Serializable
+
 // Used to represent the schema of a column family in the state store
 case class StateStoreColFamilySchema(
     colFamilyName: String,
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
index 72bc3ca33054..e2b93c147891 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
@@ -37,10 +37,22 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow
 import org.apache.spark.sql.catalyst.util.UnsafeRowUtils
 import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
-import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo
+import org.apache.spark.sql.execution.streaming.{StatefulOperatorStateInfo, 
StreamExecution}
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.util.{NextIterator, ThreadUtils, Utils}
 
+sealed trait StateStoreEncoding {
+  override def toString: String = this match {
+    case StateStoreEncoding.UnsafeRow => "unsaferow"
+    case StateStoreEncoding.Avro => "avro"
+  }
+}
+
+object StateStoreEncoding {
+  case object UnsafeRow extends StateStoreEncoding
+  case object Avro extends StateStoreEncoding
+}
+
 /**
  * Base trait for a versioned key-value store which provides read operations. 
Each instance of a
  * `ReadStateStore` represents a specific version of state data, and such 
instances are created
@@ -769,6 +781,7 @@ object StateStore extends Logging {
     if (version < 0) {
       throw QueryExecutionErrors.unexpectedStateStoreVersion(version)
     }
+    hadoopConf.set(StreamExecution.RUN_ID_KEY, 
storeProviderId.queryRunId.toString)
     val storeProvider = getStateStoreProvider(storeProviderId, keySchema, 
valueSchema,
       keyStateEncoderSpec, useColumnFamilies, storeConf, hadoopConf, 
useMultipleValuesPerKey)
     storeProvider.getStore(version, stateStoreCkptId)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
index c8af395e996d..9d26bf8fdf2e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
@@ -83,6 +83,9 @@ class StateStoreConf(
   /** The interval of maintenance tasks. */
   val maintenanceInterval = sqlConf.streamingMaintenanceInterval
 
+  /** The interval of maintenance tasks. */
+  val stateStoreEncodingFormat = sqlConf.stateStoreEncodingFormat
+
   /**
    * When creating new state store checkpoint, which format version to use.
    */
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala
index baab6327b35c..af64f563cf7b 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala
@@ -24,7 +24,7 @@ import org.apache.hadoop.conf.Configuration
 import org.apache.spark.io.CompressionCodec
 import org.apache.spark.sql.{Encoders, Row}
 import org.apache.spark.sql.execution.streaming.MemoryStream
-import 
org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled,
 RocksDBFileManager, RocksDBStateStoreProvider, TestClass}
+import 
org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled,
 AlsoTestWithEncodingTypes, RocksDBFileManager, RocksDBStateStoreProvider, 
TestClass}
 import org.apache.spark.sql.functions.{col, explode, timestamp_seconds}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.streaming.{InputMapRow, ListState, MapInputEvent, 
MapOutputEvent, MapStateTTLProcessor, MaxEventTimeStatefulProcessor, 
OutputMode, RunningCountStatefulProcessor, 
RunningCountStatefulProcessorWithProcTimeTimerUpdates, StatefulProcessor, 
StateStoreMetricsTest, TestMapStateProcessor, TimeMode, TimerValues, 
TransformWithStateSuiteUtils, Trigger, TTLConfig, ValueState}
@@ -126,7 +126,7 @@ class SessionGroupsStatefulProcessorWithTTL extends
  * Test suite to verify integration of state data source reader with the 
transformWithState operator
  */
 class StateDataSourceTransformWithStateSuite extends StateStoreMetricsTest
-  with AlsoTestWithChangelogCheckpointingEnabled {
+  with AlsoTestWithChangelogCheckpointingEnabled with 
AlsoTestWithEncodingTypes {
 
   import testImplicits._
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala
index e1bd9dd38066..0abdcadefbd5 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala
@@ -43,6 +43,7 @@ import org.apache.spark.util.Utils
 @ExtendedSQLTest
 class RocksDBStateStoreSuite extends 
StateStoreSuiteBase[RocksDBStateStoreProvider]
   with AlsoTestWithChangelogCheckpointingEnabled
+  with AlsoTestWithEncodingTypes
   with SharedSparkSession
   with BeforeAndAfter {
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala
index 637eb4913030..61ca8e7c32f6 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala
@@ -86,6 +86,19 @@ trait RocksDBStateStoreChangelogCheckpointingTestUtil {
   }
 }
 
+trait AlsoTestWithEncodingTypes extends SQLTestUtils {
+  override protected def test(testName: String, testTags: Tag*)(testBody: => 
Any)
+                             (implicit pos: Position): Unit = {
+    Seq("unsaferow", "avro").foreach { encoding =>
+      super.test(s"$testName (encoding = $encoding)", testTags: _*) {
+        withSQLConf(SQLConf.STREAMING_STATE_STORE_ENCODING_FORMAT.key -> 
encoding) {
+          testBody
+        }
+      }
+    }
+  }
+}
+
 trait AlsoTestWithChangelogCheckpointingEnabled
   extends SQLTestUtils with RocksDBStateStoreChangelogCheckpointingTestUtil {
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala
index 55d08cd8f12a..8984d9b0845b 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala
@@ -423,7 +423,7 @@ class ValueStateSuite extends StateVariableSuiteBase {
  * types (ValueState, ListState, MapState) used in arbitrary stateful 
operators.
  */
 abstract class StateVariableSuiteBase extends SharedSparkSession
-  with BeforeAndAfter {
+  with BeforeAndAfter with AlsoTestWithEncodingTypes {
 
   before {
     StateStore.stop()
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala
index 88862e2ad079..5d88db0d01ba 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.streaming
 import org.apache.spark.SparkIllegalArgumentException
 import org.apache.spark.sql.Encoders
 import org.apache.spark.sql.execution.streaming.MemoryStream
-import 
org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled,
 RocksDBStateStoreProvider}
+import 
org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled,
 AlsoTestWithEncodingTypes, RocksDBStateStoreProvider}
 import org.apache.spark.sql.internal.SQLConf
 
 case class InputRow(key: String, action: String, value: String)
@@ -127,7 +127,8 @@ class ToggleSaveAndEmitProcessor
 }
 
 class TransformWithListStateSuite extends StreamTest
-  with AlsoTestWithChangelogCheckpointingEnabled {
+  with AlsoTestWithChangelogCheckpointingEnabled
+  with AlsoTestWithEncodingTypes {
   import testImplicits._
 
   test("test appending null value in list state throw exception") {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala
index 76c5cbeee424..ec6ff4fcceb6 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.streaming
 import org.apache.spark.SparkIllegalArgumentException
 import org.apache.spark.sql.Encoders
 import org.apache.spark.sql.execution.streaming.MemoryStream
-import 
org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled,
 RocksDBStateStoreProvider}
+import 
org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled,
 AlsoTestWithEncodingTypes, RocksDBStateStoreProvider}
 import org.apache.spark.sql.internal.SQLConf
 
 case class InputMapRow(key: String, action: String, value: (String, String))
@@ -81,7 +81,8 @@ class TestMapStateProcessor
  * operators such as transformWithState.
  */
 class TransformWithMapStateSuite extends StreamTest
-  with AlsoTestWithChangelogCheckpointingEnabled {
+  with AlsoTestWithChangelogCheckpointingEnabled
+  with AlsoTestWithEncodingTypes  {
   import testImplicits._
 
   private def testMapStateWithNullUserKey(inputMapRow: InputMapRow): Unit = {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
index 505775d4f6a9..91a47645f417 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
@@ -429,11 +429,12 @@ class SleepingTimerProcessor extends 
StatefulProcessor[String, String, String] {
  * Class that adds tests for transformWithState stateful streaming operator
  */
 class TransformWithStateSuite extends StateStoreMetricsTest
-  with AlsoTestWithChangelogCheckpointingEnabled {
+  with AlsoTestWithChangelogCheckpointingEnabled with 
AlsoTestWithEncodingTypes {
 
   import testImplicits._
 
-  test("transformWithState - streaming with rocksdb and invalid processor 
should fail") {
+  test("transformWithState - streaming with rocksdb and" +
+    " invalid processor should fail") {
     withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
       classOf[RocksDBStateStoreProvider].getName,
       SQLConf.SHUFFLE_PARTITIONS.key ->
@@ -688,7 +689,8 @@ class TransformWithStateSuite extends StateStoreMetricsTest
     }
   }
 
-  test("transformWithState - streaming with rocksdb and event time based 
timer") {
+  test("transformWithState - streaming with rocksdb and event " +
+  "time based timer") {
     val inputData = MemoryStream[(String, Int)]
     val result =
       inputData.toDS()
@@ -778,7 +780,8 @@ class TransformWithStateSuite extends StateStoreMetricsTest
     )
   }
 
-  test("Use statefulProcessor without transformWithState - handle should be 
absent") {
+  test("Use statefulProcessor without transformWithState -" +
+    " handle should be absent") {
     val processor = new RunningCountStatefulProcessor()
     val ex = intercept[Exception] {
       processor.getHandle
@@ -1034,7 +1037,8 @@ class TransformWithStateSuite extends 
StateStoreMetricsTest
     }
   }
 
-  test("transformWithState - verify StateSchemaV3 writes correct SQL schema of 
key/value") {
+  test("transformWithState - verify StateSchemaV3 writes " +
+    "correct SQL schema of key/value") {
     withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
       classOf[RocksDBStateStoreProvider].getName,
       SQLConf.SHUFFLE_PARTITIONS.key ->
@@ -1605,7 +1609,8 @@ class TransformWithStateSuite extends 
StateStoreMetricsTest
     }
   }
 
-  test("transformWithState - verify that schema file is kept after metadata is 
purged") {
+  test("transformWithState - verify that schema file " +
+    "is kept after metadata is purged") {
     withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
       classOf[RocksDBStateStoreProvider].getName,
       SQLConf.SHUFFLE_PARTITIONS.key ->
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLTest.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLTest.scala
index 2ddf69aa49e0..75fda9630779 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLTest.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLTest.scala
@@ -21,7 +21,7 @@ import java.sql.Timestamp
 import java.time.Duration
 
 import org.apache.spark.sql.execution.streaming.MemoryStream
-import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider
+import 
org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled,
 AlsoTestWithEncodingTypes, RocksDBStateStoreProvider}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.streaming.util.StreamManualClock
 
@@ -41,7 +41,8 @@ case class OutputEvent(
  * Test suite base for TransformWithState with TTL support.
  */
 abstract class TransformWithStateTTLTest
-  extends StreamTest {
+  extends StreamTest with AlsoTestWithChangelogCheckpointingEnabled
+  with AlsoTestWithEncodingTypes {
   import testImplicits._
 
   def getProcessor(ttlConfig: TTLConfig): StatefulProcessor[String, 
InputEvent, OutputEvent]
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala
index 21c3beb79314..b19c126c7386 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala
@@ -262,7 +262,8 @@ class TransformWithValueStateTTLSuite extends 
TransformWithStateTTLTest {
     }
   }
 
-  test("verify StateSchemaV3 writes correct SQL schema of key/value and with 
TTL") {
+  test("verify StateSchemaV3 writes correct SQL " +
+    "schema of key/value and with TTL") {
     withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
       classOf[RocksDBStateStoreProvider].getName,
       SQLConf.SHUFFLE_PARTITIONS.key ->


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to