HeartSaVioR commented on code in PR #48290:
URL: https://github.com/apache/spark/pull/48290#discussion_r1802300062


##########
python/pyspark/sql/streaming/stateful_processor.py:
##########
@@ -180,6 +256,41 @@ def getListState(
         self.stateful_processor_api_client.get_list_state(state_name, schema, 
ttl_duration_ms)
         return ListState(ListStateClient(self.stateful_processor_api_client), 
state_name, schema)
 
+    def getMapState(
+        self,
+        state_name: str,
+        key_schema: Union[StructType, str],

Review Comment:
   Shall we use `user_key`/`userKey` for map key, as we do for Scala impl? This 
is somehow confusing with grouping key. Please apply this to all places.



##########
python/pyspark/sql/streaming/stateful_processor_api_client.py:
##########
@@ -154,6 +154,36 @@ def get_list_state(
             # TODO(SPARK-49233): Classify user facing errors.
             raise PySparkRuntimeError(f"Error initializing value state: " 
f"{response_message[1]}")
 
+    def get_map_state(
+        self,
+        state_name: str,
+        key_schema: Union[StructType, str],

Review Comment:
   ditto



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala:
##########
@@ -150,12 +172,12 @@ class TransformWithStateInPandasStateServer(
         val keyRow = PythonSQLUtils.toJVMRow(keyBytes, groupingKeySchema, 
keyRowDeserializer)
         ImplicitGroupingKeyTracker.setImplicitKey(keyRow)
         // Reset the list state iterators for a new grouping key.

Review Comment:
   nit: list/map?



##########
python/pyspark/sql/streaming/map_state_client.py:
##########
@@ -0,0 +1,295 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from typing import Dict, Iterator, Union, cast, Tuple, Optional
+
+from pyspark.sql.streaming.stateful_processor_api_client import 
StatefulProcessorApiClient
+from pyspark.sql.types import StructType, TYPE_CHECKING, _parse_datatype_string
+from pyspark.errors import PySparkRuntimeError
+import uuid
+
+if TYPE_CHECKING:
+    from pyspark.sql.pandas._typing import DataFrameLike as PandasDataFrameLike
+
+__all__ = ["MapStateClient"]
+
+
+class MapStateClient:
+    def __init__(
+        self,
+        stateful_processor_api_client: StatefulProcessorApiClient,
+        user_key_schema: Union[StructType, str],
+        value_schema: Union[StructType, str],
+    ) -> None:
+        self._stateful_processor_api_client = stateful_processor_api_client
+        if isinstance(user_key_schema, str):
+            self.user_key_schema = cast(StructType, 
_parse_datatype_string(user_key_schema))
+        else:
+            self.user_key_schema = user_key_schema
+        if isinstance(value_schema, str):
+            self.value_schema = cast(StructType, 
_parse_datatype_string(value_schema))
+        else:
+            self.value_schema = value_schema
+        # Dictionaries to store the mapping between iterator id and a tuple of 
pandas DataFrame
+        # and the index of the last row that was read.
+        self.key_value_dict: Dict[str, Tuple["PandasDataFrameLike", int]] = {}
+        self.dict: Dict[str, Tuple["PandasDataFrameLike", int]] = {}
+
+    def exists(self, state_name: str) -> bool:
+        import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+
+        exists_call = stateMessage.Exists()
+        map_state_call = stateMessage.MapStateCall(stateName=state_name, 
exists=exists_call)
+        state_variable_request = 
stateMessage.StateVariableRequest(mapStateCall=map_state_call)
+        message = 
stateMessage.StateRequest(stateVariableRequest=state_variable_request)
+
+        
self._stateful_processor_api_client._send_proto_message(message.SerializeToString())
+        response_message = 
self._stateful_processor_api_client._receive_proto_message()
+        status = response_message[0]
+        if status == 0:
+            return True
+        elif status == 2:
+            # Expect status code is 2 when state variable doesn't have a value.
+            return False
+        else:
+            # TODO(SPARK-49233): Classify user facing errors.
+            raise PySparkRuntimeError(f"Error checking map state exists: 
{response_message[1]}")
+
+    def get_value(self, state_name: str, key: Tuple) -> Optional[Tuple]:
+        import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+
+        bytes = 
self._stateful_processor_api_client._serialize_to_bytes(self.user_key_schema, 
key)
+        get_value_call = stateMessage.GetValue(userKey=bytes)
+        map_state_call = stateMessage.MapStateCall(stateName=state_name, 
getValue=get_value_call)
+        state_variable_request = 
stateMessage.StateVariableRequest(mapStateCall=map_state_call)
+        message = 
stateMessage.StateRequest(stateVariableRequest=state_variable_request)
+
+        
self._stateful_processor_api_client._send_proto_message(message.SerializeToString())
+        response_message = 
self._stateful_processor_api_client._receive_proto_message()
+        status = response_message[0]
+        if status == 0:
+            if len(response_message[2]) == 0:
+                return None
+            row = 
self._stateful_processor_api_client._deserialize_from_bytes(response_message[2])
+            return row
+        else:
+            # TODO(SPARK-49233): Classify user facing errors.
+            raise PySparkRuntimeError(f"Error getting value: 
{response_message[1]}")
+
+    def contains_key(self, state_name: str, key: Tuple) -> bool:
+        import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+
+        bytes = 
self._stateful_processor_api_client._serialize_to_bytes(self.user_key_schema, 
key)
+        contains_key_call = stateMessage.ContainsKey(userKey=bytes)
+        map_state_call = stateMessage.MapStateCall(
+            stateName=state_name, containsKey=contains_key_call
+        )
+        state_variable_request = 
stateMessage.StateVariableRequest(mapStateCall=map_state_call)
+        message = 
stateMessage.StateRequest(stateVariableRequest=state_variable_request)
+
+        
self._stateful_processor_api_client._send_proto_message(message.SerializeToString())
+        response_message = 
self._stateful_processor_api_client._receive_proto_message()
+        status = response_message[0]
+        if status == 0:
+            return True
+        elif status == 2:
+            # Expect status code is 2 when the given key doesn't exist in the 
map state.
+            return False
+        else:
+            # TODO(SPARK-49233): Classify user facing errors.
+            raise PySparkRuntimeError(
+                f"Error checking if map state contains key: 
{response_message[1]}"
+            )
+
+    def update_value(
+        self,
+        state_name: str,
+        key: Tuple,

Review Comment:
   ditto



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala:
##########
@@ -336,6 +345,127 @@ class TransformWithStateInPandasStateServer(
     }
   }
 
+  private[sql] def handleMapStateRequest(message: MapStateCall): Unit = {
+    val stateName = message.getStateName
+    if (!mapStates.contains(stateName)) {
+      logWarning(log"Map state ${MDC(LogKeys.STATE_NAME, stateName)} is not 
initialized.")
+      sendResponse(1, s"Map state $stateName is not initialized.")
+      return
+    }
+    val mapStateInfo = mapStates(stateName)
+    message.getMethodCase match {
+      case MapStateCall.MethodCase.EXISTS =>
+        if (mapStateInfo.mapState.exists()) {
+          sendResponse(0)
+        } else {
+          // Send status code 2 to indicate that the list state doesn't have a 
value yet.
+          sendResponse(2, s"state $stateName doesn't exist")
+        }
+      case MapStateCall.MethodCase.GETVALUE =>
+        val keyBytes = message.getGetValue.getUserKey.toByteArray
+        val keyRow = PythonSQLUtils.toJVMRow(keyBytes, mapStateInfo.keySchema,
+          mapStateInfo.keyDeserializer)
+        val value = mapStateInfo.mapState.getValue(keyRow)
+        if (value != null) {
+          val valueBytes = PythonSQLUtils.toPyRow(value)
+          val byteString = ByteString.copyFrom(valueBytes)
+          sendResponse(0, null, byteString)
+        } else {
+          logWarning(log"Map state ${MDC(LogKeys.STATE_NAME, stateName)} 
doesn't contain" +
+            log" key ${MDC(LogKeys.KEY, keyRow.toString)}.")
+          sendResponse(0)
+        }
+      case MapStateCall.MethodCase.CONTAINSKEY =>
+        val keyBytes = message.getContainsKey.getUserKey.toByteArray
+        val keyRow = PythonSQLUtils.toJVMRow(keyBytes, mapStateInfo.keySchema,
+          mapStateInfo.keyDeserializer)
+        if (mapStateInfo.mapState.containsKey(keyRow)) {
+          sendResponse(0)
+        } else {
+          sendResponse(2, s"Map state $stateName doesn't contain key 
${keyRow.toString()}")
+        }
+      case MapStateCall.MethodCase.UPDATEVALUE =>
+        val keyBytes = message.getUpdateValue.getUserKey.toByteArray
+        val keyRow = PythonSQLUtils.toJVMRow(keyBytes, mapStateInfo.keySchema,
+          mapStateInfo.keyDeserializer)
+        val valueBytes = message.getUpdateValue.getValue.toByteArray
+        val valueRow = PythonSQLUtils.toJVMRow(valueBytes, 
mapStateInfo.valueSchema,
+          mapStateInfo.valueDeserializer)
+        mapStateInfo.mapState.updateValue(keyRow, valueRow)
+        sendResponse(0)
+      case MapStateCall.MethodCase.ITERATOR =>
+        val iteratorId = message.getIterator.getIteratorId
+        var iteratorOption = keyValueIterators.get(iteratorId)
+        if (iteratorOption.isEmpty) {
+          iteratorOption = Some(mapStateInfo.mapState.iterator())
+          keyValueIterators.put(iteratorId, iteratorOption.get)
+        }
+        if (!iteratorOption.get.hasNext) {
+          sendResponse(2, s"Map state $stateName doesn't contain any entry.")
+          return
+        } else {
+          sendResponse(0)
+        }
+        val keyValueStateSchema: StructType = StructType(
+          Array(
+            // key row serialized as a byte array.
+            StructField("keyRow", BinaryType),
+            // value row serialized as a byte array.
+            StructField("valueRow", BinaryType)
+          )
+        )
+        sendIteratorAsArrowBatches(iteratorOption.get, keyValueStateSchema,
+          arrowStreamWriterForTest) {tuple =>
+          val keyBytes = PythonSQLUtils.toPyRow(tuple._1)
+          val valueBytes = PythonSQLUtils.toPyRow(tuple._2)
+          new GenericInternalRow(
+            Array[Any](
+              keyBytes,
+              valueBytes
+            )
+          )
+        }
+      case MapStateCall.MethodCase.KEYS =>
+        val iteratorId = message.getKeys.getIteratorId
+        var iteratorOption = iterators.get(iteratorId)
+        if (iteratorOption.isEmpty) {
+          iteratorOption = Some(mapStateInfo.mapState.keys())
+          iterators.put(iteratorId, iteratorOption.get)
+        }
+        if (!iteratorOption.get.hasNext) {
+          sendResponse(2, s"Map state $stateName doesn't contain any key.")
+          return

Review Comment:
   same



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala:
##########
@@ -336,6 +345,127 @@ class TransformWithStateInPandasStateServer(
     }
   }
 
+  private[sql] def handleMapStateRequest(message: MapStateCall): Unit = {
+    val stateName = message.getStateName
+    if (!mapStates.contains(stateName)) {
+      logWarning(log"Map state ${MDC(LogKeys.STATE_NAME, stateName)} is not 
initialized.")
+      sendResponse(1, s"Map state $stateName is not initialized.")
+      return
+    }
+    val mapStateInfo = mapStates(stateName)
+    message.getMethodCase match {
+      case MapStateCall.MethodCase.EXISTS =>
+        if (mapStateInfo.mapState.exists()) {
+          sendResponse(0)
+        } else {
+          // Send status code 2 to indicate that the list state doesn't have a 
value yet.
+          sendResponse(2, s"state $stateName doesn't exist")
+        }
+      case MapStateCall.MethodCase.GETVALUE =>
+        val keyBytes = message.getGetValue.getUserKey.toByteArray
+        val keyRow = PythonSQLUtils.toJVMRow(keyBytes, mapStateInfo.keySchema,
+          mapStateInfo.keyDeserializer)
+        val value = mapStateInfo.mapState.getValue(keyRow)
+        if (value != null) {
+          val valueBytes = PythonSQLUtils.toPyRow(value)
+          val byteString = ByteString.copyFrom(valueBytes)
+          sendResponse(0, null, byteString)
+        } else {
+          logWarning(log"Map state ${MDC(LogKeys.STATE_NAME, stateName)} 
doesn't contain" +
+            log" key ${MDC(LogKeys.KEY, keyRow.toString)}.")
+          sendResponse(0)
+        }
+      case MapStateCall.MethodCase.CONTAINSKEY =>
+        val keyBytes = message.getContainsKey.getUserKey.toByteArray
+        val keyRow = PythonSQLUtils.toJVMRow(keyBytes, mapStateInfo.keySchema,
+          mapStateInfo.keyDeserializer)
+        if (mapStateInfo.mapState.containsKey(keyRow)) {
+          sendResponse(0)
+        } else {
+          sendResponse(2, s"Map state $stateName doesn't contain key 
${keyRow.toString()}")
+        }
+      case MapStateCall.MethodCase.UPDATEVALUE =>
+        val keyBytes = message.getUpdateValue.getUserKey.toByteArray
+        val keyRow = PythonSQLUtils.toJVMRow(keyBytes, mapStateInfo.keySchema,
+          mapStateInfo.keyDeserializer)
+        val valueBytes = message.getUpdateValue.getValue.toByteArray
+        val valueRow = PythonSQLUtils.toJVMRow(valueBytes, 
mapStateInfo.valueSchema,
+          mapStateInfo.valueDeserializer)
+        mapStateInfo.mapState.updateValue(keyRow, valueRow)
+        sendResponse(0)
+      case MapStateCall.MethodCase.ITERATOR =>
+        val iteratorId = message.getIterator.getIteratorId
+        var iteratorOption = keyValueIterators.get(iteratorId)
+        if (iteratorOption.isEmpty) {
+          iteratorOption = Some(mapStateInfo.mapState.iterator())
+          keyValueIterators.put(iteratorId, iteratorOption.get)
+        }
+        if (!iteratorOption.get.hasNext) {
+          sendResponse(2, s"Map state $stateName doesn't contain any entry.")
+          return
+        } else {
+          sendResponse(0)
+        }
+        val keyValueStateSchema: StructType = StructType(
+          Array(
+            // key row serialized as a byte array.
+            StructField("keyRow", BinaryType),
+            // value row serialized as a byte array.
+            StructField("valueRow", BinaryType)
+          )
+        )
+        sendIteratorAsArrowBatches(iteratorOption.get, keyValueStateSchema,
+          arrowStreamWriterForTest) {tuple =>
+          val keyBytes = PythonSQLUtils.toPyRow(tuple._1)
+          val valueBytes = PythonSQLUtils.toPyRow(tuple._2)
+          new GenericInternalRow(
+            Array[Any](
+              keyBytes,
+              valueBytes
+            )
+          )
+        }
+      case MapStateCall.MethodCase.KEYS =>
+        val iteratorId = message.getKeys.getIteratorId
+        var iteratorOption = iterators.get(iteratorId)
+        if (iteratorOption.isEmpty) {
+          iteratorOption = Some(mapStateInfo.mapState.keys())
+          iterators.put(iteratorId, iteratorOption.get)
+        }
+        if (!iteratorOption.get.hasNext) {
+          sendResponse(2, s"Map state $stateName doesn't contain any key.")
+          return
+        } else {
+          sendResponse(0)
+        }
+        sendIteratorAsArrowBatches(iteratorOption.get, mapStateInfo.keySchema,
+          arrowStreamWriterForTest) {data => mapStateInfo.keySerializer(data)}
+      case MapStateCall.MethodCase.VALUES =>
+        val iteratorId = message.getValues.getIteratorId
+        var iteratorOption = iterators.get(iteratorId)
+        if (iteratorOption.isEmpty) {
+          iteratorOption = Some(mapStateInfo.mapState.values())
+          iterators.put(iteratorId, iteratorOption.get)
+        }
+        if (!iteratorOption.get.hasNext) {
+          sendResponse(2, s"Map state $stateName doesn't contain any value.")
+          return

Review Comment:
   same



##########
python/pyspark/sql/streaming/map_state_client.py:
##########
@@ -0,0 +1,295 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from typing import Dict, Iterator, Union, cast, Tuple, Optional
+
+from pyspark.sql.streaming.stateful_processor_api_client import 
StatefulProcessorApiClient
+from pyspark.sql.types import StructType, TYPE_CHECKING, _parse_datatype_string
+from pyspark.errors import PySparkRuntimeError
+import uuid
+
+if TYPE_CHECKING:
+    from pyspark.sql.pandas._typing import DataFrameLike as PandasDataFrameLike
+
+__all__ = ["MapStateClient"]
+
+
+class MapStateClient:
+    def __init__(
+        self,
+        stateful_processor_api_client: StatefulProcessorApiClient,
+        user_key_schema: Union[StructType, str],
+        value_schema: Union[StructType, str],
+    ) -> None:
+        self._stateful_processor_api_client = stateful_processor_api_client
+        if isinstance(user_key_schema, str):
+            self.user_key_schema = cast(StructType, 
_parse_datatype_string(user_key_schema))
+        else:
+            self.user_key_schema = user_key_schema
+        if isinstance(value_schema, str):
+            self.value_schema = cast(StructType, 
_parse_datatype_string(value_schema))
+        else:
+            self.value_schema = value_schema
+        # Dictionaries to store the mapping between iterator id and a tuple of 
pandas DataFrame
+        # and the index of the last row that was read.
+        self.key_value_dict: Dict[str, Tuple["PandasDataFrameLike", int]] = {}

Review Comment:
   nit: shall we use "longer but self describe" name? for example, 
`key_value_pair_iterator_cursors` or so. 
   
   Same with below, `key_or_value_iterator_cursors` or so. I'm even OK to use 
the different dict for keys / values requests.
   



##########
python/pyspark/sql/streaming/map_state_client.py:
##########
@@ -0,0 +1,295 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from typing import Dict, Iterator, Union, cast, Tuple, Optional
+
+from pyspark.sql.streaming.stateful_processor_api_client import 
StatefulProcessorApiClient
+from pyspark.sql.types import StructType, TYPE_CHECKING, _parse_datatype_string
+from pyspark.errors import PySparkRuntimeError
+import uuid
+
+if TYPE_CHECKING:
+    from pyspark.sql.pandas._typing import DataFrameLike as PandasDataFrameLike
+
+__all__ = ["MapStateClient"]
+
+
+class MapStateClient:
+    def __init__(
+        self,
+        stateful_processor_api_client: StatefulProcessorApiClient,
+        user_key_schema: Union[StructType, str],
+        value_schema: Union[StructType, str],
+    ) -> None:
+        self._stateful_processor_api_client = stateful_processor_api_client
+        if isinstance(user_key_schema, str):
+            self.user_key_schema = cast(StructType, 
_parse_datatype_string(user_key_schema))
+        else:
+            self.user_key_schema = user_key_schema
+        if isinstance(value_schema, str):
+            self.value_schema = cast(StructType, 
_parse_datatype_string(value_schema))
+        else:
+            self.value_schema = value_schema
+        # Dictionaries to store the mapping between iterator id and a tuple of 
pandas DataFrame
+        # and the index of the last row that was read.
+        self.key_value_dict: Dict[str, Tuple["PandasDataFrameLike", int]] = {}
+        self.dict: Dict[str, Tuple["PandasDataFrameLike", int]] = {}
+
+    def exists(self, state_name: str) -> bool:
+        import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+
+        exists_call = stateMessage.Exists()
+        map_state_call = stateMessage.MapStateCall(stateName=state_name, 
exists=exists_call)
+        state_variable_request = 
stateMessage.StateVariableRequest(mapStateCall=map_state_call)
+        message = 
stateMessage.StateRequest(stateVariableRequest=state_variable_request)
+
+        
self._stateful_processor_api_client._send_proto_message(message.SerializeToString())
+        response_message = 
self._stateful_processor_api_client._receive_proto_message()
+        status = response_message[0]
+        if status == 0:
+            return True
+        elif status == 2:
+            # Expect status code is 2 when state variable doesn't have a value.
+            return False
+        else:
+            # TODO(SPARK-49233): Classify user facing errors.
+            raise PySparkRuntimeError(f"Error checking map state exists: 
{response_message[1]}")
+
+    def get_value(self, state_name: str, key: Tuple) -> Optional[Tuple]:

Review Comment:
   ditto (user_key)



##########
python/pyspark/sql/streaming/map_state_client.py:
##########
@@ -0,0 +1,295 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from typing import Dict, Iterator, Union, cast, Tuple, Optional
+
+from pyspark.sql.streaming.stateful_processor_api_client import 
StatefulProcessorApiClient
+from pyspark.sql.types import StructType, TYPE_CHECKING, _parse_datatype_string
+from pyspark.errors import PySparkRuntimeError
+import uuid
+
+if TYPE_CHECKING:
+    from pyspark.sql.pandas._typing import DataFrameLike as PandasDataFrameLike
+
+__all__ = ["MapStateClient"]
+
+
+class MapStateClient:
+    def __init__(
+        self,
+        stateful_processor_api_client: StatefulProcessorApiClient,
+        user_key_schema: Union[StructType, str],
+        value_schema: Union[StructType, str],
+    ) -> None:
+        self._stateful_processor_api_client = stateful_processor_api_client
+        if isinstance(user_key_schema, str):
+            self.user_key_schema = cast(StructType, 
_parse_datatype_string(user_key_schema))
+        else:
+            self.user_key_schema = user_key_schema
+        if isinstance(value_schema, str):
+            self.value_schema = cast(StructType, 
_parse_datatype_string(value_schema))
+        else:
+            self.value_schema = value_schema
+        # Dictionaries to store the mapping between iterator id and a tuple of 
pandas DataFrame
+        # and the index of the last row that was read.
+        self.key_value_dict: Dict[str, Tuple["PandasDataFrameLike", int]] = {}
+        self.dict: Dict[str, Tuple["PandasDataFrameLike", int]] = {}
+
+    def exists(self, state_name: str) -> bool:
+        import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+
+        exists_call = stateMessage.Exists()
+        map_state_call = stateMessage.MapStateCall(stateName=state_name, 
exists=exists_call)
+        state_variable_request = 
stateMessage.StateVariableRequest(mapStateCall=map_state_call)
+        message = 
stateMessage.StateRequest(stateVariableRequest=state_variable_request)
+
+        
self._stateful_processor_api_client._send_proto_message(message.SerializeToString())
+        response_message = 
self._stateful_processor_api_client._receive_proto_message()
+        status = response_message[0]
+        if status == 0:
+            return True
+        elif status == 2:
+            # Expect status code is 2 when state variable doesn't have a value.
+            return False
+        else:
+            # TODO(SPARK-49233): Classify user facing errors.
+            raise PySparkRuntimeError(f"Error checking map state exists: 
{response_message[1]}")
+
+    def get_value(self, state_name: str, key: Tuple) -> Optional[Tuple]:
+        import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+
+        bytes = 
self._stateful_processor_api_client._serialize_to_bytes(self.user_key_schema, 
key)
+        get_value_call = stateMessage.GetValue(userKey=bytes)
+        map_state_call = stateMessage.MapStateCall(stateName=state_name, 
getValue=get_value_call)
+        state_variable_request = 
stateMessage.StateVariableRequest(mapStateCall=map_state_call)
+        message = 
stateMessage.StateRequest(stateVariableRequest=state_variable_request)
+
+        
self._stateful_processor_api_client._send_proto_message(message.SerializeToString())
+        response_message = 
self._stateful_processor_api_client._receive_proto_message()
+        status = response_message[0]
+        if status == 0:
+            if len(response_message[2]) == 0:
+                return None
+            row = 
self._stateful_processor_api_client._deserialize_from_bytes(response_message[2])
+            return row
+        else:
+            # TODO(SPARK-49233): Classify user facing errors.
+            raise PySparkRuntimeError(f"Error getting value: 
{response_message[1]}")
+
+    def contains_key(self, state_name: str, key: Tuple) -> bool:
+        import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+
+        bytes = 
self._stateful_processor_api_client._serialize_to_bytes(self.user_key_schema, 
key)
+        contains_key_call = stateMessage.ContainsKey(userKey=bytes)
+        map_state_call = stateMessage.MapStateCall(
+            stateName=state_name, containsKey=contains_key_call
+        )
+        state_variable_request = 
stateMessage.StateVariableRequest(mapStateCall=map_state_call)
+        message = 
stateMessage.StateRequest(stateVariableRequest=state_variable_request)
+
+        
self._stateful_processor_api_client._send_proto_message(message.SerializeToString())
+        response_message = 
self._stateful_processor_api_client._receive_proto_message()
+        status = response_message[0]
+        if status == 0:
+            return True
+        elif status == 2:
+            # Expect status code is 2 when the given key doesn't exist in the 
map state.
+            return False
+        else:
+            # TODO(SPARK-49233): Classify user facing errors.
+            raise PySparkRuntimeError(
+                f"Error checking if map state contains key: 
{response_message[1]}"
+            )
+
+    def update_value(
+        self,
+        state_name: str,
+        key: Tuple,
+        value: Tuple,
+    ) -> None:
+        import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+
+        key_bytes = self._stateful_processor_api_client._serialize_to_bytes(
+            self.user_key_schema, key
+        )
+        value_bytes = self._stateful_processor_api_client._serialize_to_bytes(
+            self.value_schema, value
+        )
+        update_value_call = stateMessage.UpdateValue(userKey=key_bytes, 
value=value_bytes)
+        map_state_call = stateMessage.MapStateCall(
+            stateName=state_name, updateValue=update_value_call
+        )
+        state_variable_request = 
stateMessage.StateVariableRequest(mapStateCall=map_state_call)
+        message = 
stateMessage.StateRequest(stateVariableRequest=state_variable_request)
+
+        
self._stateful_processor_api_client._send_proto_message(message.SerializeToString())
+        response_message = 
self._stateful_processor_api_client._receive_proto_message()
+        status = response_message[0]
+        if status != 0:
+            # TODO(SPARK-49233): Classify user facing errors.
+            raise PySparkRuntimeError(f"Error updating map state value: 
{response_message[1]}")
+
+    def get_key_value_pair(self, state_name: str, iterator_id: str) -> 
Tuple[Tuple, Tuple]:
+        import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+
+        if iterator_id in self.key_value_dict:
+            # If the state is already in the dictionary, return the next row.
+            pandas_df, index = self.key_value_dict[iterator_id]
+        else:
+            # If the state is not in the dictionary, fetch the state from the 
server.
+            iterator_call = stateMessage.Iterator(iteratorId=iterator_id)
+            map_state_call = stateMessage.MapStateCall(stateName=state_name, 
iterator=iterator_call)
+            state_variable_request = 
stateMessage.StateVariableRequest(mapStateCall=map_state_call)
+            message = 
stateMessage.StateRequest(stateVariableRequest=state_variable_request)
+
+            
self._stateful_processor_api_client._send_proto_message(message.SerializeToString())
+            response_message = 
self._stateful_processor_api_client._receive_proto_message()
+            status = response_message[0]
+            if status == 0:
+                iterator = 
self._stateful_processor_api_client._read_arrow_state()
+                data_batch = None
+                for batch in iterator:
+                    if data_batch is None:
+                        data_batch = batch
+                if data_batch is None:
+                    # TODO(SPARK-49233): Classify user facing errors.
+                    raise PySparkRuntimeError("Error getting map state entry.")
+                pandas_df = data_batch.to_pandas()
+                index = 0
+            else:
+                raise StopIteration()
+
+        new_index = index + 1
+        if new_index < len(pandas_df):
+            # Update the index in the dictionary.
+            self.key_value_dict[iterator_id] = (pandas_df, new_index)
+        else:
+            # If the index is at the end of the DataFrame, remove the state 
from the dictionary.
+            self.key_value_dict.pop(iterator_id, None)
+        key_row_bytes = pandas_df.iloc[index, 0]
+        value_row_bytes = pandas_df.iloc[index, 1]
+        key_row = 
self._stateful_processor_api_client._deserialize_from_bytes(key_row_bytes)
+        value_row = 
self._stateful_processor_api_client._deserialize_from_bytes(value_row_bytes)
+        return tuple(key_row), tuple(value_row)
+
+    def get_row(self, state_name: str, iterator_id: str, is_key: bool) -> 
Tuple:
+        import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+
+        if iterator_id in self.dict:
+            # If the state is already in the dictionary, return the next row.
+            pandas_df, index = self.dict[iterator_id]
+        else:
+            # If the state is not in the dictionary, fetch the state from the 
server.
+            if is_key:
+                keys_call = stateMessage.Keys(iteratorId=iterator_id)
+                map_state_call = 
stateMessage.MapStateCall(stateName=state_name, keys=keys_call)
+            else:
+                values_call = stateMessage.Values(iteratorId=iterator_id)
+                map_state_call = 
stateMessage.MapStateCall(stateName=state_name, values=values_call)
+            state_variable_request = 
stateMessage.StateVariableRequest(mapStateCall=map_state_call)
+            message = 
stateMessage.StateRequest(stateVariableRequest=state_variable_request)
+
+            
self._stateful_processor_api_client._send_proto_message(message.SerializeToString())
+            response_message = 
self._stateful_processor_api_client._receive_proto_message()
+            status = response_message[0]
+            if status == 0:
+                iterator = 
self._stateful_processor_api_client._read_arrow_state()
+                data_batch = None

Review Comment:
   ditto



##########
python/pyspark/sql/streaming/map_state_client.py:
##########
@@ -0,0 +1,295 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from typing import Dict, Iterator, Union, cast, Tuple, Optional
+
+from pyspark.sql.streaming.stateful_processor_api_client import 
StatefulProcessorApiClient
+from pyspark.sql.types import StructType, TYPE_CHECKING, _parse_datatype_string
+from pyspark.errors import PySparkRuntimeError
+import uuid
+
+if TYPE_CHECKING:
+    from pyspark.sql.pandas._typing import DataFrameLike as PandasDataFrameLike
+
+__all__ = ["MapStateClient"]
+
+
+class MapStateClient:
+    def __init__(
+        self,
+        stateful_processor_api_client: StatefulProcessorApiClient,
+        user_key_schema: Union[StructType, str],
+        value_schema: Union[StructType, str],
+    ) -> None:
+        self._stateful_processor_api_client = stateful_processor_api_client
+        if isinstance(user_key_schema, str):
+            self.user_key_schema = cast(StructType, 
_parse_datatype_string(user_key_schema))
+        else:
+            self.user_key_schema = user_key_schema
+        if isinstance(value_schema, str):
+            self.value_schema = cast(StructType, 
_parse_datatype_string(value_schema))
+        else:
+            self.value_schema = value_schema
+        # Dictionaries to store the mapping between iterator id and a tuple of 
pandas DataFrame
+        # and the index of the last row that was read.
+        self.key_value_dict: Dict[str, Tuple["PandasDataFrameLike", int]] = {}
+        self.dict: Dict[str, Tuple["PandasDataFrameLike", int]] = {}
+
+    def exists(self, state_name: str) -> bool:
+        import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+
+        exists_call = stateMessage.Exists()
+        map_state_call = stateMessage.MapStateCall(stateName=state_name, 
exists=exists_call)
+        state_variable_request = 
stateMessage.StateVariableRequest(mapStateCall=map_state_call)
+        message = 
stateMessage.StateRequest(stateVariableRequest=state_variable_request)
+
+        
self._stateful_processor_api_client._send_proto_message(message.SerializeToString())
+        response_message = 
self._stateful_processor_api_client._receive_proto_message()
+        status = response_message[0]
+        if status == 0:
+            return True
+        elif status == 2:
+            # Expect status code is 2 when state variable doesn't have a value.
+            return False
+        else:
+            # TODO(SPARK-49233): Classify user facing errors.
+            raise PySparkRuntimeError(f"Error checking map state exists: 
{response_message[1]}")
+
+    def get_value(self, state_name: str, key: Tuple) -> Optional[Tuple]:
+        import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+
+        bytes = 
self._stateful_processor_api_client._serialize_to_bytes(self.user_key_schema, 
key)
+        get_value_call = stateMessage.GetValue(userKey=bytes)
+        map_state_call = stateMessage.MapStateCall(stateName=state_name, 
getValue=get_value_call)
+        state_variable_request = 
stateMessage.StateVariableRequest(mapStateCall=map_state_call)
+        message = 
stateMessage.StateRequest(stateVariableRequest=state_variable_request)
+
+        
self._stateful_processor_api_client._send_proto_message(message.SerializeToString())
+        response_message = 
self._stateful_processor_api_client._receive_proto_message()
+        status = response_message[0]
+        if status == 0:
+            if len(response_message[2]) == 0:
+                return None
+            row = 
self._stateful_processor_api_client._deserialize_from_bytes(response_message[2])
+            return row
+        else:
+            # TODO(SPARK-49233): Classify user facing errors.
+            raise PySparkRuntimeError(f"Error getting value: 
{response_message[1]}")
+
+    def contains_key(self, state_name: str, key: Tuple) -> bool:

Review Comment:
   ditto



##########
python/pyspark/sql/streaming/map_state_client.py:
##########
@@ -0,0 +1,295 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from typing import Dict, Iterator, Union, cast, Tuple, Optional
+
+from pyspark.sql.streaming.stateful_processor_api_client import 
StatefulProcessorApiClient
+from pyspark.sql.types import StructType, TYPE_CHECKING, _parse_datatype_string
+from pyspark.errors import PySparkRuntimeError
+import uuid
+
+if TYPE_CHECKING:
+    from pyspark.sql.pandas._typing import DataFrameLike as PandasDataFrameLike
+
+__all__ = ["MapStateClient"]
+
+
+class MapStateClient:
+    def __init__(
+        self,
+        stateful_processor_api_client: StatefulProcessorApiClient,
+        user_key_schema: Union[StructType, str],
+        value_schema: Union[StructType, str],
+    ) -> None:
+        self._stateful_processor_api_client = stateful_processor_api_client
+        if isinstance(user_key_schema, str):
+            self.user_key_schema = cast(StructType, 
_parse_datatype_string(user_key_schema))
+        else:
+            self.user_key_schema = user_key_schema
+        if isinstance(value_schema, str):
+            self.value_schema = cast(StructType, 
_parse_datatype_string(value_schema))
+        else:
+            self.value_schema = value_schema
+        # Dictionaries to store the mapping between iterator id and a tuple of 
pandas DataFrame
+        # and the index of the last row that was read.
+        self.key_value_dict: Dict[str, Tuple["PandasDataFrameLike", int]] = {}
+        self.dict: Dict[str, Tuple["PandasDataFrameLike", int]] = {}
+
+    def exists(self, state_name: str) -> bool:
+        import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+
+        exists_call = stateMessage.Exists()
+        map_state_call = stateMessage.MapStateCall(stateName=state_name, 
exists=exists_call)
+        state_variable_request = 
stateMessage.StateVariableRequest(mapStateCall=map_state_call)
+        message = 
stateMessage.StateRequest(stateVariableRequest=state_variable_request)
+
+        
self._stateful_processor_api_client._send_proto_message(message.SerializeToString())
+        response_message = 
self._stateful_processor_api_client._receive_proto_message()
+        status = response_message[0]
+        if status == 0:
+            return True
+        elif status == 2:
+            # Expect status code is 2 when state variable doesn't have a value.
+            return False
+        else:
+            # TODO(SPARK-49233): Classify user facing errors.
+            raise PySparkRuntimeError(f"Error checking map state exists: 
{response_message[1]}")
+
+    def get_value(self, state_name: str, key: Tuple) -> Optional[Tuple]:
+        import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+
+        bytes = 
self._stateful_processor_api_client._serialize_to_bytes(self.user_key_schema, 
key)
+        get_value_call = stateMessage.GetValue(userKey=bytes)
+        map_state_call = stateMessage.MapStateCall(stateName=state_name, 
getValue=get_value_call)
+        state_variable_request = 
stateMessage.StateVariableRequest(mapStateCall=map_state_call)
+        message = 
stateMessage.StateRequest(stateVariableRequest=state_variable_request)
+
+        
self._stateful_processor_api_client._send_proto_message(message.SerializeToString())
+        response_message = 
self._stateful_processor_api_client._receive_proto_message()
+        status = response_message[0]
+        if status == 0:
+            if len(response_message[2]) == 0:
+                return None
+            row = 
self._stateful_processor_api_client._deserialize_from_bytes(response_message[2])
+            return row
+        else:
+            # TODO(SPARK-49233): Classify user facing errors.
+            raise PySparkRuntimeError(f"Error getting value: 
{response_message[1]}")
+
+    def contains_key(self, state_name: str, key: Tuple) -> bool:
+        import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+
+        bytes = 
self._stateful_processor_api_client._serialize_to_bytes(self.user_key_schema, 
key)
+        contains_key_call = stateMessage.ContainsKey(userKey=bytes)
+        map_state_call = stateMessage.MapStateCall(
+            stateName=state_name, containsKey=contains_key_call
+        )
+        state_variable_request = 
stateMessage.StateVariableRequest(mapStateCall=map_state_call)
+        message = 
stateMessage.StateRequest(stateVariableRequest=state_variable_request)
+
+        
self._stateful_processor_api_client._send_proto_message(message.SerializeToString())
+        response_message = 
self._stateful_processor_api_client._receive_proto_message()
+        status = response_message[0]
+        if status == 0:
+            return True
+        elif status == 2:
+            # Expect status code is 2 when the given key doesn't exist in the 
map state.
+            return False
+        else:
+            # TODO(SPARK-49233): Classify user facing errors.
+            raise PySparkRuntimeError(
+                f"Error checking if map state contains key: 
{response_message[1]}"
+            )
+
+    def update_value(
+        self,
+        state_name: str,
+        key: Tuple,
+        value: Tuple,
+    ) -> None:
+        import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+
+        key_bytes = self._stateful_processor_api_client._serialize_to_bytes(
+            self.user_key_schema, key
+        )
+        value_bytes = self._stateful_processor_api_client._serialize_to_bytes(
+            self.value_schema, value
+        )
+        update_value_call = stateMessage.UpdateValue(userKey=key_bytes, 
value=value_bytes)
+        map_state_call = stateMessage.MapStateCall(
+            stateName=state_name, updateValue=update_value_call
+        )
+        state_variable_request = 
stateMessage.StateVariableRequest(mapStateCall=map_state_call)
+        message = 
stateMessage.StateRequest(stateVariableRequest=state_variable_request)
+
+        
self._stateful_processor_api_client._send_proto_message(message.SerializeToString())
+        response_message = 
self._stateful_processor_api_client._receive_proto_message()
+        status = response_message[0]
+        if status != 0:
+            # TODO(SPARK-49233): Classify user facing errors.
+            raise PySparkRuntimeError(f"Error updating map state value: 
{response_message[1]}")
+
+    def get_key_value_pair(self, state_name: str, iterator_id: str) -> 
Tuple[Tuple, Tuple]:
+        import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+
+        if iterator_id in self.key_value_dict:
+            # If the state is already in the dictionary, return the next row.
+            pandas_df, index = self.key_value_dict[iterator_id]
+        else:
+            # If the state is not in the dictionary, fetch the state from the 
server.
+            iterator_call = stateMessage.Iterator(iteratorId=iterator_id)
+            map_state_call = stateMessage.MapStateCall(stateName=state_name, 
iterator=iterator_call)
+            state_variable_request = 
stateMessage.StateVariableRequest(mapStateCall=map_state_call)
+            message = 
stateMessage.StateRequest(stateVariableRequest=state_variable_request)
+
+            
self._stateful_processor_api_client._send_proto_message(message.SerializeToString())
+            response_message = 
self._stateful_processor_api_client._receive_proto_message()
+            status = response_message[0]
+            if status == 0:
+                iterator = 
self._stateful_processor_api_client._read_arrow_state()
+                data_batch = None

Review Comment:
   same here: let's leave code comment likewise I and Jing suggested above.
   
   Btw, I see we use the approach you newly introduced in list state - I'm OK 
to fix list state as well, but please explicitly mention the change in PR 
description.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala:
##########
@@ -150,12 +172,12 @@ class TransformWithStateInPandasStateServer(
         val keyRow = PythonSQLUtils.toJVMRow(keyBytes, groupingKeySchema, 
keyRowDeserializer)
         ImplicitGroupingKeyTracker.setImplicitKey(keyRow)
         // Reset the list state iterators for a new grouping key.
-        listStateIterators = new mutable.HashMap[String, Iterator[Row]]()
+        iterators = new mutable.HashMap[String, Iterator[Row]]()
         sendResponse(0)
       case ImplicitGroupingKeyRequest.MethodCase.REMOVEIMPLICITKEY =>
         ImplicitGroupingKeyTracker.removeImplicitKey()
         // Reset the list state iterators for a new grouping key.

Review Comment:
   ditto



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala:
##########
@@ -195,6 +217,15 @@ class TransformWithStateInPandasStateServer(
             None
         }
         initializeStateVariable(stateName, schema, 
StateVariableType.ListState, ttlDurationMs)
+      case StatefulProcessorCall.MethodCase.GETMAPSTATE =>
+        val stateName = message.getGetMapState.getStateName
+        val keySchema = message.getGetMapState.getSchema

Review Comment:
   nit: userKeySchema



##########
python/pyspark/sql/streaming/map_state_client.py:
##########
@@ -0,0 +1,295 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from typing import Dict, Iterator, Union, cast, Tuple, Optional
+
+from pyspark.sql.streaming.stateful_processor_api_client import 
StatefulProcessorApiClient
+from pyspark.sql.types import StructType, TYPE_CHECKING, _parse_datatype_string
+from pyspark.errors import PySparkRuntimeError
+import uuid
+
+if TYPE_CHECKING:
+    from pyspark.sql.pandas._typing import DataFrameLike as PandasDataFrameLike
+
+__all__ = ["MapStateClient"]
+
+
+class MapStateClient:
+    def __init__(
+        self,
+        stateful_processor_api_client: StatefulProcessorApiClient,
+        user_key_schema: Union[StructType, str],
+        value_schema: Union[StructType, str],
+    ) -> None:
+        self._stateful_processor_api_client = stateful_processor_api_client
+        if isinstance(user_key_schema, str):
+            self.user_key_schema = cast(StructType, 
_parse_datatype_string(user_key_schema))
+        else:
+            self.user_key_schema = user_key_schema
+        if isinstance(value_schema, str):
+            self.value_schema = cast(StructType, 
_parse_datatype_string(value_schema))
+        else:
+            self.value_schema = value_schema
+        # Dictionaries to store the mapping between iterator id and a tuple of 
pandas DataFrame
+        # and the index of the last row that was read.
+        self.key_value_dict: Dict[str, Tuple["PandasDataFrameLike", int]] = {}
+        self.dict: Dict[str, Tuple["PandasDataFrameLike", int]] = {}
+
+    def exists(self, state_name: str) -> bool:
+        import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+
+        exists_call = stateMessage.Exists()
+        map_state_call = stateMessage.MapStateCall(stateName=state_name, 
exists=exists_call)
+        state_variable_request = 
stateMessage.StateVariableRequest(mapStateCall=map_state_call)
+        message = 
stateMessage.StateRequest(stateVariableRequest=state_variable_request)
+
+        
self._stateful_processor_api_client._send_proto_message(message.SerializeToString())
+        response_message = 
self._stateful_processor_api_client._receive_proto_message()
+        status = response_message[0]
+        if status == 0:
+            return True
+        elif status == 2:
+            # Expect status code is 2 when state variable doesn't have a value.
+            return False
+        else:
+            # TODO(SPARK-49233): Classify user facing errors.
+            raise PySparkRuntimeError(f"Error checking map state exists: 
{response_message[1]}")
+
+    def get_value(self, state_name: str, key: Tuple) -> Optional[Tuple]:
+        import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+
+        bytes = 
self._stateful_processor_api_client._serialize_to_bytes(self.user_key_schema, 
key)
+        get_value_call = stateMessage.GetValue(userKey=bytes)
+        map_state_call = stateMessage.MapStateCall(stateName=state_name, 
getValue=get_value_call)
+        state_variable_request = 
stateMessage.StateVariableRequest(mapStateCall=map_state_call)
+        message = 
stateMessage.StateRequest(stateVariableRequest=state_variable_request)
+
+        
self._stateful_processor_api_client._send_proto_message(message.SerializeToString())
+        response_message = 
self._stateful_processor_api_client._receive_proto_message()
+        status = response_message[0]
+        if status == 0:
+            if len(response_message[2]) == 0:
+                return None
+            row = 
self._stateful_processor_api_client._deserialize_from_bytes(response_message[2])
+            return row
+        else:
+            # TODO(SPARK-49233): Classify user facing errors.
+            raise PySparkRuntimeError(f"Error getting value: 
{response_message[1]}")
+
+    def contains_key(self, state_name: str, key: Tuple) -> bool:
+        import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+
+        bytes = 
self._stateful_processor_api_client._serialize_to_bytes(self.user_key_schema, 
key)
+        contains_key_call = stateMessage.ContainsKey(userKey=bytes)
+        map_state_call = stateMessage.MapStateCall(
+            stateName=state_name, containsKey=contains_key_call
+        )
+        state_variable_request = 
stateMessage.StateVariableRequest(mapStateCall=map_state_call)
+        message = 
stateMessage.StateRequest(stateVariableRequest=state_variable_request)
+
+        
self._stateful_processor_api_client._send_proto_message(message.SerializeToString())
+        response_message = 
self._stateful_processor_api_client._receive_proto_message()
+        status = response_message[0]
+        if status == 0:
+            return True
+        elif status == 2:
+            # Expect status code is 2 when the given key doesn't exist in the 
map state.
+            return False
+        else:
+            # TODO(SPARK-49233): Classify user facing errors.
+            raise PySparkRuntimeError(
+                f"Error checking if map state contains key: 
{response_message[1]}"
+            )
+
+    def update_value(
+        self,
+        state_name: str,
+        key: Tuple,
+        value: Tuple,
+    ) -> None:
+        import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+
+        key_bytes = self._stateful_processor_api_client._serialize_to_bytes(
+            self.user_key_schema, key
+        )
+        value_bytes = self._stateful_processor_api_client._serialize_to_bytes(
+            self.value_schema, value
+        )
+        update_value_call = stateMessage.UpdateValue(userKey=key_bytes, 
value=value_bytes)
+        map_state_call = stateMessage.MapStateCall(
+            stateName=state_name, updateValue=update_value_call
+        )
+        state_variable_request = 
stateMessage.StateVariableRequest(mapStateCall=map_state_call)
+        message = 
stateMessage.StateRequest(stateVariableRequest=state_variable_request)
+
+        
self._stateful_processor_api_client._send_proto_message(message.SerializeToString())
+        response_message = 
self._stateful_processor_api_client._receive_proto_message()
+        status = response_message[0]
+        if status != 0:
+            # TODO(SPARK-49233): Classify user facing errors.
+            raise PySparkRuntimeError(f"Error updating map state value: 
{response_message[1]}")
+
+    def get_key_value_pair(self, state_name: str, iterator_id: str) -> 
Tuple[Tuple, Tuple]:
+        import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+
+        if iterator_id in self.key_value_dict:
+            # If the state is already in the dictionary, return the next row.
+            pandas_df, index = self.key_value_dict[iterator_id]
+        else:
+            # If the state is not in the dictionary, fetch the state from the 
server.
+            iterator_call = stateMessage.Iterator(iteratorId=iterator_id)
+            map_state_call = stateMessage.MapStateCall(stateName=state_name, 
iterator=iterator_call)
+            state_variable_request = 
stateMessage.StateVariableRequest(mapStateCall=map_state_call)
+            message = 
stateMessage.StateRequest(stateVariableRequest=state_variable_request)
+
+            
self._stateful_processor_api_client._send_proto_message(message.SerializeToString())
+            response_message = 
self._stateful_processor_api_client._receive_proto_message()
+            status = response_message[0]
+            if status == 0:
+                iterator = 
self._stateful_processor_api_client._read_arrow_state()
+                data_batch = None
+                for batch in iterator:
+                    if data_batch is None:
+                        data_batch = batch
+                if data_batch is None:
+                    # TODO(SPARK-49233): Classify user facing errors.
+                    raise PySparkRuntimeError("Error getting map state entry.")
+                pandas_df = data_batch.to_pandas()
+                index = 0
+            else:
+                raise StopIteration()
+
+        new_index = index + 1
+        if new_index < len(pandas_df):
+            # Update the index in the dictionary.
+            self.key_value_dict[iterator_id] = (pandas_df, new_index)
+        else:
+            # If the index is at the end of the DataFrame, remove the state 
from the dictionary.
+            self.key_value_dict.pop(iterator_id, None)
+        key_row_bytes = pandas_df.iloc[index, 0]
+        value_row_bytes = pandas_df.iloc[index, 1]
+        key_row = 
self._stateful_processor_api_client._deserialize_from_bytes(key_row_bytes)
+        value_row = 
self._stateful_processor_api_client._deserialize_from_bytes(value_row_bytes)
+        return tuple(key_row), tuple(value_row)
+
+    def get_row(self, state_name: str, iterator_id: str, is_key: bool) -> 
Tuple:
+        import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+
+        if iterator_id in self.dict:
+            # If the state is already in the dictionary, return the next row.
+            pandas_df, index = self.dict[iterator_id]
+        else:
+            # If the state is not in the dictionary, fetch the state from the 
server.
+            if is_key:
+                keys_call = stateMessage.Keys(iteratorId=iterator_id)
+                map_state_call = 
stateMessage.MapStateCall(stateName=state_name, keys=keys_call)
+            else:
+                values_call = stateMessage.Values(iteratorId=iterator_id)
+                map_state_call = 
stateMessage.MapStateCall(stateName=state_name, values=values_call)
+            state_variable_request = 
stateMessage.StateVariableRequest(mapStateCall=map_state_call)
+            message = 
stateMessage.StateRequest(stateVariableRequest=state_variable_request)
+
+            
self._stateful_processor_api_client._send_proto_message(message.SerializeToString())
+            response_message = 
self._stateful_processor_api_client._receive_proto_message()
+            status = response_message[0]
+            if status == 0:
+                iterator = 
self._stateful_processor_api_client._read_arrow_state()
+                data_batch = None
+                for batch in iterator:
+                    if data_batch is None:
+                        data_batch = batch
+                if data_batch is None:
+                    entry_name = "key"
+                    if not is_key:
+                        entry_name = "value"
+                    # TODO(SPARK-49233): Classify user facing errors.
+                    raise PySparkRuntimeError(f"Error getting map state 
{entry_name}.")
+                pandas_df = data_batch.to_pandas()
+                index = 0
+            else:
+                raise StopIteration()
+
+        new_index = index + 1
+        if new_index < len(pandas_df):
+            # Update the index in the dictionary.
+            self.dict[iterator_id] = (pandas_df, new_index)
+        else:
+            # If the index is at the end of the DataFrame, remove the state 
from the dictionary.
+            self.dict.pop(iterator_id, None)
+        pandas_row = pandas_df.iloc[index]
+        return tuple(pandas_row)
+
+    def remove_key(self, state_name: str, key: Tuple) -> None:
+        import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+
+        bytes = 
self._stateful_processor_api_client._serialize_to_bytes(self.user_key_schema, 
key)
+        remove_key_call = stateMessage.RemoveKey(userKey=bytes)
+        map_state_call = stateMessage.MapStateCall(stateName=state_name, 
removeKey=remove_key_call)
+        state_variable_request = 
stateMessage.StateVariableRequest(mapStateCall=map_state_call)
+        message = 
stateMessage.StateRequest(stateVariableRequest=state_variable_request)
+
+        
self._stateful_processor_api_client._send_proto_message(message.SerializeToString())
+        response_message = 
self._stateful_processor_api_client._receive_proto_message()
+        status = response_message[0]
+        if status != 0:
+            # TODO(SPARK-49233): Classify user facing errors.
+            raise PySparkRuntimeError(f"Error removing key from map state: 
{response_message[1]}")
+
+    def clear(self, state_name: str) -> None:
+        import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+
+        clear_call = stateMessage.Clear()
+        map_state_call = stateMessage.MapStateCall(stateName=state_name, 
clear=clear_call)
+        state_variable_request = 
stateMessage.StateVariableRequest(mapStateCall=map_state_call)
+        message = 
stateMessage.StateRequest(stateVariableRequest=state_variable_request)
+
+        
self._stateful_processor_api_client._send_proto_message(message.SerializeToString())
+        response_message = 
self._stateful_processor_api_client._receive_proto_message()
+        status = response_message[0]
+        if status != 0:
+            # TODO(SPARK-49233): Classify user facing errors.
+            raise PySparkRuntimeError(f"Error clearing map state: " 
f"{response_message[1]}")
+
+
+class MapStateIterator:
+    def __init__(self, map_state_client: MapStateClient, state_name: str, 
is_key: bool):
+        self.map_state_client = map_state_client
+        self.state_name = state_name
+        # Generate a unique identifier for the iterator to make sure iterators 
from the same

Review Comment:
   Somehow I put review comment about the problem of stateful iterator in 
#47878, and I see you address the issue already, awesome!



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala:
##########
@@ -336,6 +345,127 @@ class TransformWithStateInPandasStateServer(
     }
   }
 
+  private[sql] def handleMapStateRequest(message: MapStateCall): Unit = {
+    val stateName = message.getStateName
+    if (!mapStates.contains(stateName)) {
+      logWarning(log"Map state ${MDC(LogKeys.STATE_NAME, stateName)} is not 
initialized.")
+      sendResponse(1, s"Map state $stateName is not initialized.")
+      return
+    }
+    val mapStateInfo = mapStates(stateName)
+    message.getMethodCase match {
+      case MapStateCall.MethodCase.EXISTS =>
+        if (mapStateInfo.mapState.exists()) {
+          sendResponse(0)
+        } else {
+          // Send status code 2 to indicate that the list state doesn't have a 
value yet.
+          sendResponse(2, s"state $stateName doesn't exist")
+        }
+      case MapStateCall.MethodCase.GETVALUE =>
+        val keyBytes = message.getGetValue.getUserKey.toByteArray
+        val keyRow = PythonSQLUtils.toJVMRow(keyBytes, mapStateInfo.keySchema,
+          mapStateInfo.keyDeserializer)
+        val value = mapStateInfo.mapState.getValue(keyRow)
+        if (value != null) {
+          val valueBytes = PythonSQLUtils.toPyRow(value)
+          val byteString = ByteString.copyFrom(valueBytes)
+          sendResponse(0, null, byteString)
+        } else {
+          logWarning(log"Map state ${MDC(LogKeys.STATE_NAME, stateName)} 
doesn't contain" +
+            log" key ${MDC(LogKeys.KEY, keyRow.toString)}.")
+          sendResponse(0)
+        }
+      case MapStateCall.MethodCase.CONTAINSKEY =>
+        val keyBytes = message.getContainsKey.getUserKey.toByteArray
+        val keyRow = PythonSQLUtils.toJVMRow(keyBytes, mapStateInfo.keySchema,
+          mapStateInfo.keyDeserializer)
+        if (mapStateInfo.mapState.containsKey(keyRow)) {
+          sendResponse(0)
+        } else {
+          sendResponse(2, s"Map state $stateName doesn't contain key 
${keyRow.toString()}")
+        }
+      case MapStateCall.MethodCase.UPDATEVALUE =>
+        val keyBytes = message.getUpdateValue.getUserKey.toByteArray
+        val keyRow = PythonSQLUtils.toJVMRow(keyBytes, mapStateInfo.keySchema,
+          mapStateInfo.keyDeserializer)
+        val valueBytes = message.getUpdateValue.getValue.toByteArray
+        val valueRow = PythonSQLUtils.toJVMRow(valueBytes, 
mapStateInfo.valueSchema,
+          mapStateInfo.valueDeserializer)
+        mapStateInfo.mapState.updateValue(keyRow, valueRow)
+        sendResponse(0)
+      case MapStateCall.MethodCase.ITERATOR =>
+        val iteratorId = message.getIterator.getIteratorId
+        var iteratorOption = keyValueIterators.get(iteratorId)
+        if (iteratorOption.isEmpty) {
+          iteratorOption = Some(mapStateInfo.mapState.iterator())
+          keyValueIterators.put(iteratorId, iteratorOption.get)
+        }
+        if (!iteratorOption.get.hasNext) {
+          sendResponse(2, s"Map state $stateName doesn't contain any entry.")
+          return
+        } else {
+          sendResponse(0)
+        }
+        val keyValueStateSchema: StructType = StructType(

Review Comment:
   nit: Let's move this to else part as commented above.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala:
##########
@@ -336,6 +345,127 @@ class TransformWithStateInPandasStateServer(
     }
   }
 
+  private[sql] def handleMapStateRequest(message: MapStateCall): Unit = {
+    val stateName = message.getStateName
+    if (!mapStates.contains(stateName)) {
+      logWarning(log"Map state ${MDC(LogKeys.STATE_NAME, stateName)} is not 
initialized.")
+      sendResponse(1, s"Map state $stateName is not initialized.")
+      return
+    }
+    val mapStateInfo = mapStates(stateName)
+    message.getMethodCase match {
+      case MapStateCall.MethodCase.EXISTS =>
+        if (mapStateInfo.mapState.exists()) {
+          sendResponse(0)
+        } else {
+          // Send status code 2 to indicate that the list state doesn't have a 
value yet.
+          sendResponse(2, s"state $stateName doesn't exist")
+        }
+      case MapStateCall.MethodCase.GETVALUE =>
+        val keyBytes = message.getGetValue.getUserKey.toByteArray
+        val keyRow = PythonSQLUtils.toJVMRow(keyBytes, mapStateInfo.keySchema,
+          mapStateInfo.keyDeserializer)
+        val value = mapStateInfo.mapState.getValue(keyRow)
+        if (value != null) {
+          val valueBytes = PythonSQLUtils.toPyRow(value)
+          val byteString = ByteString.copyFrom(valueBytes)
+          sendResponse(0, null, byteString)
+        } else {
+          logWarning(log"Map state ${MDC(LogKeys.STATE_NAME, stateName)} 
doesn't contain" +
+            log" key ${MDC(LogKeys.KEY, keyRow.toString)}.")
+          sendResponse(0)
+        }
+      case MapStateCall.MethodCase.CONTAINSKEY =>
+        val keyBytes = message.getContainsKey.getUserKey.toByteArray
+        val keyRow = PythonSQLUtils.toJVMRow(keyBytes, mapStateInfo.keySchema,
+          mapStateInfo.keyDeserializer)
+        if (mapStateInfo.mapState.containsKey(keyRow)) {
+          sendResponse(0)
+        } else {
+          sendResponse(2, s"Map state $stateName doesn't contain key 
${keyRow.toString()}")
+        }
+      case MapStateCall.MethodCase.UPDATEVALUE =>
+        val keyBytes = message.getUpdateValue.getUserKey.toByteArray
+        val keyRow = PythonSQLUtils.toJVMRow(keyBytes, mapStateInfo.keySchema,
+          mapStateInfo.keyDeserializer)
+        val valueBytes = message.getUpdateValue.getValue.toByteArray
+        val valueRow = PythonSQLUtils.toJVMRow(valueBytes, 
mapStateInfo.valueSchema,
+          mapStateInfo.valueDeserializer)
+        mapStateInfo.mapState.updateValue(keyRow, valueRow)
+        sendResponse(0)
+      case MapStateCall.MethodCase.ITERATOR =>
+        val iteratorId = message.getIterator.getIteratorId
+        var iteratorOption = keyValueIterators.get(iteratorId)
+        if (iteratorOption.isEmpty) {
+          iteratorOption = Some(mapStateInfo.mapState.iterator())
+          keyValueIterators.put(iteratorId, iteratorOption.get)
+        }
+        if (!iteratorOption.get.hasNext) {
+          sendResponse(2, s"Map state $stateName doesn't contain any entry.")
+          return

Review Comment:
   nit: Let's avoid using return till linter can aggressively block non-local 
return. It should be OK to have one more indentation for remaining code.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala:
##########
@@ -389,6 +520,63 @@ class TransformWithStateInPandasStateServer(
         } else {
           sendResponse(1, s"List state $stateName already exists")
         }
+        case StateVariableType.MapState => if (!mapStates.contains(stateName)) 
{
+          val valueSchema = StructType.fromString(mapStateValueSchemaString)
+          val valueExpressionEncoder = 
ExpressionEncoder(valueSchema).resolveAndBind()
+          val state = if (ttlDurationMs.isEmpty) {
+            statefulProcessorHandle.getMapState[Row, Row](stateName,
+              Encoders.row(schema), Encoders.row(valueSchema))
+          } else {
+            statefulProcessorHandle.getMapState[Row, Row](stateName, 
Encoders.row(schema),
+              Encoders.row(valueSchema), 
TTLConfig(Duration.ofMillis(ttlDurationMs.get)))
+          }
+          mapStates.put(stateName,
+            MapStateInfo(state, schema, valueSchema, 
expressionEncoder.createDeserializer(),
+              expressionEncoder.createSerializer(), 
valueExpressionEncoder.createDeserializer(),
+              valueExpressionEncoder.createSerializer()))
+          sendResponse(0)
+        } else {
+          sendResponse(1, s"Map state $stateName already exists")
+        }
+    }
+  }
+
+  private def sendIteratorAsArrowBatches[T](

Review Comment:
   I see this is also added from #47878 but different implementation. Do you 
plan to rebase this PR to apply the new implementation to timer as well once 
#47878 is merged?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala:
##########
@@ -281,43 +314,19 @@ class TransformWithStateInPandasStateServer(
         sendResponse(0)
       case ListStateCall.MethodCase.LISTSTATEGET =>
         val iteratorId = message.getListStateGet.getIteratorId
-        var iteratorOption = listStateIterators.get(iteratorId)
+        var iteratorOption = iterators.get(iteratorId)
         if (iteratorOption.isEmpty) {
           iteratorOption = Some(listStateInfo.listState.get())
-          listStateIterators.put(iteratorId, iteratorOption.get)
+          iterators.put(iteratorId, iteratorOption.get)
         }
         if (!iteratorOption.get.hasNext) {
           sendResponse(2, s"List state $stateName doesn't contain any value.")
           return
         } else {
           sendResponse(0)
         }
-        outputStream.flush()
-        val arrowStreamWriter = if (arrowStreamWriterForTest != null) {
-          arrowStreamWriterForTest
-        } else {
-          val arrowSchema = ArrowUtils.toArrowSchema(listStateInfo.schema, 
timeZoneId,
-            errorOnDuplicatedFieldNames, largeVarTypes)
-          val allocator = ArrowUtils.rootAllocator.newChildAllocator(
-          s"stdout writer for transformWithStateInPandas state socket", 0, 
Long.MaxValue)
-          val root = VectorSchemaRoot.create(arrowSchema, allocator)
-          new BaseStreamingArrowWriter(root, new ArrowStreamWriter(root, null, 
outputStream),
-            arrowTransformWithStateInPandasMaxRecordsPerBatch)
-        }
-        val listRowSerializer = listStateInfo.serializer
-        // Only write a single batch in each GET request. Stops writing row if 
rowCount reaches
-        // the arrowTransformWithStateInPandasMaxRecordsPerBatch limit. This 
is to handle a case
-        // when there are multiple state variables, user tries to access a 
different state variable
-        // while the current state variable is not exhausted yet.
-        var rowCount = 0
-        while (iteratorOption.get.hasNext &&
-          rowCount < arrowTransformWithStateInPandasMaxRecordsPerBatch) {
-          val row = iteratorOption.get.next()
-          val internalRow = listRowSerializer(row)
-          arrowStreamWriter.writeRow(internalRow)
-          rowCount += 1
-        }
-        arrowStreamWriter.finalizeCurrentArrowBatch()
+        sendIteratorAsArrowBatches(iteratorOption.get, listStateInfo.schema,
+          arrowStreamWriterForTest) {data => listStateInfo.serializer(data)}

Review Comment:
   nit: {` `data =>



##########
python/pyspark/sql/streaming/list_state_client.py:
##########
@@ -78,8 +78,11 @@ def get(self, state_name: str, iterator_id: str) -> Tuple:
             status = response_message[0]
             if status == 0:
                 iterator = 
self._stateful_processor_api_client._read_arrow_state()
-                batch = next(iterator)
-                pandas_df = batch.to_pandas()
+                data_batch = None

Review Comment:
   So we do iterate (consume) all batches if there are more than one but only 
take the first batch? Please leave code comment as @jingz-db suggested as it 
confuses people to think it might be a bug from looking into the code.
   
   Also, when the iterator has multiple batches and how it is safe to ignore 
remaining and take only the first one?



##########
python/pyspark/sql/streaming/list_state_client.py:
##########
@@ -78,8 +78,11 @@ def get(self, state_name: str, iterator_id: str) -> Tuple:
             status = response_message[0]
             if status == 0:
                 iterator = 
self._stateful_processor_api_client._read_arrow_state()
-                batch = next(iterator)
-                pandas_df = batch.to_pandas()
+                data_batch = None

Review Comment:
   I'd say I want to see this fix in separate PR, with relevant test which 
fails on master branch and passes with the fix. Let's scope the PR properly - 
the PR is aiming to add MapState with TTL.



##########
python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py:
##########
@@ -399,21 +431,30 @@ def handleInputRows(self, key, rows) -> 
Iterator[pd.DataFrame]:
             iter = self.ttl_list_state.get()
             for s in iter:
                 ttl_list_state_count += s[0]
+        if self.ttl_map_state.exists():
+            ttl_map_state_count = self.ttl_map_state.get_value(key)[0]

Review Comment:
   just to double confirm, here the key does not necessarily need to be the 
same with grouping key, right? It's just to simplify the test.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to