bogao007 commented on code in PR #47133:
URL: https://github.com/apache/spark/pull/47133#discussion_r1685163048


##########
sql/core/src/main/java/org/apache/spark/sql/execution/streaming/StateMessage.proto:
##########
@@ -0,0 +1,86 @@
+syntax = "proto3";
+
+package org.apache.spark.sql.execution.streaming.state;

Review Comment:
   Yeah we can keep a single file, removed the one under python



##########
python/pyspark/sql/streaming/state_api_client.py:
##########
@@ -0,0 +1,172 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from enum import Enum
+import os
+import socket
+from typing import Any, Union, cast
+
+import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+from pyspark.serializers import write_int, read_int, UTF8Deserializer
+from pyspark.sql.types import StructType, _parse_datatype_string
+
+
+class StatefulProcessorHandleState(Enum):
+    CREATED = 1
+    INITIALIZED = 2
+    DATA_PROCESSED = 3
+    CLOSED = 4
+
+
+class StateApiClient:
+    def __init__(
+            self,
+            state_server_port: int) -> None:
+        self._client_socket = socket.socket()
+        self._client_socket.connect(("localhost", state_server_port))
+        self.sockfile = self._client_socket.makefile("rwb", 
int(os.environ.get("SPARK_BUFFER_SIZE",
+                                                                               
65536)))
+        print(f"client is ready - connection established")
+        self.handle_state = StatefulProcessorHandleState.CREATED
+        self.utf8_deserializer = UTF8Deserializer()
+        # place holder, will remove when actual implementation is done
+        # self.setHandleState(StatefulProcessorHandleState.CLOSED)
+
+    def set_handle_state(self, state: StatefulProcessorHandleState) -> None:
+        print(f"setting handle state to: {state}")
+        proto_state = self._get_proto_state(state)
+        set_handle_state = stateMessage.SetHandleState(state=proto_state)
+        handle_call = 
stateMessage.StatefulProcessorCall(setHandleState=set_handle_state)
+        message = stateMessage.StateRequest(statefulProcessorCall=handle_call)
+
+        self._send_proto_message(message)
+        status = read_int(self.sockfile)
+
+        if (status == 0):
+            self.handle_state = state
+        print(f"setHandleState status= {status}")
+
+    def get_value_state(self, state_name: str, schema: Union[StructType, str]) 
-> None:
+        if isinstance(schema, str):
+            schema = cast(StructType, _parse_datatype_string(schema))
+
+        print(f"initializing value state: {state_name}")
+
+        state_call_command = stateMessage.StateCallCommand()
+        state_call_command.stateName = state_name
+        state_call_command.schema = schema.json()
+        call = 
stateMessage.StatefulProcessorCall(getValueState=state_call_command)
+
+        message = stateMessage.StateRequest(statefulProcessorCall=call)
+
+        self._send_proto_message(message)
+        status = read_int(self.sockfile)
+        print(f"getValueState status= {status}")
+
+    def value_state_exists(self, state_name: str) -> bool:
+        print(f"checking value state exists: {state_name}")
+        exists_call = stateMessage.Exists(stateName=state_name)
+        value_state_call = stateMessage.ValueStateCall(exists=exists_call)
+        state_variable_request = 
stateMessage.StateVariableRequest(valueStateCall=value_state_call)
+        message = 
stateMessage.StateRequest(stateVariableRequest=state_variable_request)
+
+        self._send_proto_message(message)
+        status = read_int(self.sockfile)
+        print(f"valueStateExists status= {status}")
+        if (status == 0):
+            return True
+        else:
+            return False
+
+    def value_state_get(self, state_name: str) -> Any:
+        print(f"getting value state: {state_name}")
+        get_call = stateMessage.Get(stateName=state_name)
+        value_state_call = stateMessage.ValueStateCall(get=get_call)
+        state_variable_request = 
stateMessage.StateVariableRequest(valueStateCall=value_state_call)
+        message = 
stateMessage.StateRequest(stateVariableRequest=state_variable_request)
+
+        self._send_proto_message(message)
+        status = read_int(self.sockfile)
+        print(f"valueStateGet status= {status}")
+        if (status == 0):
+            return self.utf8_deserializer.loads(self.sockfile)
+        else:
+            return None
+
+    def value_state_update(self, state_name: str, schema: Union[StructType, 
str], value: str) -> None:
+        if isinstance(schema, str):
+            schema = cast(StructType, _parse_datatype_string(schema))
+        print(f"updating value state: {state_name}")
+        byteStr = value.encode('utf-8')
+        update_call = stateMessage.Update(stateName=state_name, 
schema=schema.json(), value=byteStr)
+        value_state_call = stateMessage.ValueStateCall(update=update_call)
+        state_variable_request = 
stateMessage.StateVariableRequest(valueStateCall=value_state_call)
+        message = 
stateMessage.StateRequest(stateVariableRequest=state_variable_request)
+
+        self._send_proto_message(message)
+        status = read_int(self.sockfile)
+        print(f"valueStateUpdate status= {status}")
+
+    def value_state_clear(self, state_name: str) -> None:
+        print(f"clearing value state: {state_name}")
+        clear_call = stateMessage.Clear(stateName=state_name)
+        value_state_call = stateMessage.ValueStateCall(clear=clear_call)
+        state_variable_request = 
stateMessage.StateVariableRequest(valueStateCall=value_state_call)
+        message = 
stateMessage.StateRequest(stateVariableRequest=state_variable_request)
+
+        self._send_proto_message(message)
+        status = read_int(self.sockfile)
+        print(f"valueStateClear status= {status}")
+
+    def set_implicit_key(self, key: str) -> None:
+        print(f"setting implicit key: {key}")
+        set_implicit_key = stateMessage.SetImplicitKey(key=key)
+        request = 
stateMessage.ImplicitGroupingKeyRequest(setImplicitKey=set_implicit_key)
+        message = stateMessage.StateRequest(implicitGroupingKeyRequest=request)
+
+        self._send_proto_message(message)
+        status = read_int(self.sockfile)
+        print(f"setImplicitKey status= {status}")
+
+    def remove_implicit_key(self) -> None:
+        print(f"removing implicit key")
+        remove_implicit_key = stateMessage.RemoveImplicitKey()
+        request = 
stateMessage.ImplicitGroupingKeyRequest(removeImplicitKey=remove_implicit_key)
+        message = stateMessage.StateRequest(implicitGroupingKeyRequest=request)
+
+        self._send_proto_message(message)
+        status = read_int(self.sockfile)
+        print(f"removeImplicitKey status= {status}")
+
+    def _get_proto_state(self,
+                         state: StatefulProcessorHandleState) -> 
stateMessage.HandleState.ValueType:
+        if (state == StatefulProcessorHandleState.CREATED):
+            return stateMessage.CREATED
+        elif (state == StatefulProcessorHandleState.INITIALIZED):
+            return stateMessage.INITIALIZED
+        elif (state == StatefulProcessorHandleState.DATA_PROCESSED):
+            return stateMessage.DATA_PROCESSED
+        else:
+            return stateMessage.CLOSED
+
+    def _send_proto_message(self, message: stateMessage.StateRequest) -> None:
+        serialized_msg = message.SerializeToString()
+        print(f"sending message -- len = {len(serialized_msg)} 
{str(serialized_msg)}")
+        write_int(0, self.sockfile)

Review Comment:
   Added



##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -358,6 +362,141 @@ def applyInPandasWithState(
         )
         return DataFrame(jdf, self.session)
 
+
+    def transformWithStateInPandas(self, 

Review Comment:
   As we discussed with @HyukjinKwon, we will keep with 
`transformWithStateInPandas`



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasPythonRunner.scala:
##########
@@ -0,0 +1,175 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import java.io.DataOutputStream
+import java.nio.file.{Files, Path}
+import java.util.concurrent.TimeUnit
+
+import scala.concurrent.ExecutionContext
+
+import jnr.unixsocket.UnixServerSocketChannel
+import jnr.unixsocket.UnixSocketAddress
+import org.apache.arrow.vector.VectorSchemaRoot
+import org.apache.arrow.vector.ipc.ArrowStreamWriter
+
+import org.apache.spark.TaskContext
+import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions}
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.execution.arrow
+import org.apache.spark.sql.execution.arrow.ArrowWriter
+import org.apache.spark.sql.execution.metric.SQLMetric
+import 
org.apache.spark.sql.execution.python.TransformWithStateInPandasPythonRunner.{InType,
 OutType}
+import org.apache.spark.sql.execution.streaming.StatefulProcessorHandleImpl
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.vectorized.ColumnarBatch
+import org.apache.spark.util.ThreadUtils
+
+/**
+ * Python runner implementation for TransformWithStateInPandas.
+ */
+class TransformWithStateInPandasPythonRunner(
+    funcs: Seq[(ChainedPythonFunctions, Long)],
+    evalType: Int,
+    argOffsets: Array[Array[Int]],
+    _schema: StructType,
+    processorHandle: StatefulProcessorHandleImpl,
+    _timeZoneId: String,
+    initialWorkerConf: Map[String, String],
+    override val pythonMetrics: Map[String, SQLMetric],
+    jobArtifactUUID: Option[String],
+    groupingKeySchema: StructType)
+  extends BasePythonRunner[InType, OutType](funcs.map(_._1), evalType, 
argOffsets, jobArtifactUUID)
+    with PythonArrowInput[InType]
+    with BasicPythonArrowOutput
+    with Logging {
+
+  private val sqlConf = SQLConf.get
+  private val arrowMaxRecordsPerBatch = sqlConf.arrowMaxRecordsPerBatch
+
+  private val serverId = 
TransformWithStateInPandasStateServer.allocateServerId()
+
+  private val socketPath = s"./uds_$serverId.sock"
+
+  override protected val workerConf: Map[String, String] = initialWorkerConf +
+    (SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key -> 
arrowMaxRecordsPerBatch.toString)
+
+  private val arrowWriter: arrow.ArrowWriter = ArrowWriter.create(root)
+
+  // Use lazy val to initialize the fields before these are accessed in 
[[PythonArrowInput]]'s
+  // constructor.
+  override protected lazy val schema: StructType = _schema
+  override protected lazy val timeZoneId: String = _timeZoneId
+  override protected val errorOnDuplicatedFieldNames: Boolean = true
+  override protected val largeVarTypes: Boolean = sqlConf.arrowUseLargeVarTypes
+
+  override protected def handleMetadataBeforeExec(stream: DataOutputStream): 
Unit = {
+    super.handleMetadataBeforeExec(stream)
+    // Also write the port number for state server
+    stream.writeInt(serverId)
+  }
+
+  override def compute(
+      inputIterator: Iterator[InType],
+      partitionIndex: Int,
+      context: TaskContext): Iterator[OutType] = {
+    var serverChannel: UnixServerSocketChannel = null
+    var failed = false
+    try {
+      val socketFile = Path.of(socketPath)
+      Files.deleteIfExists(socketFile)
+      val serverAddress = new UnixSocketAddress(socketPath)
+      serverChannel = UnixServerSocketChannel.open()
+      serverChannel.socket().bind(serverAddress)
+    } catch {
+      case e: Exception =>
+        failed = true
+        throw e
+    } finally {
+      if (failed) {
+        closeServerSocketChannelSilently(serverChannel)
+      }
+    }
+
+    val executor = 
ThreadUtils.newDaemonSingleThreadExecutor("stateConnectionListenerThread")
+    val executionContext = ExecutionContext.fromExecutor(executor)
+
+    executionContext.execute(
+      new TransformWithStateInPandasStateServer(serverChannel, processorHandle,
+        groupingKeySchema))
+
+    context.addTaskCompletionListener[Unit] { _ =>
+      logWarning(s"completion listener called")
+      executor.awaitTermination(10, TimeUnit.SECONDS)
+      executor.shutdownNow()
+      val socketFile = Path.of(socketPath)
+      Files.deleteIfExists(socketFile)
+    }
+
+    super.compute(inputIterator, partitionIndex, context)
+  }
+
+  private def closeServerSocketChannelSilently(serverChannel: 
UnixServerSocketChannel): Unit = {
+    try {
+      logWarning(s"closing the state server socket")
+      serverChannel.close()
+    } catch {
+      case e: Exception =>
+        logError(s"failed to close state server socket", e)
+    }
+  }
+
+  override protected def writeUDF(dataOut: DataOutputStream): Unit = {
+    PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets, None)
+  }
+
+  override protected def writeNextInputToArrowStream(
+      root: VectorSchemaRoot,
+      writer: ArrowStreamWriter,
+      dataOut: DataOutputStream,
+      inputIterator: Iterator[InType]): Boolean = {
+
+    if (inputIterator.hasNext) {
+      val startData = dataOut.size()
+      val next = inputIterator.next()
+      val nextBatch = next._2
+
+      while (nextBatch.hasNext) {
+        arrowWriter.write(nextBatch.next())
+      }

Review Comment:
   As we discussed, this work will be captured in a followup PR.



##########
python/pyspark/sql/pandas/serializers.py:
##########
@@ -1116,3 +1121,88 @@ def init_stream_yield_batches(batches):
         batches_to_write = init_stream_yield_batches(serialize_batches())
 
         return ArrowStreamSerializer.dump_stream(self, batches_to_write, 
stream)
+
+
+class TransformWithStateInPandasSerializer(ArrowStreamPandasUDFSerializer):
+    """
+    Serializer used by Python worker to evaluate UDF for 
transformWithStateInPandasSerializer.
+
+    Parameters
+    ----------
+    timezone : str
+        A timezone to respect when handling timestamp values
+    safecheck : bool
+        If True, conversion from Arrow to Pandas checks for overflow/truncation
+    assign_cols_by_name : bool
+        If True, then Pandas DataFrames will get columns by name
+    arrow_max_records_per_batch : int
+        Limit of the number of records that can be written to a single 
ArrowRecordBatch in memory.
+    """
+
+    def __init__(

Review Comment:
   We need this for each operator



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala:
##########
@@ -0,0 +1,287 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import java.io.{DataInputStream, DataOutputStream, EOFException}
+import java.nio.channels.Channels
+
+import scala.collection.mutable
+
+import com.google.protobuf.ByteString
+import jnr.unixsocket.UnixServerSocketChannel
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.{Encoder, Encoders, Row}
+import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, 
StatefulProcessorHandleImpl, StatefulProcessorHandleState}
+import 
org.apache.spark.sql.execution.streaming.state.StateMessage.{HandleState, 
ImplicitGroupingKeyRequest, StatefulProcessorCall, StateRequest, StateResponse, 
StateVariableRequest, ValueStateCall}
+import org.apache.spark.sql.streaming.ValueState
+import org.apache.spark.sql.types.{BooleanType, DataType, DoubleType, 
FloatType, IntegerType, LongType, StructType}
+
+/**
+ * This class is used to handle the state requests from the Python side.
+ */
+class TransformWithStateInPandasStateServer(
+    private val serverChannel: UnixServerSocketChannel,
+    private val statefulProcessorHandle: StatefulProcessorHandleImpl,
+    private val groupingKeySchema: StructType)
+  extends Runnable
+  with Logging{
+
+  private var inputStream: DataInputStream = _
+  private var outputStream: DataOutputStream = _
+
+  private val valueStates = mutable.HashMap[String, ValueState[Any]]()
+
+  def run(): Unit = {
+    logWarning(s"Waiting for connection from Python worker")
+    val channel = serverChannel.accept()
+    logWarning(s"listening on channel - ${channel.getLocalAddress}")
+
+    inputStream = new DataInputStream(
+      Channels.newInputStream(channel))
+    outputStream = new DataOutputStream(
+      Channels.newOutputStream(channel)
+    )
+
+    while (channel.isConnected &&
+      statefulProcessorHandle.getHandleState != 
StatefulProcessorHandleState.CLOSED) {
+
+      try {
+        logWarning(s"reading the version")
+        val version = inputStream.readInt()
+
+        if (version != -1) {
+          logWarning(s"version = ${version}")
+          assert(version == 0)
+          val messageLen = inputStream.readInt()
+          logWarning(s"parsing a message of ${messageLen} bytes")
+
+          val messageBytes = new Array[Byte](messageLen)
+          inputStream.read(messageBytes)
+          logWarning(s"read bytes = ${messageBytes.mkString("Array(", ", ", 
")")}")
+
+          val message = 
StateRequest.parseFrom(ByteString.copyFrom(messageBytes))
+
+          logWarning(s"read message = $message")
+          handleRequest(message)
+          logWarning(s"flush output stream")
+
+          outputStream.flush()
+        }
+      } catch {
+        case _: EOFException =>
+          logWarning(s"No more data to read from the socket")
+          
statefulProcessorHandle.setHandleState(StatefulProcessorHandleState.CLOSED)
+          return
+        case e: Exception =>
+          logWarning(s"Error reading message: ${e.getMessage}")
+          sendResponse(1, e.getMessage)
+          outputStream.flush()
+          
statefulProcessorHandle.setHandleState(StatefulProcessorHandleState.CLOSED)
+          return
+      }
+    }
+    logWarning(s"done from the state server thread")
+  }
+
+  private def handleRequest(message: StateRequest): Unit = {

Review Comment:
   Done.



##########
python/pyspark/sql/streaming/state_api_client.py:
##########
@@ -0,0 +1,142 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from enum import Enum
+import os
+import socket
+from typing import Any, Union, cast
+
+import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+from pyspark.serializers import write_int, read_int, UTF8Deserializer
+from pyspark.sql.types import StructType, _parse_datatype_string
+
+
+class StatefulProcessorHandleState(Enum):
+    CREATED = 1
+    INITIALIZED = 2
+    DATA_PROCESSED = 3
+    CLOSED = 4
+
+

Review Comment:
   Yep



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala:
##########
@@ -161,6 +161,41 @@ case class FlatMapGroupsInPandasWithState(
     newChild: LogicalPlan): FlatMapGroupsInPandasWithState = copy(child = 
newChild)
 }
 
+object TransformWithStateInPandas {

Review Comment:
   Since we keep with `TransformWithStateInPandas` on Python side, we will use 
the same here.



##########
python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py:
##########
@@ -0,0 +1,152 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import random
+import shutil
+import string
+import sys
+import tempfile
+import pandas as pd
+from pyspark.sql.streaming import StatefulProcessor, StatefulProcessorHandle
+from typing import Iterator
+
+import unittest
+from typing import cast
+
+from pyspark import SparkConf
+from pyspark.sql.streaming.state import GroupStateTimeout, GroupState
+from pyspark.sql.types import (
+    LongType,
+    StringType,
+    StructType,
+    StructField,
+    Row,
+)
+from pyspark.testing.sqlutils import (
+    ReusedSQLTestCase,
+    have_pandas,
+    have_pyarrow,
+    pandas_requirement_message,
+    pyarrow_requirement_message,
+)
+from pyspark.testing.utils import eventually
+
+
[email protected](
+    not have_pandas or not have_pyarrow,
+    cast(str, pandas_requirement_message or pyarrow_requirement_message),
+)
+class TransformWithStateInPandasTestsMixin:
+    @classmethod
+    def conf(cls):
+        cfg = SparkConf()
+        cfg.set("spark.sql.shuffle.partitions", "5")
+        cfg.set("spark.sql.streaming.stateStore.providerClass",
+                
"org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider")
+        return cfg
+
+    def _test_apply_in_pandas_with_state_basic(self, func, check_results):
+        input_path = tempfile.mkdtemp()
+
+        def prepare_test_resource():
+            with open(input_path + "/text-test.txt", "w") as fw:
+                fw.write("hello\n")
+                fw.write("this\n")
+
+        prepare_test_resource()
+
+        df = self.spark.readStream.format("text").load(input_path)
+
+        for q in self.spark.streams.active:
+            q.stop()
+        self.assertTrue(df.isStreaming)
+
+        output_type = StructType(
+            [StructField("key", StringType()), StructField("countAsString", 
StringType())]
+        )
+        state_type = StructType([StructField("c", LongType())])
+
+        q = (
+            df.groupBy(df["value"])
+            .transformWithStateInPandas(stateful_processor = 
SimpleStatefulProcessor(),
+                                        outputStructType=output_type,
+                                        outputMode="Update",
+                                        timeMode="None")
+            .writeStream.queryName("this_query")
+            .foreachBatch(check_results)
+            .outputMode("update")
+            .start()
+        )
+
+        self.assertEqual(q.name, "this_query")
+        self.assertTrue(q.isActive)
+        q.processAllAvailable()
+        self.assertTrue(q.exception() is None)
+
+    def test_apply_in_pandas_with_state_basic(self):
+        def func(key, pdf_iter, state):
+            assert isinstance(state, GroupState)
+
+            total_len = 0
+            for pdf in pdf_iter:
+                total_len += len(pdf)
+
+            state.update((total_len,))
+            assert state.get[0] == 1
+            yield pd.DataFrame({"key": [key[0]], "countAsString": 
[str(total_len)]})
+
+        def check_results(batch_df, _):
+            assert set(batch_df.sort("key").collect()) == {
+                Row(key="hello", countAsString="1"),
+                Row(key="this", countAsString="1"),
+            }
+
+        self._test_apply_in_pandas_with_state_basic(func, check_results)
+
+
+class SimpleStatefulProcessor(StatefulProcessor):
+  def init(self, handle: StatefulProcessorHandle) -> None:
+    state_schema = StructType([
+      StructField("value", StringType(), True)
+    ])
+    self.value_state = handle.getValueState("testValueState", state_schema)
+  def handleInputRows(self, key, rows) -> Iterator[pd.DataFrame]:
+    self.value_state.update("test_value")
+    exists = self.value_state.exists()
+    value = self.value_state.get()
+    self.value_state.clear()

Review Comment:
   Will add more tests here



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to