HeartSaVioR commented on code in PR #44961:
URL: https://github.com/apache/spark/pull/44961#discussion_r1487255434


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala:
##########
@@ -78,7 +78,7 @@ class StatePartitionReader(
 
     StateStoreProvider.createAndInit(
       stateStoreProviderId, keySchema, valueSchema, numColsPrefixKey,
-      useColumnFamilies = false, storeConf, hadoopConf.value)
+      useColumnFamilies = false, storeConf, hadoopConf.value, 
useMultipleValuesPerKey = false)

Review Comment:
   Shall we file a JIRA ticket to support transformWithState with state data 
source - reader? From what I understand, we are deferring on supporting this. 
Please correct me if I'm missing.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala:
##########
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.execution.streaming.state.{StateStore, 
StateStoreErrors}
+import org.apache.spark.sql.streaming.ListState
+
+/**
+ * Provides concrete implementation for list of values associated with a state 
variable
+ * used in the streaming transformWithState operator.
+ *
+ * @param store - reference to the StateStore instance to be used for storing 
state
+ * @param stateName - name of logical state partition
+ * @tparam S - data type of object that will be stored in the list
+ */
+class ListStateImpl[S](store: StateStore,
+     stateName: String,
+     keyExprEnc: ExpressionEncoder[Any])
+  extends ListState[S] with Logging {
+
+   /** Whether state exists or not. */
+   override def exists(): Boolean = {
+     val encodedGroupingKey = 
StateTypesEncoderUtils.encodeGroupingKey(stateName, keyExprEnc)
+     val stateValue = store.get(encodedGroupingKey, stateName)
+     stateValue != null
+   }
+
+   /** Get the state value if it exists. If the state does not exist in state 
store, an
+    * empty iterator is returned. */
+   override def get(): Iterator[S] = {
+     val encodedKey = StateTypesEncoderUtils.encodeGroupingKey(stateName, 
keyExprEnc)
+     val unsafeRowValuesIterator = store.valuesIterator(encodedKey, stateName)
+     new Iterator[S] {
+       override def hasNext: Boolean = {
+         unsafeRowValuesIterator.hasNext
+       }
+
+       override def next(): S = {
+         val valueUnsafeRow = unsafeRowValuesIterator.next()
+         StateTypesEncoderUtils.decodeValue(valueUnsafeRow)
+       }
+     }
+   }
+
+   /** Get the list value as an option if it exists and None otherwise. */
+   override def getOption(): Option[Iterator[S]] = {
+     Option(get())
+   }
+
+   /** Update the value of the list. */
+   override def put(newState: Array[S]): Unit = {
+     validateNewState(newState)
+
+     if (newState.isEmpty) {
+       this.clear()

Review Comment:
   I feel like there might be some use case to distinguish empty list vs no 
value. Is this coupled with technical issue (e.g. we don't allow empty list as 
value at all), or just a matter of UX?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala:
##########
@@ -67,6 +67,16 @@ trait ReadStateStore {
   def get(key: UnsafeRow,
     colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): UnsafeRow
 
+  /**
+   * Provides an iterator containing all values for a particular key. The 
values are merged

Review Comment:
   The method doc seems to be too tied with the implementation, specifically 
for RocksDB state store provider. Shall we simply mention the method contract 
without detail here? Please have a look at other existing methods.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala:
##########
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.execution.streaming.state.{StateStore, 
StateStoreErrors}
+import org.apache.spark.sql.streaming.ListState
+
+/**
+ * Provides concrete implementation for list of values associated with a state 
variable
+ * used in the streaming transformWithState operator.
+ *
+ * @param store - reference to the StateStore instance to be used for storing 
state
+ * @param stateName - name of logical state partition
+ * @tparam S - data type of object that will be stored in the list
+ */
+class ListStateImpl[S](store: StateStore,
+     stateName: String,
+     keyExprEnc: ExpressionEncoder[Any])
+  extends ListState[S] with Logging {
+
+   /** Whether state exists or not. */
+   override def exists(): Boolean = {
+     val encodedGroupingKey = 
StateTypesEncoderUtils.encodeGroupingKey(stateName, keyExprEnc)
+     val stateValue = store.get(encodedGroupingKey, stateName)
+     stateValue != null
+   }
+
+   /** Get the state value if it exists. If the state does not exist in state 
store, an
+    * empty iterator is returned. */
+   override def get(): Iterator[S] = {
+     val encodedKey = StateTypesEncoderUtils.encodeGroupingKey(stateName, 
keyExprEnc)
+     val unsafeRowValuesIterator = store.valuesIterator(encodedKey, stateName)
+     new Iterator[S] {
+       override def hasNext: Boolean = {
+         unsafeRowValuesIterator.hasNext
+       }
+
+       override def next(): S = {
+         val valueUnsafeRow = unsafeRowValuesIterator.next()
+         StateTypesEncoderUtils.decodeValue(valueUnsafeRow)
+       }
+     }
+   }
+
+   /** Get the list value as an option if it exists and None otherwise. */

Review Comment:
   This does not seem to be compatible with method contract in get().
   
   `If the state does not exist in state store, an empty iterator is returned.`
   
   vs
   
   `Get the list value as an option if it exists and None otherwise.`
   
   `Option(Iterator.empty[String])` will be `Some(Iterator.empty[String])`.
   
   It'd be more natural to do the opposite - implement getOption() and leverage 
getOption() to cover get() if store.valuesIterator returns null.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala:
##########


Review Comment:
   I won't review this as of now, as I'd expect this PR to be rebased after 
#45038 is merged.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala:
##########
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.commons.lang3.SerializationUtils
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.execution.streaming.state.StateStoreErrors
+import org.apache.spark.sql.types.{BinaryType, StructType}
+
+/**
+ * Helper object providing APIs to encodes the grouping key, and user provided 
values
+ * to Spark [[UnsafeRow]].
+ */
+object StateTypesEncoderUtils {
+
+  private val KEY_ROW_SCHEMA: StructType = new StructType().add("key", 
BinaryType)
+  private val VALUE_ROW_SCHEMA: StructType = new StructType().add("value", 
BinaryType)
+
+  // TODO: validate places that are trying to encode the key and check if we 
can eliminate/
+  // add caching for some of these calls.
+  def encodeGroupingKey(stateName: String, keyExprEnc: 
ExpressionEncoder[Any]): UnsafeRow = {
+    val keyOption = ImplicitGroupingKeyTracker.getImplicitKeyOption
+    if (keyOption.isEmpty) {
+      throw StateStoreErrors.implicitKeyNotFound(stateName)
+    }
+
+    val toRow = keyExprEnc.createSerializer()
+    val keyByteArr = toRow
+      .apply(keyOption.get).asInstanceOf[UnsafeRow].getBytes()
+
+    val keyEncoder = UnsafeProjection.create(KEY_ROW_SCHEMA)

Review Comment:
   If we are fully sure we will end up with binary key - binary value for every 
state type, this should go up to fields in this object. Same for valueEncoder. 
If we anticipate different schema, maybe keyEncoder / valueEncoder should be 
passed as parameter instead.



##########
sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala:
##########
@@ -46,5 +46,5 @@ private[sql] trait ValueState[S] extends Serializable {
   def update(newState: S): Unit
 
   /** Remove this state. */
-  def remove(): Unit
+  def clear(): Unit

Review Comment:
   Just to make clear, this renaming was discussed, right? cc. @anishshri-db 



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala:
##########
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.execution.streaming.state.{StateStore, 
StateStoreErrors}
+import org.apache.spark.sql.streaming.ListState
+
+/**
+ * Provides concrete implementation for list of values associated with a state 
variable
+ * used in the streaming transformWithState operator.
+ *
+ * @param store - reference to the StateStore instance to be used for storing 
state
+ * @param stateName - name of logical state partition
+ * @tparam S - data type of object that will be stored in the list
+ */
+class ListStateImpl[S](store: StateStore,
+     stateName: String,
+     keyExprEnc: ExpressionEncoder[Any])
+  extends ListState[S] with Logging {

Review Comment:
   I think you can create serializer from keyExprEnc here as field and pass 
serializer instead in encodeGroupingKey.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala:
##########
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.commons.lang3.SerializationUtils
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.execution.streaming.state.StateStoreErrors
+import org.apache.spark.sql.types.{BinaryType, StructType}
+
+/**
+ * Helper object providing APIs to encodes the grouping key, and user provided 
values
+ * to Spark [[UnsafeRow]].
+ */
+object StateTypesEncoderUtils {
+
+  private val KEY_ROW_SCHEMA: StructType = new StructType().add("key", 
BinaryType)
+  private val VALUE_ROW_SCHEMA: StructType = new StructType().add("value", 
BinaryType)
+
+  // TODO: validate places that are trying to encode the key and check if we 
can eliminate/
+  // add caching for some of these calls.
+  def encodeGroupingKey(stateName: String, keyExprEnc: 
ExpressionEncoder[Any]): UnsafeRow = {
+    val keyOption = ImplicitGroupingKeyTracker.getImplicitKeyOption
+    if (keyOption.isEmpty) {
+      throw StateStoreErrors.implicitKeyNotFound(stateName)
+    }
+
+    val toRow = keyExprEnc.createSerializer()
+    val keyByteArr = toRow
+      .apply(keyOption.get).asInstanceOf[UnsafeRow].getBytes()
+
+    val keyEncoder = UnsafeProjection.create(KEY_ROW_SCHEMA)
+    val keyRow = keyEncoder(InternalRow(keyByteArr))
+    keyRow
+  }
+
+  def encodeValue[S] (value: S): UnsafeRow = {

Review Comment:
   nit: remove space between `]` and `(`.
   
   (Just thinking out loud, no need to fix) I'd prefer to tie to the language 
type system e.g. `[S <: Serializable]` here, though I see it will require so 
many places to be modified.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala:
##########
@@ -264,6 +278,11 @@ private[sql] class HDFSBackedStateStoreProvider extends 
StateStoreProvider with
       throw 
StateStoreErrors.multipleColumnFamiliesNotSupported("HDFSStateStoreProvider")
     }
 
+    if (useMultipleValuesPerKey) {
+      throw new UnsupportedOperationException("Multiple values per key are not 
supported with " +

Review Comment:
   ditto? 
   
   ```
   throw StateStoreErrors.unsupportedOperationException("multipleValuesPerKey", 
"HDFSStateStore")
   ```



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala:
##########
@@ -316,6 +321,38 @@ class RocksDB(
     }
   }
 
+  /**
+   * Merge the given value for the given key. This is equivalent to the Atomic
+   * Read-Modify-Write operation in RocksDB, known as the "Merge" operation. 
The
+   * modification is appending the provided value to current list of values for
+   * the given key.
+   *
+   * @note This operation requires that the encoder used can decode multiple 
values for
+   * a key from the values byte array.
+   *
+   * @note This update is not committed to disk until commit() is called.
+   */
+  def merge(
+      key: Array[Byte],
+      value: Array[Byte],
+      colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Unit = {
+    if (!useColumnFamilies) {
+      throw new RuntimeException("Merge operation uses changelog checkpointing 
v2 which" +
+        " requires column families to be enabled.")
+    }
+    verifyColFamilyExists(colFamilyName)
+
+    if (conf.trackTotalNumberOfRows) {
+      val oldValue = db.get(colFamilyNameToHandleMap(colFamilyName), 
readOptions, key)
+      if (oldValue == null) {
+        numKeysOnWritingVersion += 1
+      }
+    }
+    db.merge(colFamilyNameToHandleMap(colFamilyName), writeOptions, key, value)

Review Comment:
   Does this work as same as put if there is no such key?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala:
##########
@@ -94,6 +94,10 @@ private[sql] class HDFSBackedStateStoreProvider extends 
StateStoreProvider with
       Iterator[UnsafeRowPair] = {
       map.prefixScan(prefixKey)
     }
+
+    override def valuesIterator(key: UnsafeRow, colFamilyName: String): 
Iterator[UnsafeRow] = {
+      throw new UnsupportedOperationException("store does not support multiple 
values per key")

Review Comment:
   nit: should behave the same with writable state store
   
   ```
   throw StateStoreErrors.unsupportedOperationException("multipleValuesPerKey", 
"HDFSStateStore")
   ```



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala:
##########
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.execution.streaming.state.{StateStore, 
StateStoreErrors}
+import org.apache.spark.sql.streaming.ListState
+
+/**
+ * Provides concrete implementation for list of values associated with a state 
variable
+ * used in the streaming transformWithState operator.
+ *
+ * @param store - reference to the StateStore instance to be used for storing 
state
+ * @param stateName - name of logical state partition
+ * @tparam S - data type of object that will be stored in the list
+ */
+class ListStateImpl[S](store: StateStore,
+     stateName: String,
+     keyExprEnc: ExpressionEncoder[Any])
+  extends ListState[S] with Logging {
+
+   /** Whether state exists or not. */
+   override def exists(): Boolean = {
+     val encodedGroupingKey = 
StateTypesEncoderUtils.encodeGroupingKey(stateName, keyExprEnc)
+     val stateValue = store.get(encodedGroupingKey, stateName)
+     stateValue != null
+   }
+
+   /** Get the state value if it exists. If the state does not exist in state 
store, an
+    * empty iterator is returned. */
+   override def get(): Iterator[S] = {
+     val encodedKey = StateTypesEncoderUtils.encodeGroupingKey(stateName, 
keyExprEnc)
+     val unsafeRowValuesIterator = store.valuesIterator(encodedKey, stateName)

Review Comment:
   What's the expectation of value if the state does not exist for the key? I 
feel like the returning iterator will throw NPE if the value is null. If it's 
guaranteed to be empty Iterator then looks fine.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala:
##########
@@ -127,6 +137,9 @@ trait StateStore extends ReadStateStore {
   def remove(key: UnsafeRow,
     colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Unit
 
+  def merge(key: UnsafeRow, value: UnsafeRow,

Review Comment:
   Shall we add method doc?



##########
sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql.execution.streaming.MemoryStream
+import 
org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled,
 RocksDBStateStoreProvider}
+import org.apache.spark.sql.internal.SQLConf
+
+case class InputRow(key: String, action: String, value: String)
+
+class TestListStateProcessor
+  extends StatefulProcessor[String, InputRow, (String, String)] {
+
+  @transient var _processorHandle: StatefulProcessorHandle = _
+  @transient var _listState: ListState[String] = _
+
+  override def init(handle: StatefulProcessorHandle, outputMode: OutputMode): 
Unit = {
+    _processorHandle = handle
+    _listState = handle.getListState("testListState")
+  }
+
+  override def handleInputRows(key: String,

Review Comment:
   nit: the first param needs to be below the ( if the method definition does 
not fit to 2 lines.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala:
##########
@@ -65,6 +65,43 @@ private[sql] class RocksDBStateStoreProvider
       value
     }
 
+    override def valuesIterator(key: UnsafeRow, colFamilyName: String): 
Iterator[UnsafeRow] = {
+      verify(key != null, "Key cannot be null")
+      verify(encoder.supportsMultipleValuesPerKey, "valuesIterator requires a 
encoder " +
+      "that supports multiple values for a single key.")
+      val valueIterator = 
encoder.decodeValues(rocksDB.get(encoder.encodeKey(key), colFamilyName))
+
+      if (!isValidated && valueIterator.nonEmpty) {
+        new Iterator[UnsafeRow] {
+          override def hasNext: Boolean = {
+            valueIterator.hasNext
+          }
+
+          override def next(): UnsafeRow = {
+            val value = valueIterator.next()
+            if (!isValidated && value != null) {
+              StateStoreProvider.validateStateRowFormat(
+                key, keySchema, value, valueSchema, storeConf)
+              isValidated = true
+            }
+            value
+          }
+        }
+      } else {
+        valueIterator
+      }
+    }
+
+    override def merge(key: UnsafeRow, value: UnsafeRow,
+        colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Unit = {
+      verify(state == UPDATING, "Cannot put after already committed or 
aborted")
+      verify(encoder.supportsMultipleValuesPerKey, "Merge operation requires a 
encoder" +

Review Comment:
   nit: a -> an



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala:
##########
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.execution.streaming.state.{StateStore, 
StateStoreErrors}
+import org.apache.spark.sql.streaming.ListState
+
+/**
+ * Provides concrete implementation for list of values associated with a state 
variable
+ * used in the streaming transformWithState operator.
+ *
+ * @param store - reference to the StateStore instance to be used for storing 
state
+ * @param stateName - name of logical state partition
+ * @tparam S - data type of object that will be stored in the list
+ */
+class ListStateImpl[S](store: StateStore,
+     stateName: String,
+     keyExprEnc: ExpressionEncoder[Any])
+  extends ListState[S] with Logging {
+
+   /** Whether state exists or not. */
+   override def exists(): Boolean = {
+     val encodedGroupingKey = 
StateTypesEncoderUtils.encodeGroupingKey(stateName, keyExprEnc)
+     val stateValue = store.get(encodedGroupingKey, stateName)
+     stateValue != null
+   }
+
+   /** Get the state value if it exists. If the state does not exist in state 
store, an
+    * empty iterator is returned. */
+   override def get(): Iterator[S] = {
+     val encodedKey = StateTypesEncoderUtils.encodeGroupingKey(stateName, 
keyExprEnc)
+     val unsafeRowValuesIterator = store.valuesIterator(encodedKey, stateName)
+     new Iterator[S] {
+       override def hasNext: Boolean = {
+         unsafeRowValuesIterator.hasNext
+       }
+
+       override def next(): S = {
+         val valueUnsafeRow = unsafeRowValuesIterator.next()
+         StateTypesEncoderUtils.decodeValue(valueUnsafeRow)
+       }
+     }
+   }
+
+   /** Get the list value as an option if it exists and None otherwise. */
+   override def getOption(): Option[Iterator[S]] = {
+     Option(get())
+   }
+
+   /** Update the value of the list. */
+   override def put(newState: Array[S]): Unit = {
+     validateNewState(newState)
+
+     if (newState.isEmpty) {
+       this.clear()

Review Comment:
   This also makes the difference between get() and getOption() be moot. 
(Though we could provide None for empty Iterator in getOption() to cover with 
this case.)



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala:
##########
@@ -845,6 +859,76 @@ class RocksDBSuite extends 
AlsoTestWithChangelogCheckpointingEnabled with Shared
     }
   }
 
+  test("ensure merge operation is not supported if column families is not 
enabled") {
+    withTempDir { dir =>
+      val remoteDir = Utils.createTempDir().toString
+      val conf = dbConf.copy(minDeltasForSnapshot = 5, compactOnCommit = false)
+      new File(remoteDir).delete() // to make sure that the directory gets 
created
+      withDB(remoteDir, conf = conf, useColumnFamilies = false) { db =>
+        db.load(0)
+        db.put("a", "1")
+        intercept[RuntimeException](
+          db.merge("a", "2")
+        )
+      }
+    }
+  }
+
+  test(s"RocksDB: ensure merge operation correctness") {

Review Comment:
   nit: remove s before ""



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala:
##########
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.execution.streaming.state.StateStore
+import org.apache.spark.sql.streaming.ListState
+
+/**
+ * Provides concrete implementation for list of values associated with a state 
variable
+ * used in the streaming transformWithState operator.
+ *
+ * @param store - reference to the StateStore instance to be used for storing 
state
+ * @param stateName - name of logical state partition
+ * @tparam S - data type of object that will be stored in the list
+ */
+class ListStateImpl[S](store: StateStore,

Review Comment:
   nit: the first argument does not seem to be moved



##########
sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql.execution.streaming.MemoryStream
+import 
org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled,
 RocksDBStateStoreProvider}
+import org.apache.spark.sql.internal.SQLConf
+
+case class InputRow(key: String, action: String, value: String)
+
+class TestListStateProcessor
+  extends StatefulProcessor[String, InputRow, (String, String)] {
+
+  @transient var _processorHandle: StatefulProcessorHandle = _
+  @transient var _listState: ListState[String] = _
+
+  override def init(handle: StatefulProcessorHandle, outputMode: OutputMode): 
Unit = {
+    _processorHandle = handle
+    _listState = handle.getListState("testListState")
+  }
+
+  override def handleInputRows(key: String,
+      rows: Iterator[InputRow],
+      timerValues: TimerValues): Iterator[(String, String)] = {
+
+    var output = List[(String, String)]()
+
+    for (row <- rows) {
+      if (row.action == "emit") {
+        output = (key, row.value) :: output
+      } else if (row.action == "emitAllInState") {
+        _listState.get().foreach(v => {

Review Comment:
   nit: `(v => {` to `{ v =>`, reduce the extra paren



##########
sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql.execution.streaming.MemoryStream
+import 
org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled,
 RocksDBStateStoreProvider}
+import org.apache.spark.sql.internal.SQLConf
+
+case class InputRow(key: String, action: String, value: String)
+
+class TestListStateProcessor
+  extends StatefulProcessor[String, InputRow, (String, String)] {
+
+  @transient var _processorHandle: StatefulProcessorHandle = _
+  @transient var _listState: ListState[String] = _
+
+  override def init(handle: StatefulProcessorHandle, outputMode: OutputMode): 
Unit = {
+    _processorHandle = handle
+    _listState = handle.getListState("testListState")
+  }
+
+  override def handleInputRows(key: String,
+      rows: Iterator[InputRow],
+      timerValues: TimerValues): Iterator[(String, String)] = {
+
+    var output = List[(String, String)]()
+
+    for (row <- rows) {
+      if (row.action == "emit") {
+        output = (key, row.value) :: output
+      } else if (row.action == "emitAllInState") {
+        _listState.get().foreach(v => {
+          output = (key, v) :: output
+        })
+        _listState.clear()
+      } else if (row.action == "append") {
+        _listState.appendValue(row.value)
+      } else if (row.action == "appendAll") {
+        _listState.appendList(row.value.split(","))
+      } else if (row.action == "put") {
+        _listState.put(row.value.split(","))
+      } else if (row.action == "remove") {
+        _listState.clear()
+      } else if (row.action == "tryAppendingNull") {
+        _listState.appendValue(null)
+      } else if (row.action == "tryAppendingNullValueInList") {
+        _listState.appendList(Array(null))
+      } else if (row.action == "tryAppendingNullList") {
+        _listState.appendList(null)
+      } else if (row.action == "tryPutNullList") {
+        _listState.put(null)
+      } else if (row.action == "tryPuttingNullInList") {
+        _listState.put(Array(null))
+      }
+    }
+
+    output.iterator
+  }
+
+  override def close(): Unit = {}
+}
+
+class ToggleSaveAndEmitProcessor
+  extends StatefulProcessor[String, String, String] {
+
+  @transient var _processorHandle: StatefulProcessorHandle = _
+  @transient var _listState: ListState[String] = _
+  @transient var _valueState: ValueState[Boolean] = _
+
+  override def init(handle: StatefulProcessorHandle, outputMode: OutputMode): 
Unit = {
+    _processorHandle = handle
+    _listState = handle.getListState("testListState")
+    _valueState = handle.getValueState("testValueState")
+  }
+
+  override def handleInputRows(
+      key: String,
+      rows: Iterator[String],
+      timerValues: TimerValues): Iterator[String] = {
+    val valueStateOption = _valueState.getOption()
+
+    if (valueStateOption.isEmpty || !valueStateOption.get) {
+      _listState.appendList(rows.toArray)
+      _valueState.update(true)
+      Seq().iterator
+    } else {
+      _valueState.clear()
+      val storedValues = _listState.get()
+      _listState.clear()
+
+      new Iterator[String] {
+        override def hasNext: Boolean = {
+          rows.hasNext || storedValues.hasNext
+        }
+
+        override def next(): String = {
+          if (rows.hasNext) {
+            rows.next()
+          } else {
+            storedValues.next()
+          }
+        }
+      }
+    }
+  }
+
+  override def close(): Unit = {}
+}
+
+class TransformWithListStateSuite extends StreamTest
+  with AlsoTestWithChangelogCheckpointingEnabled {
+  import testImplicits._
+
+  test("test appending null value in list state throw exception") {
+    withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+      classOf[RocksDBStateStoreProvider].getName) {
+
+      val inputData = MemoryStream[InputRow]
+      val result = inputData.toDS()
+        .groupByKey(x => x.key)
+        .transformWithState(new TestListStateProcessor(),
+          TimeoutMode.NoTimeouts(),
+          OutputMode.Update())
+
+      testStream(result, OutputMode.Update()) (
+        AddData(inputData, InputRow("k1", "tryAppendingNull", "")),
+        ExpectFailure[SparkException](e => {
+          assert(e.getMessage.contains("CANNOT_WRITE_STATE_STORE.NULL_VALUE"))
+        })
+      )
+    }
+  }
+
+  test("test putting null value in list state throw exception") {
+    withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+      classOf[RocksDBStateStoreProvider].getName) {
+
+      val inputData = MemoryStream[InputRow]
+      val result = inputData.toDS()
+        .groupByKey(x => x.key)
+        .transformWithState(new TestListStateProcessor(),
+          TimeoutMode.NoTimeouts(),
+          OutputMode.Update())
+
+      testStream(result, OutputMode.Update())(
+        AddData(inputData, InputRow("k1", "tryPuttingNullInList", "")),
+        ExpectFailure[SparkException](e => {
+          assert(e.getMessage.contains("CANNOT_WRITE_STATE_STORE.NULL_VALUE"))
+        })
+      )
+    }
+  }
+
+  test("test putting null list in list state throw exception") {
+    withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+      classOf[RocksDBStateStoreProvider].getName) {
+
+      val inputData = MemoryStream[InputRow]
+      val result = inputData.toDS()
+        .groupByKey(x => x.key)
+        .transformWithState(new TestListStateProcessor(),
+          TimeoutMode.NoTimeouts(),
+          OutputMode.Update())
+
+      testStream(result, OutputMode.Update())(
+        AddData(inputData, InputRow("k1", "tryPutNullList", "")),
+        ExpectFailure[SparkException](e => {
+          assert(e.getMessage.contains("CANNOT_WRITE_STATE_STORE.NULL_VALUE"))
+        })
+      )
+    }
+  }
+
+  test("test appending null list in list state throw exception") {
+    withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+      classOf[RocksDBStateStoreProvider].getName) {
+
+      val inputData = MemoryStream[InputRow]
+      val result = inputData.toDS()
+        .groupByKey(x => x.key)
+        .transformWithState(new TestListStateProcessor(),
+          TimeoutMode.NoTimeouts(),
+          OutputMode.Update())
+
+      testStream(result, OutputMode.Update())(
+        AddData(inputData, InputRow("k1", "tryAppendingNullList", "")),
+        ExpectFailure[SparkException](e => {
+          assert(e.getMessage.contains("CANNOT_WRITE_STATE_STORE.NULL_VALUE"))
+        })
+      )
+    }
+  }
+
+  test("test list state correctness") {
+    withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+      classOf[RocksDBStateStoreProvider].getName) {
+
+      val inputData = MemoryStream[InputRow]
+      val result = inputData.toDS()
+        .groupByKey(x => x.key)
+        .transformWithState(new TestListStateProcessor(),
+          TimeoutMode.NoTimeouts(),
+          OutputMode.Update())
+
+      testStream(result, OutputMode.Update()) (
+        // no interaction test
+        AddData(inputData, InputRow("k1", "emit", "v1")),
+        CheckNewAnswer(("k1", "v1")),
+        // check simple append
+        AddData(inputData, InputRow("k1", "append", "v2")),
+        AddData(inputData, InputRow("k1", "emitAllInState", "")),
+        CheckNewAnswer(("k1", "v2")),
+        // multiple appends are correctly stored and emitted
+        AddData(inputData, InputRow("k2", "append", "v1")),
+        AddData(inputData, InputRow("k1", "append", "v4")),
+        AddData(inputData, InputRow("k2", "append", "v2")),
+        AddData(inputData, InputRow("k1", "emit", "v5")),
+        AddData(inputData, InputRow("k2", "emit", "v3")),
+        CheckNewAnswer(("k1", "v5"), ("k2", "v3")),
+        AddData(inputData, InputRow("k1", "emitAllInState", "")),
+        AddData(inputData, InputRow("k2", "emitAllInState", "")),
+        CheckNewAnswer(("k2", "v1"), ("k2", "v2"), ("k1", "v4")),
+        // check appendAll with append
+        AddData(inputData, InputRow("k3", "appendAll", "v1,v2,v3")),
+        AddData(inputData, InputRow("k3", "emit", "v4")),
+        AddData(inputData, InputRow("k3", "append", "v5")),
+        CheckNewAnswer(("k3", "v4")),
+        AddData(inputData, InputRow("k3", "emitAllInState", "")),
+        CheckNewAnswer(("k3", "v1"), ("k3", "v2"), ("k3", "v3"), ("k3", "v5")),
+        // check removal cleans up all data in state
+        AddData(inputData, InputRow("k4", "append", "v2")),
+        AddData(inputData, InputRow("k4", "appendList", "v3,v4")),
+        AddData(inputData, InputRow("k4", "remove", "")),

Review Comment:
   nit: shall we add emitAllInState for k4 to make sure no state value exists 
in k4?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to