anishshri-db commented on code in PR #48148:
URL: https://github.com/apache/spark/pull/48148#discussion_r1775947838


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala:
##########
@@ -676,27 +681,83 @@ class RocksDBStateStoreChangeDataReader(
     endVersion: Long,
     compressionCodec: CompressionCodec,
     keyValueEncoderMap:
-      ConcurrentHashMap[String, (RocksDBKeyStateEncoder, 
RocksDBValueStateEncoder)])
+      ConcurrentHashMap[String, (RocksDBKeyStateEncoder, 
RocksDBValueStateEncoder)],
+    colFamilyNameOpt: Option[String] = None)
   extends StateStoreChangeDataReader(
-    fm, stateLocation, startVersion, endVersion, compressionCodec) {
+    fm, stateLocation, startVersion, endVersion, compressionCodec, 
colFamilyNameOpt) {
 
   override protected var changelogSuffix: String = "changelog"
 
-  override def getNext(): (RecordType.Value, UnsafeRow, UnsafeRow, Long) = {
-    val reader = currentChangelogReader()
-    if (reader == null) {
-      return null
+  private def getColFamilyIdBytes: Option[Array[Byte]] = {
+    if (colFamilyNameOpt.isDefined) {
+      val colFamilyName = colFamilyNameOpt.get
+      if (!keyValueEncoderMap.containsKey(colFamilyName)) {
+        throw new IllegalStateException(
+          s"Column family $colFamilyName not found in the key value encoder 
map")
+      }
+      Some(keyValueEncoderMap.get(colFamilyName)._1.getColumnFamilyIdBytes())
+    } else {
+      None
     }
-    val (recordType, keyArray, valueArray) = reader.next()
-    // Todo: does not support multiple virtual column families
-    val (rocksDBKeyStateEncoder, rocksDBValueStateEncoder) =
-      keyValueEncoderMap.get(StateStore.DEFAULT_COL_FAMILY_NAME)
-    val keyRow = rocksDBKeyStateEncoder.decodeKey(keyArray)
-    if (valueArray == null) {
-      (recordType, keyRow, null, currentChangelogVersion - 1)
+  }
+
+  private val colFamilyIdBytesOpt: Option[Array[Byte]] = getColFamilyIdBytes
+
+  override def getNext(): (RecordType.Value, UnsafeRow, UnsafeRow, Long) = {
+    if (colFamilyIdBytesOpt.isDefined) {
+      // If we are reading records for a particular column family, the 
corresponding vcf id
+      // will be encoded in the key byte array. We need to extract that and 
compare for the
+      // expected column family id. If it matches, we return the record. If 
not, we move to
+      // the next record. Note that this has be handled across multiple 
changelog files and we
+      // rely on the currentChangelogReader to move to the next changelog file 
when needed.
+      var currRecord: (RecordType.Value, Array[Byte], Array[Byte]) = null
+      var currEncoder: (RocksDBKeyStateEncoder, RocksDBValueStateEncoder) = 
null
+
+      breakable {
+        while (true) {
+          val reader = currentChangelogReader()
+          if (reader == null) {
+            return null
+          }
+
+          currRecord = reader.next()
+          currEncoder = keyValueEncoderMap.get(colFamilyNameOpt
+            .getOrElse(StateStore.DEFAULT_COL_FAMILY_NAME))
+
+          val colFamilyIdBytes: Array[Byte] = colFamilyIdBytesOpt.get
+
+          val keyArrayWithColFamilyId = new Array[Byte](colFamilyIdBytes.size)
+          Array.copy(currRecord._2, 0, keyArrayWithColFamilyId, 0, 
colFamilyIdBytes.size)
+
+          if (java.util.Arrays.equals(keyArrayWithColFamilyId, 
colFamilyIdBytes)) {
+            break()
+          }
+        }
+      }
+
+      val keyRow = currEncoder._1.decodeKey(currRecord._2)
+      if (currRecord._3 == null) {
+        (currRecord._1, keyRow, null, currentChangelogVersion - 1)
+      } else {
+        val valueRow = currEncoder._2.decodeValue(currRecord._3)
+        (currRecord._1, keyRow, valueRow, currentChangelogVersion - 1)
+      }
     } else {
-      val valueRow = rocksDBValueStateEncoder.decodeValue(valueArray)
-      (recordType, keyRow, valueRow, currentChangelogVersion - 1)
+      val reader = currentChangelogReader()
+      if (reader == null) {
+        return null
+      }
+      val (recordType, keyArray, valueArray) = reader.next()

Review Comment:
   Done



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to