HeartSaVioR commented on code in PR #49304:
URL: https://github.com/apache/spark/pull/49304#discussion_r1926516850


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala:
##########
@@ -656,31 +803,75 @@ class RocksDB(
    *
    * @note This update is not committed to disk until commit() is called.
    */
-  def merge(key: Array[Byte], value: Array[Byte]): Unit = {
-    if (conf.trackTotalNumberOfRows) {
-      val oldValue = db.get(readOptions, key)
-      if (oldValue == null) {
-        numKeysOnWritingVersion += 1
+  def merge(
+      key: Array[Byte],
+      value: Array[Byte],
+      cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Unit = {
+    val keyWithPrefix = if (useColumnFamilies) {
+      encodeStateRowWithPrefix(key, cfName)
+    } else {
+      key
+    }
+
+    if (useColumnFamilies) {
+      if (conf.trackTotalNumberOfRows) {
+        val oldValue = db.get(readOptions, keyWithPrefix)
+        if (oldValue == null) {
+          val cfInfo = getColumnFamilyInfo(cfName)
+          if (cfInfo.isInternal) {
+            numInternalKeysOnWritingVersion += 1
+          } else {
+            numKeysOnWritingVersion += 1
+          }
+        }
+      }
+    } else {
+      if (conf.trackTotalNumberOfRows) {
+        val oldValue = db.get(readOptions, keyWithPrefix)
+        if (oldValue == null) {
+          numKeysOnWritingVersion += 1
+        }
       }
     }
-    db.merge(writeOptions, key, value)
 
-    changelogWriter.foreach(_.merge(key, value))
+    db.merge(writeOptions, keyWithPrefix, value)
+    changelogWriter.foreach(_.merge(keyWithPrefix, value))
   }
 
   /**
    * Remove the key if present.
    * @note This update is not committed to disk until commit() is called.
    */
-  def remove(key: Array[Byte]): Unit = {
-    if (conf.trackTotalNumberOfRows) {
-      val value = db.get(readOptions, key)
-      if (value != null) {
-        numKeysOnWritingVersion -= 1
+  def remove(key: Array[Byte], cfName: String = 
StateStore.DEFAULT_COL_FAMILY_NAME): Unit = {
+    val keyWithPrefix = if (useColumnFamilies) {
+      encodeStateRowWithPrefix(key, cfName)
+    } else {
+      key
+    }
+
+    if (useColumnFamilies) {
+      if (conf.trackTotalNumberOfRows) {
+        val oldValue = db.get(readOptions, keyWithPrefix)
+        if (oldValue != null) {
+          val cfInfo = getColumnFamilyInfo(cfName)
+          if (cfInfo.isInternal) {
+            numInternalKeysOnWritingVersion -= 1
+          } else {
+            numKeysOnWritingVersion -= 1
+          }
+        }
+      }
+    } else {
+      if (conf.trackTotalNumberOfRows) {
+        val value = db.get(readOptions, keyWithPrefix)
+        if (value != null) {
+          numKeysOnWritingVersion -= 1
+        }
       }
     }
-    db.delete(writeOptions, key)
-    changelogWriter.foreach(_.delete(key))
+
+    db.delete(writeOptions, keyWithPrefix)
+    changelogWriter.foreach(_.delete(keyWithPrefix))
   }
 
   /**

Review Comment:
   I'm not sure how we missed this (or we were doing prefix scan to cover this 
indirectly)...
   
   I don't think iterator() makes sense for multiple column families use case. 
The resulting iterator won't give any information about column family, so it's 
not even useful for global scan on entire space as well. (Different CFs can 
have the same key.)
   
   I think we have to make change on the iterator() to receive cfName to scan 
only for certain cf. Please let me know if you think we can't (or you see use 
cases of performing global scan).



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to