HeartSaVioR commented on code in PR #47133: URL: https://github.com/apache/spark/pull/47133#discussion_r1714675981
########## sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasPythonRunner.scala: ########## @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io.DataOutputStream +import java.net.ServerSocket + +import scala.concurrent.ExecutionContext + +import org.apache.arrow.vector.VectorSchemaRoot +import org.apache.arrow.vector.ipc.ArrowStreamWriter + +import org.apache.spark.TaskContext +import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions, PythonRDD} +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.arrow +import org.apache.spark.sql.execution.arrow.ArrowWriter +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.execution.python.TransformWithStateInPandasPythonRunner.{InType, OutType} +import org.apache.spark.sql.execution.streaming.StatefulProcessorHandleImpl +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.util.ThreadUtils + +/** + * Python runner implementation for TransformWithStateInPandas. + */ +class TransformWithStateInPandasPythonRunner( + funcs: Seq[(ChainedPythonFunctions, Long)], + evalType: Int, + argOffsets: Array[Array[Int]], + _schema: StructType, + processorHandle: StatefulProcessorHandleImpl, + _timeZoneId: String, + initialWorkerConf: Map[String, String], + override val pythonMetrics: Map[String, SQLMetric], + jobArtifactUUID: Option[String], + groupingKeySchema: StructType) + extends BasePythonRunner[InType, OutType](funcs.map(_._1), evalType, argOffsets, jobArtifactUUID) + with PythonArrowInput[InType] + with BasicPythonArrowOutput + with Logging { + + private val sqlConf = SQLConf.get + private val arrowMaxRecordsPerBatch = sqlConf.arrowMaxRecordsPerBatch + + private var stateSocketSocketPort: Int = 0 Review Comment: nit: Probably one of `Socket` should be `Server`? ########## sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala: ########## @@ -0,0 +1,241 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream, DataOutputStream, EOFException} +import java.net.ServerSocket + +import scala.collection.mutable + +import com.google.protobuf.ByteString + +import org.apache.spark.internal.{Logging, LogKeys, MDC} +import org.apache.spark.sql.{Encoders, Row} +import org.apache.spark.sql.api.python.PythonSQLUtils +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, StatefulProcessorHandleImpl, StatefulProcessorHandleState} +import org.apache.spark.sql.execution.streaming.state.StateMessage.{HandleState, ImplicitGroupingKeyRequest, StatefulProcessorCall, StateRequest, StateResponse, StateVariableRequest, ValueStateCall} +import org.apache.spark.sql.streaming.ValueState +import org.apache.spark.sql.types.StructType + +/** + * This class is used to handle the state requests from the Python side. It runs on a separate + * thread spawned by TransformWithStateInPandasStateRunner per task. It opens a dedicated socket + * to process/transfer state related info which is shut down when task finishes or there's an error + * on opening the socket. It run It processes following state requests and return responses to the Review Comment: nit: `It run` `It processes`? ########## sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala: ########## @@ -0,0 +1,241 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream, DataOutputStream, EOFException} +import java.net.ServerSocket + +import scala.collection.mutable + +import com.google.protobuf.ByteString + +import org.apache.spark.internal.{Logging, LogKeys, MDC} +import org.apache.spark.sql.{Encoders, Row} +import org.apache.spark.sql.api.python.PythonSQLUtils +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, StatefulProcessorHandleImpl, StatefulProcessorHandleState} +import org.apache.spark.sql.execution.streaming.state.StateMessage.{HandleState, ImplicitGroupingKeyRequest, StatefulProcessorCall, StateRequest, StateResponse, StateVariableRequest, ValueStateCall} +import org.apache.spark.sql.streaming.ValueState +import org.apache.spark.sql.types.StructType + +/** + * This class is used to handle the state requests from the Python side. It runs on a separate + * thread spawned by TransformWithStateInPandasStateRunner per task. It opens a dedicated socket + * to process/transfer state related info which is shut down when task finishes or there's an error + * on opening the socket. It run It processes following state requests and return responses to the + * Python side: + * - Requests for managing explicit grouping key. + * - Stateful processor requests. + * - Requests for managing state variables (e.g. valueState). + */ +class TransformWithStateInPandasStateServer( + private val stateServerSocket: ServerSocket, Review Comment: nit: in many cases, having `private val` in constructor param in class is redundant. ########## sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala: ########## @@ -0,0 +1,241 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream, DataOutputStream, EOFException} +import java.net.ServerSocket + +import scala.collection.mutable + +import com.google.protobuf.ByteString + +import org.apache.spark.internal.{Logging, LogKeys, MDC} +import org.apache.spark.sql.{Encoders, Row} +import org.apache.spark.sql.api.python.PythonSQLUtils +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, StatefulProcessorHandleImpl, StatefulProcessorHandleState} +import org.apache.spark.sql.execution.streaming.state.StateMessage.{HandleState, ImplicitGroupingKeyRequest, StatefulProcessorCall, StateRequest, StateResponse, StateVariableRequest, ValueStateCall} +import org.apache.spark.sql.streaming.ValueState +import org.apache.spark.sql.types.StructType + +/** + * This class is used to handle the state requests from the Python side. It runs on a separate + * thread spawned by TransformWithStateInPandasStateRunner per task. It opens a dedicated socket + * to process/transfer state related info which is shut down when task finishes or there's an error + * on opening the socket. It run It processes following state requests and return responses to the + * Python side: + * - Requests for managing explicit grouping key. + * - Stateful processor requests. + * - Requests for managing state variables (e.g. valueState). + */ +class TransformWithStateInPandasStateServer( + private val stateServerSocket: ServerSocket, + private val statefulProcessorHandle: StatefulProcessorHandleImpl, + private val groupingKeySchema: StructType, + private val outputStreamForTest: DataOutputStream = null, + private val valueStateMapForTest: mutable.HashMap[String, ValueState[Row]] = null) + extends Runnable with Logging { + private var inputStream: DataInputStream = _ + private var outputStream: DataOutputStream = outputStreamForTest Review Comment: outputStreamForTest <= is this really used? We always assign the output stream from run() ########## sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala: ########## @@ -0,0 +1,241 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream, DataOutputStream, EOFException} +import java.net.ServerSocket + +import scala.collection.mutable + +import com.google.protobuf.ByteString + +import org.apache.spark.internal.{Logging, LogKeys, MDC} +import org.apache.spark.sql.{Encoders, Row} +import org.apache.spark.sql.api.python.PythonSQLUtils +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, StatefulProcessorHandleImpl, StatefulProcessorHandleState} +import org.apache.spark.sql.execution.streaming.state.StateMessage.{HandleState, ImplicitGroupingKeyRequest, StatefulProcessorCall, StateRequest, StateResponse, StateVariableRequest, ValueStateCall} +import org.apache.spark.sql.streaming.ValueState +import org.apache.spark.sql.types.StructType + +/** + * This class is used to handle the state requests from the Python side. It runs on a separate + * thread spawned by TransformWithStateInPandasStateRunner per task. It opens a dedicated socket + * to process/transfer state related info which is shut down when task finishes or there's an error + * on opening the socket. It run It processes following state requests and return responses to the + * Python side: + * - Requests for managing explicit grouping key. + * - Stateful processor requests. + * - Requests for managing state variables (e.g. valueState). + */ +class TransformWithStateInPandasStateServer( + private val stateServerSocket: ServerSocket, + private val statefulProcessorHandle: StatefulProcessorHandleImpl, + private val groupingKeySchema: StructType, + private val outputStreamForTest: DataOutputStream = null, + private val valueStateMapForTest: mutable.HashMap[String, ValueState[Row]] = null) + extends Runnable with Logging { + private var inputStream: DataInputStream = _ + private var outputStream: DataOutputStream = outputStreamForTest + private val valueStates = if (valueStateMapForTest != null) { + valueStateMapForTest + } else { + new mutable.HashMap[String, ValueState[Row]]() + } + + def run(): Unit = { + val listeningSocket = stateServerSocket.accept() + inputStream = new DataInputStream( + new BufferedInputStream(listeningSocket.getInputStream)) + outputStream = new DataOutputStream( + new BufferedOutputStream(listeningSocket.getOutputStream) + ) + + while (listeningSocket.isConnected && + statefulProcessorHandle.getHandleState != StatefulProcessorHandleState.CLOSED) { + try { + val version = inputStream.readInt() + if (version != -1) { + assert(version == 0) + val message = parseProtoMessage() + handleRequest(message) + outputStream.flush() + } + } catch { + case _: EOFException => + logWarning(log"No more data to read from the socket") + statefulProcessorHandle.setHandleState(StatefulProcessorHandleState.CLOSED) + return + case e: Exception => + logError(log"Error reading message: ${MDC(LogKeys.ERROR, e.getMessage)}", e) + sendResponse(1, e.getMessage) + outputStream.flush() + statefulProcessorHandle.setHandleState(StatefulProcessorHandleState.CLOSED) + return + } + } + logInfo(log"Done from the state server thread") + } + + private def parseProtoMessage(): StateRequest = { + val messageLen = inputStream.readInt() + val messageBytes = new Array[Byte](messageLen) + inputStream.read(messageBytes) + StateRequest.parseFrom(ByteString.copyFrom(messageBytes)) + } + + private def handleRequest(message: StateRequest): Unit = { + message.getMethodCase match { + case StateRequest.MethodCase.IMPLICITGROUPINGKEYREQUEST => + handleImplicitGroupingKeyRequest(message.getImplicitGroupingKeyRequest) + case StateRequest.MethodCase.STATEFULPROCESSORCALL => + handleStatefulProcessorCall(message.getStatefulProcessorCall) + case StateRequest.MethodCase.STATEVARIABLEREQUEST => + handleStateVariableRequest(message.getStateVariableRequest) + case _ => + throw new IllegalArgumentException("Invalid method call") + } + } + + private def handleImplicitGroupingKeyRequest(message: ImplicitGroupingKeyRequest): Unit = { + message.getMethodCase match { + case ImplicitGroupingKeyRequest.MethodCase.SETIMPLICITKEY => + val keyBytes = message.getSetImplicitKey.getKey.toByteArray + // The key row is serialized as a byte array, we need to convert it back to a Row + val keyRow = PythonSQLUtils.toJVMRow(keyBytes, groupingKeySchema, + ExpressionEncoder(groupingKeySchema).resolveAndBind().createDeserializer()) + ImplicitGroupingKeyTracker.setImplicitKey(keyRow) + sendResponse(0) + case ImplicitGroupingKeyRequest.MethodCase.REMOVEIMPLICITKEY => + ImplicitGroupingKeyTracker.removeImplicitKey() + sendResponse(0) + case _ => + throw new IllegalArgumentException("Invalid method call") + } + } + + private[sql] def handleStatefulProcessorCall(message: StatefulProcessorCall): Unit = { + message.getMethodCase match { + case StatefulProcessorCall.MethodCase.SETHANDLESTATE => + val requestedState = message.getSetHandleState.getState + requestedState match { + case HandleState.CREATED => + logInfo(log"set handle state to Created") + statefulProcessorHandle.setHandleState(StatefulProcessorHandleState.CREATED) + case HandleState.INITIALIZED => + logInfo(log"set handle state to Initialized") + statefulProcessorHandle.setHandleState(StatefulProcessorHandleState.INITIALIZED) + case HandleState.CLOSED => + logInfo(log"set handle state to Closed") + statefulProcessorHandle.setHandleState(StatefulProcessorHandleState.CLOSED) + case _ => + } + sendResponse(0) + case StatefulProcessorCall.MethodCase.GETVALUESTATE => + val stateName = message.getGetValueState.getStateName + val schema = message.getGetValueState.getSchema + initializeValueState(stateName, schema) + case _ => + throw new IllegalArgumentException("Invalid method call") + } + } + + private def handleStateVariableRequest(message: StateVariableRequest): Unit = { + message.getMethodCase match { + case StateVariableRequest.MethodCase.VALUESTATECALL => + handleValueStateRequest(message.getValueStateCall) + case _ => + throw new IllegalArgumentException("Invalid method call") + } + } + + private[sql] def handleValueStateRequest(message: ValueStateCall): Unit = { + val stateName = message.getStateName + message.getMethodCase match { + case ValueStateCall.MethodCase.EXISTS => + if (valueStates.contains(stateName) && valueStates(stateName).exists()) { + sendResponse(0) + } else { + sendResponse(1, s"state $stateName doesn't exist") + } + case ValueStateCall.MethodCase.GET => + if (valueStates.contains(stateName)) { + val valueOption = valueStates(stateName).getOption() + if (valueOption.isDefined) { + sendResponse(0) + // Serialize the value row as a byte array + val valueBytes = PythonSQLUtils.toPyRow(valueOption.get) + outputStream.writeInt(valueBytes.length) + outputStream.write(valueBytes) + } else { + logWarning(log"state ${MDC(LogKeys.STATE_NAME, stateName)} doesn't exist") + sendResponse(1, s"state $stateName doesn't exist") + } + } else { + logWarning(log"state ${MDC(LogKeys.STATE_NAME, stateName)} doesn't exist") + sendResponse(1, s"state $stateName doesn't exist") + } + case ValueStateCall.MethodCase.VALUESTATEUPDATE => + val byteArray = message.getValueStateUpdate.getValue.toByteArray + val schema = StructType.fromString(message.getValueStateUpdate.getSchema) Review Comment: Any reason we allow schema to be presented at every update? What is the expectation of the behavior if the schema differs from the initialization vs updating value? ########## python/pyspark/sql/pandas/serializers.py: ########## @@ -1116,3 +1123,70 @@ def init_stream_yield_batches(batches): batches_to_write = init_stream_yield_batches(serialize_batches()) return ArrowStreamSerializer.dump_stream(self, batches_to_write, stream) + + +class TransformWithStateInPandasSerializer(ArrowStreamPandasUDFSerializer): + """ + Serializer used by Python worker to evaluate UDF for + :meth:`pyspark.sql.GroupedData.transformWithStateInPandasSerializer`. + + Parameters + ---------- + timezone : str + A timezone to respect when handling timestamp values + safecheck : bool + If True, conversion from Arrow to Pandas checks for overflow/truncation + assign_cols_by_name : bool + If True, then Pandas DataFrames will get columns by name + arrow_max_records_per_batch : int + Limit of the number of records that can be written to a single ArrowRecordBatch in memory. + """ + + def __init__(self, timezone, safecheck, assign_cols_by_name, arrow_max_records_per_batch): + super(TransformWithStateInPandasSerializer, self).__init__( + timezone, safecheck, assign_cols_by_name + ) + self.arrow_max_records_per_batch = arrow_max_records_per_batch + self.key_offsets = None + + def load_stream(self, stream): + """ + Read ArrowRecordBatches from stream, deserialize them to populate a list of data chunk, and + convert the data into a list of pandas.Series. + + Please refer the doc of inner function `generate_data_batches` for more details how + this function works in overall. + """ + import pyarrow as pa + + def generate_data_batches(batches): + """ + Deserialize ArrowRecordBatches and return a generator of pandas.Series list. + + The deserialization logic assumes that Arrow RecordBatches contain the data with the + ordering that data chunks for same grouping key will appear sequentially. + + This function must avoid materializing multiple Arrow RecordBatches into memory at the + same time. And data chunks from the same grouping key should appear sequentially. + """ + for batch in batches: + data_pandas = [ + self.arrow_to_pandas(c) for c in pa.Table.from_batches([batch]).itercolumns() + ] + key_series = [data_pandas[o] for o in self.key_offsets] + batch_key = tuple(s[0] for s in key_series) + yield (batch_key, data_pandas) + + _batches = super(ArrowStreamPandasSerializer, self).load_stream(stream) + data_batches = generate_data_batches(_batches) + + for k, g in groupby(data_batches, key=lambda x: x[0]): + yield (k, g) + + def dump_stream(self, iterator, stream): + """ + Read through an iterator of (iterator of pandas DataFram), serialize them to Arrow Review Comment: nit: DataFram`e` ########## sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala: ########## @@ -0,0 +1,241 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream, DataOutputStream, EOFException} +import java.net.ServerSocket + +import scala.collection.mutable + +import com.google.protobuf.ByteString + +import org.apache.spark.internal.{Logging, LogKeys, MDC} +import org.apache.spark.sql.{Encoders, Row} +import org.apache.spark.sql.api.python.PythonSQLUtils +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, StatefulProcessorHandleImpl, StatefulProcessorHandleState} +import org.apache.spark.sql.execution.streaming.state.StateMessage.{HandleState, ImplicitGroupingKeyRequest, StatefulProcessorCall, StateRequest, StateResponse, StateVariableRequest, ValueStateCall} +import org.apache.spark.sql.streaming.ValueState +import org.apache.spark.sql.types.StructType + +/** + * This class is used to handle the state requests from the Python side. It runs on a separate + * thread spawned by TransformWithStateInPandasStateRunner per task. It opens a dedicated socket + * to process/transfer state related info which is shut down when task finishes or there's an error + * on opening the socket. It run It processes following state requests and return responses to the + * Python side: + * - Requests for managing explicit grouping key. + * - Stateful processor requests. + * - Requests for managing state variables (e.g. valueState). + */ +class TransformWithStateInPandasStateServer( + private val stateServerSocket: ServerSocket, + private val statefulProcessorHandle: StatefulProcessorHandleImpl, + private val groupingKeySchema: StructType, + private val outputStreamForTest: DataOutputStream = null, + private val valueStateMapForTest: mutable.HashMap[String, ValueState[Row]] = null) + extends Runnable with Logging { + private var inputStream: DataInputStream = _ + private var outputStream: DataOutputStream = outputStreamForTest + private val valueStates = if (valueStateMapForTest != null) { + valueStateMapForTest + } else { + new mutable.HashMap[String, ValueState[Row]]() + } + + def run(): Unit = { + val listeningSocket = stateServerSocket.accept() + inputStream = new DataInputStream( + new BufferedInputStream(listeningSocket.getInputStream)) + outputStream = new DataOutputStream( + new BufferedOutputStream(listeningSocket.getOutputStream) + ) + + while (listeningSocket.isConnected && + statefulProcessorHandle.getHandleState != StatefulProcessorHandleState.CLOSED) { + try { + val version = inputStream.readInt() + if (version != -1) { + assert(version == 0) + val message = parseProtoMessage() + handleRequest(message) + outputStream.flush() + } + } catch { + case _: EOFException => + logWarning(log"No more data to read from the socket") + statefulProcessorHandle.setHandleState(StatefulProcessorHandleState.CLOSED) + return + case e: Exception => + logError(log"Error reading message: ${MDC(LogKeys.ERROR, e.getMessage)}", e) + sendResponse(1, e.getMessage) + outputStream.flush() + statefulProcessorHandle.setHandleState(StatefulProcessorHandleState.CLOSED) + return + } + } + logInfo(log"Done from the state server thread") + } + + private def parseProtoMessage(): StateRequest = { + val messageLen = inputStream.readInt() + val messageBytes = new Array[Byte](messageLen) + inputStream.read(messageBytes) + StateRequest.parseFrom(ByteString.copyFrom(messageBytes)) + } + + private def handleRequest(message: StateRequest): Unit = { + message.getMethodCase match { + case StateRequest.MethodCase.IMPLICITGROUPINGKEYREQUEST => + handleImplicitGroupingKeyRequest(message.getImplicitGroupingKeyRequest) + case StateRequest.MethodCase.STATEFULPROCESSORCALL => + handleStatefulProcessorCall(message.getStatefulProcessorCall) + case StateRequest.MethodCase.STATEVARIABLEREQUEST => + handleStateVariableRequest(message.getStateVariableRequest) + case _ => + throw new IllegalArgumentException("Invalid method call") + } + } + + private def handleImplicitGroupingKeyRequest(message: ImplicitGroupingKeyRequest): Unit = { + message.getMethodCase match { + case ImplicitGroupingKeyRequest.MethodCase.SETIMPLICITKEY => + val keyBytes = message.getSetImplicitKey.getKey.toByteArray + // The key row is serialized as a byte array, we need to convert it back to a Row + val keyRow = PythonSQLUtils.toJVMRow(keyBytes, groupingKeySchema, + ExpressionEncoder(groupingKeySchema).resolveAndBind().createDeserializer()) Review Comment: Can this be initialized only once at the initialization phase of server? ########## python/pyspark/worker.py: ########## @@ -832,6 +849,11 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index, profil return args_offsets, wrap_grouped_map_arrow_udf(func, return_type, argspec, runner_conf) elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: return args_offsets, wrap_grouped_map_pandas_udf_with_state(func, return_type) + elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF: + argspec = inspect.getfullargspec(chained_func) # signature was lost when wrapping it Review Comment: doesn't seem to be used anywhere, blindly copied? ########## sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala: ########## @@ -0,0 +1,241 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream, DataOutputStream, EOFException} +import java.net.ServerSocket + +import scala.collection.mutable + +import com.google.protobuf.ByteString + +import org.apache.spark.internal.{Logging, LogKeys, MDC} +import org.apache.spark.sql.{Encoders, Row} +import org.apache.spark.sql.api.python.PythonSQLUtils +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, StatefulProcessorHandleImpl, StatefulProcessorHandleState} +import org.apache.spark.sql.execution.streaming.state.StateMessage.{HandleState, ImplicitGroupingKeyRequest, StatefulProcessorCall, StateRequest, StateResponse, StateVariableRequest, ValueStateCall} +import org.apache.spark.sql.streaming.ValueState +import org.apache.spark.sql.types.StructType + +/** + * This class is used to handle the state requests from the Python side. It runs on a separate + * thread spawned by TransformWithStateInPandasStateRunner per task. It opens a dedicated socket + * to process/transfer state related info which is shut down when task finishes or there's an error + * on opening the socket. It run It processes following state requests and return responses to the + * Python side: + * - Requests for managing explicit grouping key. + * - Stateful processor requests. + * - Requests for managing state variables (e.g. valueState). + */ +class TransformWithStateInPandasStateServer( + private val stateServerSocket: ServerSocket, + private val statefulProcessorHandle: StatefulProcessorHandleImpl, + private val groupingKeySchema: StructType, + private val outputStreamForTest: DataOutputStream = null, + private val valueStateMapForTest: mutable.HashMap[String, ValueState[Row]] = null) + extends Runnable with Logging { + private var inputStream: DataInputStream = _ + private var outputStream: DataOutputStream = outputStreamForTest + private val valueStates = if (valueStateMapForTest != null) { + valueStateMapForTest + } else { + new mutable.HashMap[String, ValueState[Row]]() + } + + def run(): Unit = { + val listeningSocket = stateServerSocket.accept() + inputStream = new DataInputStream( + new BufferedInputStream(listeningSocket.getInputStream)) + outputStream = new DataOutputStream( + new BufferedOutputStream(listeningSocket.getOutputStream) + ) + + while (listeningSocket.isConnected && + statefulProcessorHandle.getHandleState != StatefulProcessorHandleState.CLOSED) { + try { + val version = inputStream.readInt() + if (version != -1) { + assert(version == 0) + val message = parseProtoMessage() + handleRequest(message) + outputStream.flush() + } + } catch { + case _: EOFException => + logWarning(log"No more data to read from the socket") + statefulProcessorHandle.setHandleState(StatefulProcessorHandleState.CLOSED) + return + case e: Exception => + logError(log"Error reading message: ${MDC(LogKeys.ERROR, e.getMessage)}", e) + sendResponse(1, e.getMessage) + outputStream.flush() + statefulProcessorHandle.setHandleState(StatefulProcessorHandleState.CLOSED) + return + } + } + logInfo(log"Done from the state server thread") + } + + private def parseProtoMessage(): StateRequest = { + val messageLen = inputStream.readInt() + val messageBytes = new Array[Byte](messageLen) + inputStream.read(messageBytes) + StateRequest.parseFrom(ByteString.copyFrom(messageBytes)) + } + + private def handleRequest(message: StateRequest): Unit = { + message.getMethodCase match { + case StateRequest.MethodCase.IMPLICITGROUPINGKEYREQUEST => + handleImplicitGroupingKeyRequest(message.getImplicitGroupingKeyRequest) + case StateRequest.MethodCase.STATEFULPROCESSORCALL => + handleStatefulProcessorCall(message.getStatefulProcessorCall) + case StateRequest.MethodCase.STATEVARIABLEREQUEST => + handleStateVariableRequest(message.getStateVariableRequest) + case _ => + throw new IllegalArgumentException("Invalid method call") + } + } + + private def handleImplicitGroupingKeyRequest(message: ImplicitGroupingKeyRequest): Unit = { + message.getMethodCase match { + case ImplicitGroupingKeyRequest.MethodCase.SETIMPLICITKEY => + val keyBytes = message.getSetImplicitKey.getKey.toByteArray + // The key row is serialized as a byte array, we need to convert it back to a Row + val keyRow = PythonSQLUtils.toJVMRow(keyBytes, groupingKeySchema, + ExpressionEncoder(groupingKeySchema).resolveAndBind().createDeserializer()) + ImplicitGroupingKeyTracker.setImplicitKey(keyRow) + sendResponse(0) + case ImplicitGroupingKeyRequest.MethodCase.REMOVEIMPLICITKEY => + ImplicitGroupingKeyTracker.removeImplicitKey() + sendResponse(0) + case _ => + throw new IllegalArgumentException("Invalid method call") + } + } + + private[sql] def handleStatefulProcessorCall(message: StatefulProcessorCall): Unit = { + message.getMethodCase match { + case StatefulProcessorCall.MethodCase.SETHANDLESTATE => + val requestedState = message.getSetHandleState.getState + requestedState match { + case HandleState.CREATED => + logInfo(log"set handle state to Created") + statefulProcessorHandle.setHandleState(StatefulProcessorHandleState.CREATED) + case HandleState.INITIALIZED => + logInfo(log"set handle state to Initialized") + statefulProcessorHandle.setHandleState(StatefulProcessorHandleState.INITIALIZED) + case HandleState.CLOSED => + logInfo(log"set handle state to Closed") + statefulProcessorHandle.setHandleState(StatefulProcessorHandleState.CLOSED) + case _ => + } + sendResponse(0) + case StatefulProcessorCall.MethodCase.GETVALUESTATE => + val stateName = message.getGetValueState.getStateName + val schema = message.getGetValueState.getSchema + initializeValueState(stateName, schema) + case _ => + throw new IllegalArgumentException("Invalid method call") + } + } + + private def handleStateVariableRequest(message: StateVariableRequest): Unit = { + message.getMethodCase match { + case StateVariableRequest.MethodCase.VALUESTATECALL => + handleValueStateRequest(message.getValueStateCall) + case _ => + throw new IllegalArgumentException("Invalid method call") + } + } + + private[sql] def handleValueStateRequest(message: ValueStateCall): Unit = { + val stateName = message.getStateName + message.getMethodCase match { + case ValueStateCall.MethodCase.EXISTS => + if (valueStates.contains(stateName) && valueStates(stateName).exists()) { Review Comment: I guess we want to distinguish the case of "no value state is defined for the state variable name" vs "the value state is defined but not having a value yet". Applies to all method cases. ########## sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala: ########## @@ -0,0 +1,241 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream, DataOutputStream, EOFException} +import java.net.ServerSocket + +import scala.collection.mutable + +import com.google.protobuf.ByteString + +import org.apache.spark.internal.{Logging, LogKeys, MDC} +import org.apache.spark.sql.{Encoders, Row} +import org.apache.spark.sql.api.python.PythonSQLUtils +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, StatefulProcessorHandleImpl, StatefulProcessorHandleState} +import org.apache.spark.sql.execution.streaming.state.StateMessage.{HandleState, ImplicitGroupingKeyRequest, StatefulProcessorCall, StateRequest, StateResponse, StateVariableRequest, ValueStateCall} +import org.apache.spark.sql.streaming.ValueState +import org.apache.spark.sql.types.StructType + +/** + * This class is used to handle the state requests from the Python side. It runs on a separate + * thread spawned by TransformWithStateInPandasStateRunner per task. It opens a dedicated socket + * to process/transfer state related info which is shut down when task finishes or there's an error + * on opening the socket. It run It processes following state requests and return responses to the + * Python side: + * - Requests for managing explicit grouping key. + * - Stateful processor requests. + * - Requests for managing state variables (e.g. valueState). + */ +class TransformWithStateInPandasStateServer( + private val stateServerSocket: ServerSocket, + private val statefulProcessorHandle: StatefulProcessorHandleImpl, + private val groupingKeySchema: StructType, + private val outputStreamForTest: DataOutputStream = null, + private val valueStateMapForTest: mutable.HashMap[String, ValueState[Row]] = null) + extends Runnable with Logging { + private var inputStream: DataInputStream = _ + private var outputStream: DataOutputStream = outputStreamForTest + private val valueStates = if (valueStateMapForTest != null) { + valueStateMapForTest + } else { + new mutable.HashMap[String, ValueState[Row]]() + } + + def run(): Unit = { + val listeningSocket = stateServerSocket.accept() + inputStream = new DataInputStream( + new BufferedInputStream(listeningSocket.getInputStream)) + outputStream = new DataOutputStream( + new BufferedOutputStream(listeningSocket.getOutputStream) + ) + + while (listeningSocket.isConnected && + statefulProcessorHandle.getHandleState != StatefulProcessorHandleState.CLOSED) { + try { + val version = inputStream.readInt() + if (version != -1) { + assert(version == 0) + val message = parseProtoMessage() + handleRequest(message) + outputStream.flush() + } + } catch { + case _: EOFException => + logWarning(log"No more data to read from the socket") + statefulProcessorHandle.setHandleState(StatefulProcessorHandleState.CLOSED) + return + case e: Exception => + logError(log"Error reading message: ${MDC(LogKeys.ERROR, e.getMessage)}", e) + sendResponse(1, e.getMessage) + outputStream.flush() + statefulProcessorHandle.setHandleState(StatefulProcessorHandleState.CLOSED) + return + } + } + logInfo(log"Done from the state server thread") + } + + private def parseProtoMessage(): StateRequest = { + val messageLen = inputStream.readInt() + val messageBytes = new Array[Byte](messageLen) + inputStream.read(messageBytes) + StateRequest.parseFrom(ByteString.copyFrom(messageBytes)) + } + + private def handleRequest(message: StateRequest): Unit = { + message.getMethodCase match { + case StateRequest.MethodCase.IMPLICITGROUPINGKEYREQUEST => + handleImplicitGroupingKeyRequest(message.getImplicitGroupingKeyRequest) + case StateRequest.MethodCase.STATEFULPROCESSORCALL => + handleStatefulProcessorCall(message.getStatefulProcessorCall) + case StateRequest.MethodCase.STATEVARIABLEREQUEST => + handleStateVariableRequest(message.getStateVariableRequest) + case _ => + throw new IllegalArgumentException("Invalid method call") + } + } + + private def handleImplicitGroupingKeyRequest(message: ImplicitGroupingKeyRequest): Unit = { + message.getMethodCase match { + case ImplicitGroupingKeyRequest.MethodCase.SETIMPLICITKEY => + val keyBytes = message.getSetImplicitKey.getKey.toByteArray + // The key row is serialized as a byte array, we need to convert it back to a Row + val keyRow = PythonSQLUtils.toJVMRow(keyBytes, groupingKeySchema, + ExpressionEncoder(groupingKeySchema).resolveAndBind().createDeserializer()) + ImplicitGroupingKeyTracker.setImplicitKey(keyRow) + sendResponse(0) + case ImplicitGroupingKeyRequest.MethodCase.REMOVEIMPLICITKEY => + ImplicitGroupingKeyTracker.removeImplicitKey() + sendResponse(0) + case _ => + throw new IllegalArgumentException("Invalid method call") + } + } + + private[sql] def handleStatefulProcessorCall(message: StatefulProcessorCall): Unit = { + message.getMethodCase match { + case StatefulProcessorCall.MethodCase.SETHANDLESTATE => + val requestedState = message.getSetHandleState.getState + requestedState match { + case HandleState.CREATED => + logInfo(log"set handle state to Created") + statefulProcessorHandle.setHandleState(StatefulProcessorHandleState.CREATED) + case HandleState.INITIALIZED => + logInfo(log"set handle state to Initialized") + statefulProcessorHandle.setHandleState(StatefulProcessorHandleState.INITIALIZED) + case HandleState.CLOSED => + logInfo(log"set handle state to Closed") + statefulProcessorHandle.setHandleState(StatefulProcessorHandleState.CLOSED) + case _ => + } + sendResponse(0) + case StatefulProcessorCall.MethodCase.GETVALUESTATE => + val stateName = message.getGetValueState.getStateName + val schema = message.getGetValueState.getSchema + initializeValueState(stateName, schema) + case _ => + throw new IllegalArgumentException("Invalid method call") + } + } + + private def handleStateVariableRequest(message: StateVariableRequest): Unit = { + message.getMethodCase match { + case StateVariableRequest.MethodCase.VALUESTATECALL => + handleValueStateRequest(message.getValueStateCall) + case _ => + throw new IllegalArgumentException("Invalid method call") + } + } + + private[sql] def handleValueStateRequest(message: ValueStateCall): Unit = { + val stateName = message.getStateName + message.getMethodCase match { + case ValueStateCall.MethodCase.EXISTS => + if (valueStates.contains(stateName) && valueStates(stateName).exists()) { + sendResponse(0) + } else { + sendResponse(1, s"state $stateName doesn't exist") + } + case ValueStateCall.MethodCase.GET => + if (valueStates.contains(stateName)) { + val valueOption = valueStates(stateName).getOption() + if (valueOption.isDefined) { + sendResponse(0) + // Serialize the value row as a byte array Review Comment: Shall we ship this with proto message? This is the only thing we communicate in this server without using proto message - this is a single exception but it will be repeated for another kinds of state. Let's not push people to memorize the protocol. ########## sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala: ########## @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.python + +import java.io.DataOutputStream +import java.net.ServerSocket + +import scala.collection.mutable + +import org.mockito.ArgumentMatchers.{any, argThat} +import org.mockito.Mockito.{mock, times, verify, when} +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.{Encoder, Row} +import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema +import org.apache.spark.sql.execution.streaming.{StatefulProcessorHandleImpl, StatefulProcessorHandleState} +import org.apache.spark.sql.execution.streaming.state.StateMessage.{Clear, Exists, Get, HandleState, SetHandleState, StateCallCommand, StatefulProcessorCall, ValueStateCall} +import org.apache.spark.sql.streaming.ValueState +import org.apache.spark.sql.types.StructType + +class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with BeforeAndAfterEach { + val valueStateName = "test" + var statefulProcessorHandle: StatefulProcessorHandleImpl = _ + var outputStream: DataOutputStream = _ + var valueState: ValueState[Row] = _ + var stateServer: TransformWithStateInPandasStateServer = _ + + override def beforeEach(): Unit = { + val serverSocket = mock(classOf[ServerSocket]) + statefulProcessorHandle = mock(classOf[StatefulProcessorHandleImpl]) + val groupingKeySchema = StructType(Seq()) + outputStream = mock(classOf[DataOutputStream]) + valueState = mock(classOf[ValueState[Row]]) + val valueStateMap = mutable.HashMap[String, ValueState[Row]](valueStateName -> valueState) + stateServer = new TransformWithStateInPandasStateServer(serverSocket, + statefulProcessorHandle, groupingKeySchema, outputStream, valueStateMap) + } + + test("set handle state") { + val message = StatefulProcessorCall.newBuilder().setSetHandleState( + SetHandleState.newBuilder().setState(HandleState.CREATED).build()).build() + stateServer.handleStatefulProcessorCall(message) + verify(statefulProcessorHandle).setHandleState(StatefulProcessorHandleState.CREATED) + verify(outputStream).writeInt(0) + } + + test("get value state") { + val message = StatefulProcessorCall.newBuilder().setGetValueState( + StateCallCommand.newBuilder() + .setStateName("newName") + .setSchema("StructType(List(StructField(value,IntegerType,true)))")).build() + stateServer.handleStatefulProcessorCall(message) + verify(statefulProcessorHandle).getValueState[Row](any[String], any[Encoder[Row]]) + verify(outputStream).writeInt(0) + } + + test("value state exists") { + val message = ValueStateCall.newBuilder().setStateName(valueStateName) + .setExists(Exists.newBuilder().build()).build() + stateServer.handleValueStateRequest(message) + verify(valueState).exists() + } + + test("value state get") { + val message = ValueStateCall.newBuilder().setStateName(valueStateName) + .setGet(Get.newBuilder().build()).build() + val schema = new StructType().add("value", "int") + when(valueState.getOption()).thenReturn(Some(new GenericRowWithSchema(Array(1), schema))) + stateServer.handleValueStateRequest(message) + verify(valueState).getOption() + verify(outputStream).writeInt(0) + } + + test("value state get - not exist") { + val message = ValueStateCall.newBuilder().setStateName(valueStateName) + .setGet(Get.newBuilder().build()).build() + when(valueState.getOption()).thenReturn(None) + stateServer.handleValueStateRequest(message) + verify(valueState).getOption() + verify(outputStream).writeInt(argThat((x: Int) => x > 0)) + } + + test("value state get - not initialized") { + val nonExistMessage = ValueStateCall.newBuilder().setStateName("nonExist") + .setGet(Get.newBuilder().build()).build() + stateServer.handleValueStateRequest(nonExistMessage) + verify(valueState, times(0)).getOption() + verify(outputStream).writeInt(argThat((x: Int) => x > 0)) + } + + test("value state clear") { Review Comment: Looks like missing the test for set? ########## python/pyspark/sql/pandas/group_ops.py: ########## @@ -358,6 +364,172 @@ def applyInPandasWithState( ) return DataFrame(jdf, self.session) + def transformWithStateInPandas( + self, + statefulProcessor: StatefulProcessor, + outputStructType: Union[StructType, str], + outputMode: str, + timeMode: str, + ) -> DataFrame: + """ + Invokes methods defined in the stateful processor used in arbitrary state API v2. It + requires protobuf, pandas and pyarrow as dependencies to process input/state data. We + allow the user to act on per-group set of input rows along with keyed state and the user + can choose to output/return 0 or more rows. + + For a streaming dataframe, we will repeatedly invoke the interface methods for new rows + in each trigger and the user's state/state variables will be stored persistently across + invocations. + + The `statefulProcessor` should be a Python class that implements the interface defined in + :class:`StatefulProcessor`. + + The `outputStructType` should be a :class:`StructType` describing the schema of all + elements in the returned value, `pandas.DataFrame`. The column labels of all elements in + returned `pandas.DataFrame` must either match the field names in the defined schema if + specified as strings, or match the field data types by position if not strings, + e.g. integer indices. + + The size of each `pandas.DataFrame` in both the input and output can be arbitrary. The + number of `pandas.DataFrame` in both the input and output can also be arbitrary. + + .. versionadded:: 4.0.0 + + Parameters + ---------- + statefulProcessor : :class:`pyspark.sql.streaming.stateful_processor.StatefulProcessor` + Instance of StatefulProcessor whose functions will be invoked by the operator. + outputStructType : :class:`pyspark.sql.types.DataType` or str + The type of the output records. The value can be either a + :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. + outputMode : str + The output mode of the stateful processor. + timeMode : str + The time mode semantics of the stateful processor for timers and TTL. + + Examples + -------- + >>> from typing import Iterator + ... + >>> import pandas as pd # doctest: +SKIP + ... + >>> from pyspark.sql import Row + >>> from pyspark.sql.functions import col, split + >>> from pyspark.sql.streaming import StatefulProcessor, StatefulProcessorHandle + >>> from pyspark.sql.types import IntegerType, LongType, StringType, StructField, StructType + ... + >>> spark.conf.set("spark.sql.streaming.stateStore.providerClass", + ... "org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider") + ... # Below is a simple example of a stateful processor that counts the number of violations + ... # for a set of temperature sensors. A violation is defined when the temperature is above + ... # 100. + ... # The input data is a DataFrame with the following schema: + ... # `id: string, temperature: long`. + ... # The output schema and state schema are defined as below. + >>> output_schema = StructType([ + ... StructField("id", StringType(), True), + ... StructField("count", IntegerType(), True) + ... ]) + >>> state_schema = StructType([ + ... StructField("value", IntegerType(), True) + ... ]) + >>> class SimpleStatefulProcessor(StatefulProcessor): + ... def init(self, handle: StatefulProcessorHandle): + ... self.num_violations_state = handle.getValueState("numViolations", state_schema) + ... + ... def handleInputRows(self, key, rows): + ... new_violations = 0 + ... count = 0 + ... exists = self.num_violations_state.exists() + ... if exists: + ... existing_violations_pdf = self.num_violations_state.get() + ... existing_violations = existing_violations_pdf.get("value")[0] + ... else: + ... existing_violations = 0 + ... for pdf in rows: + ... pdf_count = pdf.count() + ... count += pdf_count.get('temperature') + ... violations_pdf = pdf.loc[pdf['temperature'] > 100] + ... new_violations += violations_pdf.count().get('temperature') + ... updated_violations = new_violations + existing_violations + ... self.num_violations_state.update((updated_violations,)) + ... yield pd.DataFrame({'id': key, 'count': count}) + ... + ... def close(self) -> None: + ... pass + + Input DataFrame: + +---+-----------+ + | id|temperature| + +---+-----------+ + | 0| 123| + | 0| 23| + | 1| 33| + | 1| 188| + | 1| 88| + +---+-----------+ + + >>> df.groupBy("value").transformWithStateInPandas(statefulProcessor = + ... SimpleStatefulProcessor(), outputStructType=output_schema, outputMode="Update", + ... timeMode="None") # doctest: +SKIP + + Output DataFrame: + +---+-----+ + | id|count| + +---+-----+ + | 0| 2| Review Comment: Isn't the desired output (0, 1), (1, 1)? ########## python/pyspark/sql/pandas/group_ops.py: ########## @@ -358,6 +364,172 @@ def applyInPandasWithState( ) return DataFrame(jdf, self.session) + def transformWithStateInPandas( + self, + statefulProcessor: StatefulProcessor, + outputStructType: Union[StructType, str], + outputMode: str, + timeMode: str, + ) -> DataFrame: + """ + Invokes methods defined in the stateful processor used in arbitrary state API v2. It + requires protobuf, pandas and pyarrow as dependencies to process input/state data. We + allow the user to act on per-group set of input rows along with keyed state and the user + can choose to output/return 0 or more rows. + + For a streaming dataframe, we will repeatedly invoke the interface methods for new rows + in each trigger and the user's state/state variables will be stored persistently across + invocations. + + The `statefulProcessor` should be a Python class that implements the interface defined in + :class:`StatefulProcessor`. + + The `outputStructType` should be a :class:`StructType` describing the schema of all + elements in the returned value, `pandas.DataFrame`. The column labels of all elements in + returned `pandas.DataFrame` must either match the field names in the defined schema if + specified as strings, or match the field data types by position if not strings, + e.g. integer indices. + + The size of each `pandas.DataFrame` in both the input and output can be arbitrary. The + number of `pandas.DataFrame` in both the input and output can also be arbitrary. + + .. versionadded:: 4.0.0 + + Parameters + ---------- + statefulProcessor : :class:`pyspark.sql.streaming.stateful_processor.StatefulProcessor` + Instance of StatefulProcessor whose functions will be invoked by the operator. + outputStructType : :class:`pyspark.sql.types.DataType` or str + The type of the output records. The value can be either a + :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. + outputMode : str + The output mode of the stateful processor. + timeMode : str + The time mode semantics of the stateful processor for timers and TTL. + + Examples + -------- + >>> from typing import Iterator + ... + >>> import pandas as pd # doctest: +SKIP + ... + >>> from pyspark.sql import Row + >>> from pyspark.sql.functions import col, split + >>> from pyspark.sql.streaming import StatefulProcessor, StatefulProcessorHandle + >>> from pyspark.sql.types import IntegerType, LongType, StringType, StructField, StructType + ... + >>> spark.conf.set("spark.sql.streaming.stateStore.providerClass", + ... "org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider") + ... # Below is a simple example of a stateful processor that counts the number of violations + ... # for a set of temperature sensors. A violation is defined when the temperature is above + ... # 100. + ... # The input data is a DataFrame with the following schema: + ... # `id: string, temperature: long`. + ... # The output schema and state schema are defined as below. + >>> output_schema = StructType([ + ... StructField("id", StringType(), True), + ... StructField("count", IntegerType(), True) + ... ]) + >>> state_schema = StructType([ + ... StructField("value", IntegerType(), True) + ... ]) + >>> class SimpleStatefulProcessor(StatefulProcessor): + ... def init(self, handle: StatefulProcessorHandle): + ... self.num_violations_state = handle.getValueState("numViolations", state_schema) + ... + ... def handleInputRows(self, key, rows): + ... new_violations = 0 + ... count = 0 + ... exists = self.num_violations_state.exists() + ... if exists: + ... existing_violations_pdf = self.num_violations_state.get() + ... existing_violations = existing_violations_pdf.get("value")[0] + ... else: + ... existing_violations = 0 + ... for pdf in rows: + ... pdf_count = pdf.count() + ... count += pdf_count.get('temperature') + ... violations_pdf = pdf.loc[pdf['temperature'] > 100] + ... new_violations += violations_pdf.count().get('temperature') + ... updated_violations = new_violations + existing_violations + ... self.num_violations_state.update((updated_violations,)) + ... yield pd.DataFrame({'id': key, 'count': count}) Review Comment: I guess the explanation is to produce the number of violations instead of the number of inputs. This doesn't follow the explanation. ########## sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasPythonRunner.scala: ########## @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io.DataOutputStream +import java.nio.file.{Files, Path} +import java.util.concurrent.TimeUnit + +import scala.concurrent.ExecutionContext + +import jnr.unixsocket.UnixServerSocketChannel +import jnr.unixsocket.UnixSocketAddress +import org.apache.arrow.vector.VectorSchemaRoot +import org.apache.arrow.vector.ipc.ArrowStreamWriter + +import org.apache.spark.TaskContext +import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions} +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.arrow +import org.apache.spark.sql.execution.arrow.ArrowWriter +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.execution.python.TransformWithStateInPandasPythonRunner.{InType, OutType} +import org.apache.spark.sql.execution.streaming.StatefulProcessorHandleImpl +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.util.ThreadUtils + +/** + * Python runner implementation for TransformWithStateInPandas. + */ +class TransformWithStateInPandasPythonRunner( + funcs: Seq[(ChainedPythonFunctions, Long)], + evalType: Int, + argOffsets: Array[Array[Int]], + _schema: StructType, + processorHandle: StatefulProcessorHandleImpl, + _timeZoneId: String, + initialWorkerConf: Map[String, String], + override val pythonMetrics: Map[String, SQLMetric], + jobArtifactUUID: Option[String], + groupingKeySchema: StructType) + extends BasePythonRunner[InType, OutType](funcs.map(_._1), evalType, argOffsets, jobArtifactUUID) + with PythonArrowInput[InType] + with BasicPythonArrowOutput + with Logging { + + private val sqlConf = SQLConf.get + private val arrowMaxRecordsPerBatch = sqlConf.arrowMaxRecordsPerBatch + + private val serverId = TransformWithStateInPandasStateServer.allocateServerId() + + private val socketPath = s"./uds_$serverId.sock" + + override protected val workerConf: Map[String, String] = initialWorkerConf + + (SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key -> arrowMaxRecordsPerBatch.toString) + + private val arrowWriter: arrow.ArrowWriter = ArrowWriter.create(root) + + // Use lazy val to initialize the fields before these are accessed in [[PythonArrowInput]]'s + // constructor. + override protected lazy val schema: StructType = _schema + override protected lazy val timeZoneId: String = _timeZoneId + override protected val errorOnDuplicatedFieldNames: Boolean = true + override protected val largeVarTypes: Boolean = sqlConf.arrowUseLargeVarTypes + + override protected def handleMetadataBeforeExec(stream: DataOutputStream): Unit = { + super.handleMetadataBeforeExec(stream) + // Also write the port number for state server + stream.writeInt(serverId) + } + + override def compute( + inputIterator: Iterator[InType], + partitionIndex: Int, + context: TaskContext): Iterator[OutType] = { + var serverChannel: UnixServerSocketChannel = null + var failed = false + try { + val socketFile = Path.of(socketPath) + Files.deleteIfExists(socketFile) + val serverAddress = new UnixSocketAddress(socketPath) + serverChannel = UnixServerSocketChannel.open() + serverChannel.socket().bind(serverAddress) + } catch { + case e: Exception => + failed = true + throw e + } finally { + if (failed) { + closeServerSocketChannelSilently(serverChannel) + } + } + + val executor = ThreadUtils.newDaemonSingleThreadExecutor("stateConnectionListenerThread") + val executionContext = ExecutionContext.fromExecutor(executor) + + executionContext.execute( + new TransformWithStateInPandasStateServer(serverChannel, processorHandle, + groupingKeySchema)) + + context.addTaskCompletionListener[Unit] { _ => + logWarning(s"completion listener called") + executor.awaitTermination(10, TimeUnit.SECONDS) + executor.shutdownNow() + val socketFile = Path.of(socketPath) + Files.deleteIfExists(socketFile) + } + + super.compute(inputIterator, partitionIndex, context) + } + + private def closeServerSocketChannelSilently(serverChannel: UnixServerSocketChannel): Unit = { + try { + logWarning(s"closing the state server socket") + serverChannel.close() + } catch { + case e: Exception => + logError(s"failed to close state server socket", e) + } + } + + override protected def writeUDF(dataOut: DataOutputStream): Unit = { + PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets, None) + } + + override protected def writeNextInputToArrowStream( + root: VectorSchemaRoot, + writer: ArrowStreamWriter, + dataOut: DataOutputStream, + inputIterator: Iterator[InType]): Boolean = { + + if (inputIterator.hasNext) { + val startData = dataOut.size() + val next = inputIterator.next() + val nextBatch = next._2 + + while (nextBatch.hasNext) { + arrowWriter.write(nextBatch.next()) + } Review Comment: Mind giving JIRA ticket(s) for the followup? Even better to put code comment mentioning JIRA ticket. ########## python/pyspark/sql/pandas/group_ops.py: ########## @@ -358,6 +364,172 @@ def applyInPandasWithState( ) return DataFrame(jdf, self.session) + def transformWithStateInPandas( + self, + statefulProcessor: StatefulProcessor, + outputStructType: Union[StructType, str], + outputMode: str, + timeMode: str, + ) -> DataFrame: + """ + Invokes methods defined in the stateful processor used in arbitrary state API v2. It + requires protobuf, pandas and pyarrow as dependencies to process input/state data. We + allow the user to act on per-group set of input rows along with keyed state and the user + can choose to output/return 0 or more rows. + + For a streaming dataframe, we will repeatedly invoke the interface methods for new rows + in each trigger and the user's state/state variables will be stored persistently across + invocations. + + The `statefulProcessor` should be a Python class that implements the interface defined in + :class:`StatefulProcessor`. + + The `outputStructType` should be a :class:`StructType` describing the schema of all + elements in the returned value, `pandas.DataFrame`. The column labels of all elements in + returned `pandas.DataFrame` must either match the field names in the defined schema if + specified as strings, or match the field data types by position if not strings, + e.g. integer indices. + + The size of each `pandas.DataFrame` in both the input and output can be arbitrary. The + number of `pandas.DataFrame` in both the input and output can also be arbitrary. + + .. versionadded:: 4.0.0 + + Parameters + ---------- + statefulProcessor : :class:`pyspark.sql.streaming.stateful_processor.StatefulProcessor` + Instance of StatefulProcessor whose functions will be invoked by the operator. + outputStructType : :class:`pyspark.sql.types.DataType` or str + The type of the output records. The value can be either a + :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. + outputMode : str + The output mode of the stateful processor. + timeMode : str + The time mode semantics of the stateful processor for timers and TTL. + + Examples + -------- + >>> from typing import Iterator + ... + >>> import pandas as pd # doctest: +SKIP + ... + >>> from pyspark.sql import Row + >>> from pyspark.sql.functions import col, split + >>> from pyspark.sql.streaming import StatefulProcessor, StatefulProcessorHandle + >>> from pyspark.sql.types import IntegerType, LongType, StringType, StructField, StructType + ... + >>> spark.conf.set("spark.sql.streaming.stateStore.providerClass", + ... "org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider") + ... # Below is a simple example of a stateful processor that counts the number of violations + ... # for a set of temperature sensors. A violation is defined when the temperature is above + ... # 100. + ... # The input data is a DataFrame with the following schema: + ... # `id: string, temperature: long`. + ... # The output schema and state schema are defined as below. + >>> output_schema = StructType([ + ... StructField("id", StringType(), True), + ... StructField("count", IntegerType(), True) + ... ]) + >>> state_schema = StructType([ + ... StructField("value", IntegerType(), True) + ... ]) + >>> class SimpleStatefulProcessor(StatefulProcessor): + ... def init(self, handle: StatefulProcessorHandle): + ... self.num_violations_state = handle.getValueState("numViolations", state_schema) + ... + ... def handleInputRows(self, key, rows): + ... new_violations = 0 + ... count = 0 + ... exists = self.num_violations_state.exists() + ... if exists: + ... existing_violations_pdf = self.num_violations_state.get() Review Comment: What's the expectation of the type of this state "value"? From the variable name `pdf` and also the way we get the number, I suspect this to be a `pandas DataFrame`, while the right type should be `Row`. ########## sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala: ########## @@ -0,0 +1,241 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream, DataOutputStream, EOFException} +import java.net.ServerSocket + +import scala.collection.mutable + +import com.google.protobuf.ByteString + +import org.apache.spark.internal.{Logging, LogKeys, MDC} +import org.apache.spark.sql.{Encoders, Row} +import org.apache.spark.sql.api.python.PythonSQLUtils +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, StatefulProcessorHandleImpl, StatefulProcessorHandleState} +import org.apache.spark.sql.execution.streaming.state.StateMessage.{HandleState, ImplicitGroupingKeyRequest, StatefulProcessorCall, StateRequest, StateResponse, StateVariableRequest, ValueStateCall} +import org.apache.spark.sql.streaming.ValueState +import org.apache.spark.sql.types.StructType + +/** + * This class is used to handle the state requests from the Python side. It runs on a separate + * thread spawned by TransformWithStateInPandasStateRunner per task. It opens a dedicated socket + * to process/transfer state related info which is shut down when task finishes or there's an error + * on opening the socket. It run It processes following state requests and return responses to the + * Python side: + * - Requests for managing explicit grouping key. + * - Stateful processor requests. + * - Requests for managing state variables (e.g. valueState). + */ +class TransformWithStateInPandasStateServer( + private val stateServerSocket: ServerSocket, + private val statefulProcessorHandle: StatefulProcessorHandleImpl, + private val groupingKeySchema: StructType, + private val outputStreamForTest: DataOutputStream = null, + private val valueStateMapForTest: mutable.HashMap[String, ValueState[Row]] = null) + extends Runnable with Logging { + private var inputStream: DataInputStream = _ + private var outputStream: DataOutputStream = outputStreamForTest + private val valueStates = if (valueStateMapForTest != null) { + valueStateMapForTest + } else { + new mutable.HashMap[String, ValueState[Row]]() + } + + def run(): Unit = { + val listeningSocket = stateServerSocket.accept() + inputStream = new DataInputStream( + new BufferedInputStream(listeningSocket.getInputStream)) + outputStream = new DataOutputStream( + new BufferedOutputStream(listeningSocket.getOutputStream) + ) + + while (listeningSocket.isConnected && + statefulProcessorHandle.getHandleState != StatefulProcessorHandleState.CLOSED) { + try { + val version = inputStream.readInt() + if (version != -1) { + assert(version == 0) + val message = parseProtoMessage() + handleRequest(message) + outputStream.flush() + } + } catch { + case _: EOFException => + logWarning(log"No more data to read from the socket") + statefulProcessorHandle.setHandleState(StatefulProcessorHandleState.CLOSED) + return + case e: Exception => + logError(log"Error reading message: ${MDC(LogKeys.ERROR, e.getMessage)}", e) + sendResponse(1, e.getMessage) + outputStream.flush() + statefulProcessorHandle.setHandleState(StatefulProcessorHandleState.CLOSED) + return + } + } + logInfo(log"Done from the state server thread") + } + + private def parseProtoMessage(): StateRequest = { + val messageLen = inputStream.readInt() + val messageBytes = new Array[Byte](messageLen) + inputStream.read(messageBytes) + StateRequest.parseFrom(ByteString.copyFrom(messageBytes)) + } + + private def handleRequest(message: StateRequest): Unit = { + message.getMethodCase match { + case StateRequest.MethodCase.IMPLICITGROUPINGKEYREQUEST => + handleImplicitGroupingKeyRequest(message.getImplicitGroupingKeyRequest) + case StateRequest.MethodCase.STATEFULPROCESSORCALL => + handleStatefulProcessorCall(message.getStatefulProcessorCall) + case StateRequest.MethodCase.STATEVARIABLEREQUEST => + handleStateVariableRequest(message.getStateVariableRequest) + case _ => + throw new IllegalArgumentException("Invalid method call") + } + } + + private def handleImplicitGroupingKeyRequest(message: ImplicitGroupingKeyRequest): Unit = { + message.getMethodCase match { + case ImplicitGroupingKeyRequest.MethodCase.SETIMPLICITKEY => + val keyBytes = message.getSetImplicitKey.getKey.toByteArray + // The key row is serialized as a byte array, we need to convert it back to a Row + val keyRow = PythonSQLUtils.toJVMRow(keyBytes, groupingKeySchema, + ExpressionEncoder(groupingKeySchema).resolveAndBind().createDeserializer()) + ImplicitGroupingKeyTracker.setImplicitKey(keyRow) + sendResponse(0) + case ImplicitGroupingKeyRequest.MethodCase.REMOVEIMPLICITKEY => + ImplicitGroupingKeyTracker.removeImplicitKey() + sendResponse(0) + case _ => + throw new IllegalArgumentException("Invalid method call") + } + } + + private[sql] def handleStatefulProcessorCall(message: StatefulProcessorCall): Unit = { + message.getMethodCase match { + case StatefulProcessorCall.MethodCase.SETHANDLESTATE => + val requestedState = message.getSetHandleState.getState + requestedState match { + case HandleState.CREATED => + logInfo(log"set handle state to Created") + statefulProcessorHandle.setHandleState(StatefulProcessorHandleState.CREATED) + case HandleState.INITIALIZED => + logInfo(log"set handle state to Initialized") + statefulProcessorHandle.setHandleState(StatefulProcessorHandleState.INITIALIZED) + case HandleState.CLOSED => + logInfo(log"set handle state to Closed") + statefulProcessorHandle.setHandleState(StatefulProcessorHandleState.CLOSED) + case _ => + } + sendResponse(0) + case StatefulProcessorCall.MethodCase.GETVALUESTATE => + val stateName = message.getGetValueState.getStateName + val schema = message.getGetValueState.getSchema + initializeValueState(stateName, schema) + case _ => + throw new IllegalArgumentException("Invalid method call") + } + } + + private def handleStateVariableRequest(message: StateVariableRequest): Unit = { + message.getMethodCase match { + case StateVariableRequest.MethodCase.VALUESTATECALL => + handleValueStateRequest(message.getValueStateCall) + case _ => + throw new IllegalArgumentException("Invalid method call") + } + } + + private[sql] def handleValueStateRequest(message: ValueStateCall): Unit = { + val stateName = message.getStateName + message.getMethodCase match { + case ValueStateCall.MethodCase.EXISTS => + if (valueStates.contains(stateName) && valueStates(stateName).exists()) { + sendResponse(0) + } else { + sendResponse(1, s"state $stateName doesn't exist") + } + case ValueStateCall.MethodCase.GET => + if (valueStates.contains(stateName)) { + val valueOption = valueStates(stateName).getOption() + if (valueOption.isDefined) { + sendResponse(0) + // Serialize the value row as a byte array + val valueBytes = PythonSQLUtils.toPyRow(valueOption.get) + outputStream.writeInt(valueBytes.length) + outputStream.write(valueBytes) + } else { + logWarning(log"state ${MDC(LogKeys.STATE_NAME, stateName)} doesn't exist") + sendResponse(1, s"state $stateName doesn't exist") + } + } else { + logWarning(log"state ${MDC(LogKeys.STATE_NAME, stateName)} doesn't exist") + sendResponse(1, s"state $stateName doesn't exist") + } + case ValueStateCall.MethodCase.VALUESTATEUPDATE => + val byteArray = message.getValueStateUpdate.getValue.toByteArray + val schema = StructType.fromString(message.getValueStateUpdate.getSchema) + // The value row is serialized as a byte array, we need to convert it back to a Row + val valueRow = PythonSQLUtils.toJVMRow(byteArray, schema, + ExpressionEncoder(schema).resolveAndBind().createDeserializer()) Review Comment: You can maintain the association between state variable and deserializer; let's just initialize once and reuse. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
