HeartSaVioR commented on code in PR #47133:
URL: https://github.com/apache/spark/pull/47133#discussion_r1716823332


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala:
##########
@@ -166,73 +170,69 @@ class TransformWithStateInPandasStateServer(
 
   private[sql] def handleValueStateRequest(message: ValueStateCall): Unit = {
     val stateName = message.getStateName
+    if (!valueStates.contains(stateName)) {
+      logWarning(log"Value state ${MDC(LogKeys.STATE_NAME, stateName)} is not 
initialized.")
+      sendResponse(1, s"Value state $stateName is not initialized.")
+      return
+    }
     message.getMethodCase match {
       case ValueStateCall.MethodCase.EXISTS =>
-        if (valueStates.contains(stateName) && 
valueStates(stateName).exists()) {
+        if (valueStates(stateName)._1.exists()) {
           sendResponse(0)
         } else {
           sendResponse(1, s"state $stateName doesn't exist")

Review Comment:
   How do we distinguish the case of "no value state is defined for the state 
variable name" vs "the value state is defined but not having a value yet" if we 
use the same status code? 



##########
python/pyspark/sql/streaming/stateful_processor.py:
##########
@@ -48,20 +50,19 @@ def exists(self) -> bool:
         """
         return self._value_state_client.exists(self._state_name)
 
-    def get(self) -> Any:
-        import pandas as pd
-
+    def get(self) -> Row:

Review Comment:
   nit: `Optional[Row]`?



##########
python/pyspark/sql/streaming/value_state_client.py:
##########
@@ -40,12 +40,15 @@ def exists(self, state_name: str) -> bool:
         status = response_message[0]
         if status == 0:
             return True
-        elif status == 1:
-            # server returns 1 if the state does not exist
+        elif status == 1 and "doesn't exist" in response_message[1]:
             return False
         else:
             raise PySparkRuntimeError(
-                f"Error checking value state exists: " f"{response_message[1]}"
+                errorClass="CALL_BEFORE_INITIALIZE",

Review Comment:
   ditto, explicitly define a dedicated error class



##########
python/pyspark/sql/streaming/value_state_client.py:
##########
@@ -95,4 +113,10 @@ def clear(self, state_name: str) -> None:
         response_message = 
self._stateful_processor_api_client._receive_proto_message()
         status = response_message[0]
         if status != 0:
-            raise PySparkRuntimeError(f"Error clearing value state: " 
f"{response_message[1]}")
+            raise PySparkRuntimeError(

Review Comment:
   ditto, same error class as above



##########
python/pyspark/sql/streaming/value_state_client.py:
##########
@@ -40,12 +40,15 @@ def exists(self, state_name: str) -> bool:
         status = response_message[0]
         if status == 0:
             return True
-        elif status == 1:
-            # server returns 1 if the state does not exist
+        elif status == 1 and "doesn't exist" in response_message[1]:

Review Comment:
   I'd recommend to use the different status code instead of parsing. Please 
consider the change relying on string/hardcode to be unacceptable except 
specific needs. 



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala:
##########
@@ -166,73 +170,69 @@ class TransformWithStateInPandasStateServer(
 
   private[sql] def handleValueStateRequest(message: ValueStateCall): Unit = {
     val stateName = message.getStateName
+    if (!valueStates.contains(stateName)) {
+      logWarning(log"Value state ${MDC(LogKeys.STATE_NAME, stateName)} is not 
initialized.")
+      sendResponse(1, s"Value state $stateName is not initialized.")
+      return
+    }
     message.getMethodCase match {
       case ValueStateCall.MethodCase.EXISTS =>
-        if (valueStates.contains(stateName) && 
valueStates(stateName).exists()) {
+        if (valueStates(stateName)._1.exists()) {
           sendResponse(0)
         } else {
           sendResponse(1, s"state $stateName doesn't exist")
         }
       case ValueStateCall.MethodCase.GET =>
-        if (valueStates.contains(stateName)) {
-          val valueOption = valueStates(stateName).getOption()
-          if (valueOption.isDefined) {
-            sendResponse(0)
-            // Serialize the value row as a byte array
-            val valueBytes = PythonSQLUtils.toPyRow(valueOption.get)
-            outputStream.writeInt(valueBytes.length)
-            outputStream.write(valueBytes)
-          } else {
-            logWarning(log"state ${MDC(LogKeys.STATE_NAME, stateName)} doesn't 
exist")
-            sendResponse(1, s"state $stateName doesn't exist")
-          }
+        val valueOption = valueStates(stateName)._1.getOption()
+        if (valueOption.isDefined) {
+          // Serialize the value row as a byte array
+          val valueBytes = PythonSQLUtils.toPyRow(valueOption.get)
+          val byteString = ByteString.copyFrom(valueBytes)
+          sendResponse(0, null, byteString)
         } else {
-          logWarning(log"state ${MDC(LogKeys.STATE_NAME, stateName)} doesn't 
exist")
-          sendResponse(1, s"state $stateName doesn't exist")
+          logWarning(log"Value state ${MDC(LogKeys.STATE_NAME, stateName)} 
doesn't contain" +
+            log" a value.")
+          sendResponse(0)
         }
       case ValueStateCall.MethodCase.VALUESTATEUPDATE =>
         val byteArray = message.getValueStateUpdate.getValue.toByteArray
-        val schema = 
StructType.fromString(message.getValueStateUpdate.getSchema)
+        val valueStateTuple = valueStates(stateName)
         // The value row is serialized as a byte array, we need to convert it 
back to a Row
-        val valueRow = PythonSQLUtils.toJVMRow(byteArray, schema,
-          ExpressionEncoder(schema).resolveAndBind().createDeserializer())
-        if (valueStates.contains(stateName)) {
-          valueStates(stateName).update(valueRow)
-          sendResponse(0)
-        } else {
-          logWarning(log"state ${MDC(LogKeys.STATE_NAME, stateName)} doesn't 
exist")
-          sendResponse(1, s"state $stateName doesn't exist")
-        }
+        val valueRow = PythonSQLUtils.toJVMRow(byteArray, valueStateTuple._2, 
valueStateTuple._3)
+        valueStates(stateName)._1.update(valueRow)

Review Comment:
   nit: valueStateTuple



##########
python/pyspark/sql/streaming/stateful_processor_api_client.py:
##########
@@ -115,7 +117,13 @@ def get_value_state(self, state_name: str, schema: 
Union[StructType, str]) -> No
         response_message = self._receive_proto_message()
         status = response_message[0]
         if status != 0:
-            raise PySparkRuntimeError(f"Error initializing value state: " 
f"{response_message[1]}")
+            raise PySparkRuntimeError(

Review Comment:
   Shall we give a better error class as it's user facing error? You can revert 
back and file a JIRA ticket for this as well to defer the change.



##########
python/pyspark/sql/streaming/value_state_client.py:
##########
@@ -60,17 +63,26 @@ def get(self, state_name: str) -> Any:
         response_message = 
self._stateful_processor_api_client._receive_proto_message()
         status = response_message[0]
         if status == 0:
-            return 
self._stateful_processor_api_client._receive_and_deserialize()
+            if len(response_message[2]) == 0:
+                return None
+            row = 
self._stateful_processor_api_client._deserialize_from_bytes(response_message[2])
+            return row
         else:
-            raise PySparkRuntimeError(f"Error getting value state: 
{response_message[1]}")
+            raise PySparkRuntimeError(

Review Comment:
   ditto, probably the same error class with above



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala:
##########
@@ -166,73 +170,69 @@ class TransformWithStateInPandasStateServer(
 
   private[sql] def handleValueStateRequest(message: ValueStateCall): Unit = {
     val stateName = message.getStateName
+    if (!valueStates.contains(stateName)) {
+      logWarning(log"Value state ${MDC(LogKeys.STATE_NAME, stateName)} is not 
initialized.")
+      sendResponse(1, s"Value state $stateName is not initialized.")
+      return
+    }
     message.getMethodCase match {
       case ValueStateCall.MethodCase.EXISTS =>
-        if (valueStates.contains(stateName) && 
valueStates(stateName).exists()) {
+        if (valueStates(stateName)._1.exists()) {
           sendResponse(0)
         } else {
           sendResponse(1, s"state $stateName doesn't exist")
         }
       case ValueStateCall.MethodCase.GET =>
-        if (valueStates.contains(stateName)) {
-          val valueOption = valueStates(stateName).getOption()
-          if (valueOption.isDefined) {
-            sendResponse(0)
-            // Serialize the value row as a byte array
-            val valueBytes = PythonSQLUtils.toPyRow(valueOption.get)
-            outputStream.writeInt(valueBytes.length)
-            outputStream.write(valueBytes)
-          } else {
-            logWarning(log"state ${MDC(LogKeys.STATE_NAME, stateName)} doesn't 
exist")
-            sendResponse(1, s"state $stateName doesn't exist")
-          }
+        val valueOption = valueStates(stateName)._1.getOption()
+        if (valueOption.isDefined) {
+          // Serialize the value row as a byte array
+          val valueBytes = PythonSQLUtils.toPyRow(valueOption.get)
+          val byteString = ByteString.copyFrom(valueBytes)
+          sendResponse(0, null, byteString)
         } else {
-          logWarning(log"state ${MDC(LogKeys.STATE_NAME, stateName)} doesn't 
exist")
-          sendResponse(1, s"state $stateName doesn't exist")
+          logWarning(log"Value state ${MDC(LogKeys.STATE_NAME, stateName)} 
doesn't contain" +
+            log" a value.")
+          sendResponse(0)
         }
       case ValueStateCall.MethodCase.VALUESTATEUPDATE =>
         val byteArray = message.getValueStateUpdate.getValue.toByteArray
-        val schema = 
StructType.fromString(message.getValueStateUpdate.getSchema)
+        val valueStateTuple = valueStates(stateName)
         // The value row is serialized as a byte array, we need to convert it 
back to a Row
-        val valueRow = PythonSQLUtils.toJVMRow(byteArray, schema,
-          ExpressionEncoder(schema).resolveAndBind().createDeserializer())
-        if (valueStates.contains(stateName)) {
-          valueStates(stateName).update(valueRow)
-          sendResponse(0)
-        } else {
-          logWarning(log"state ${MDC(LogKeys.STATE_NAME, stateName)} doesn't 
exist")
-          sendResponse(1, s"state $stateName doesn't exist")
-        }
+        val valueRow = PythonSQLUtils.toJVMRow(byteArray, valueStateTuple._2, 
valueStateTuple._3)
+        valueStates(stateName)._1.update(valueRow)
+        sendResponse(0)
       case ValueStateCall.MethodCase.CLEAR =>
-        if (valueStates.contains(stateName)) {
-          valueStates(stateName).clear()
-          sendResponse(0)
-        } else {
-          logWarning(log"state ${MDC(LogKeys.STATE_NAME, stateName)} doesn't 
exist")
-          sendResponse(1, s"state $stateName doesn't exist")
-        }
+        valueStates(stateName)._1.clear()
+        sendResponse(0)
       case _ =>
         throw new IllegalArgumentException("Invalid method call")
     }
   }
 
-  private def sendResponse(status: Int, errorMessage: String = null): Unit = {
+  private def sendResponse(
+    status: Int,

Review Comment:
   nit: 2 more spaces (while we are here)



##########
python/pyspark/sql/streaming/stateful_processor_api_client.py:
##########
@@ -115,7 +117,13 @@ def get_value_state(self, state_name: str, schema: 
Union[StructType, str]) -> No
         response_message = self._receive_proto_message()
         status = response_message[0]
         if status != 0:
-            raise PySparkRuntimeError(f"Error initializing value state: " 
f"{response_message[1]}")
+            raise PySparkRuntimeError(

Review Comment:
   I'd expect having dedicated error class, if Scala version of the 
implementation uses the error class then use the same, otherwise define a new 
one.



##########
python/pyspark/sql/streaming/value_state_client.py:
##########
@@ -81,7 +93,13 @@ def update(self, state_name: str, schema: Union[StructType, 
str], value: Tuple)
         response_message = 
self._stateful_processor_api_client._receive_proto_message()
         status = response_message[0]
         if status != 0:
-            raise PySparkRuntimeError(f"Error updating value state: " 
f"{response_message[1]}")
+            raise PySparkRuntimeError(

Review Comment:
   ditto, same error class as above



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to