anishshri-db commented on code in PR #47104:
URL: https://github.com/apache/spark/pull/47104#discussion_r1662982978


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala:
##########
@@ -313,3 +295,174 @@ class StatefulProcessorHandleImpl(
     }
   }
 }
+
+/**
+ * This DriverStatefulProcessorHandleImpl is used within TransformWithExec
+ * on the driver side to collect the columnFamilySchemas before any processing 
is
+ * actually done. We need this class because we can only collect the schemas 
after
+ * the StatefulProcessor is initialized.
+ */
+class DriverStatefulProcessorHandleImpl(timeMode: TimeMode)
+  extends StatefulProcessorHandleImplBase(timeMode) {
+
+  private[sql] val columnFamilySchemaFactory = ColumnFamilySchemaFactory.
+    getFactory(StateSchemaV3File.COLUMN_FAMILY_SCHEMA_VERSION)
+
+  private[sql] val columnFamilySchemas: util.List[ColumnFamilySchema] =
+    new util.ArrayList[ColumnFamilySchema]()
+
+  private def verifyStateVarOperations(operationType: String): Unit = {
+    if (currState != PRE_INIT) {
+      throw 
StateStoreErrors.cannotPerformOperationWithInvalidHandleState(operationType,
+        currState.toString)
+    }
+  }
+
+  /**
+   * Function to add the ValueState schema to the list of column family 
schemas.
+   * The user must ensure to call this function only within the `init()` 
method of the
+   * StatefulProcessor.
+   *
+   * @param stateName  - name of the state variable
+   * @param valEncoder - SQL encoder for state variable
+   * @tparam T - type of state variable
+   * @return - instance of ValueState of type T that can be used to store 
state persistently
+   */
+  override def getValueState[T](stateName: String, valEncoder: Encoder[T]): 
ValueState[T] = {
+    verifyStateVarOperations("get_value_state")
+    val colFamilySchema = 
columnFamilySchemaFactory.getValueStateSchema(stateName, valEncoder)
+    columnFamilySchemas.add(colFamilySchema)
+    null
+  }
+
+  /**
+   * Function to add the ValueStateWithTTL schema to the list of column family 
schemas.
+   * The user must ensure to call this function only within the `init()` 
method of the
+   * StatefulProcessor.
+   *
+   * @param stateName  - name of the state variable
+   * @param valEncoder - SQL encoder for state variable
+   * @param ttlConfig  - the ttl configuration (time to live duration etc.)
+   * @tparam T - type of state variable
+   * @return - instance of ValueState of type T that can be used to store 
state persistently
+   */
+  override def getValueState[T](
+      stateName: String,
+      valEncoder: Encoder[T],
+      ttlConfig: TTLConfig): ValueState[T] = {
+    verifyStateVarOperations("get_value_state")
+    val colFamilySchema = 
columnFamilySchemaFactory.getValueStateTtlSchema(stateName, valEncoder)
+    columnFamilySchemas.add(colFamilySchema)
+    null
+  }
+
+  /**
+   * Function to add the ListState schema to the list of column family schemas.
+   * The user must ensure to call this function only within the `init()` 
method of the
+   * StatefulProcessor.
+   *
+   * @param stateName  - name of the state variable
+   * @param valEncoder - SQL encoder for state variable
+   * @tparam T - type of state variable
+   * @return - instance of ListState of type T that can be used to store state 
persistently
+   */
+  override def getListState[T](stateName: String, valEncoder: Encoder[T]): 
ListState[T] = {
+    verifyStateVarOperations("get_list_state")
+    val colFamilySchema = 
columnFamilySchemaFactory.getListStateSchema(stateName, valEncoder)
+    columnFamilySchemas.add(colFamilySchema)
+    null
+  }
+
+  /**
+   * Function to add the ListStateWithTTL schema to the list of column family 
schemas.
+   * The user must ensure to call this function only within the `init()` 
method of the
+   * StatefulProcessor.
+   *
+   * @param stateName  - name of the state variable
+   * @param valEncoder - SQL encoder for state variable
+   * @param ttlConfig  - the ttl configuration (time to live duration etc.)
+   * @tparam T - type of state variable
+   * @return - instance of ListState of type T that can be used to store state 
persistently
+   */
+  override def getListState[T](
+      stateName: String,
+      valEncoder: Encoder[T],
+      ttlConfig: TTLConfig): ListState[T] = {
+    verifyStateVarOperations("get_list_state")
+    val colFamilySchema = 
columnFamilySchemaFactory.getListStateTtlSchema(stateName, valEncoder)
+    columnFamilySchemas.add(colFamilySchema)
+    null
+  }
+
+  /**
+   * Function to add the MapState schema to the list of column family schemas.
+   * The user must ensure to call this function only within the `init()` 
method of the
+   * StatefulProcessor.
+   * @param stateName  - name of the state variable
+   * @param userKeyEnc - spark sql encoder for the map key
+   * @param valEncoder - spark sql encoder for the map value
+   * @tparam K - type of key for map state variable
+   * @tparam V - type of value for map state variable
+   * @return - instance of MapState of type [K,V] that can be used to store 
state persistently
+   */
+  override def getMapState[K, V](
+      stateName: String,
+      userKeyEnc: Encoder[K],
+      valEncoder: Encoder[V]): MapState[K, V] = {
+    verifyStateVarOperations("get_map_state")
+    val colFamilySchema = columnFamilySchemaFactory.
+      getMapStateSchema(stateName, userKeyEnc, valEncoder)
+    columnFamilySchemas.add(colFamilySchema)
+    null
+  }
+
+  /**
+   * Function to add the MapStateWithTTL schema to the list of column family 
schemas.
+   * The user must ensure to call this function only within the `init()` 
method of the
+   * StatefulProcessor.
+   * @param stateName  - name of the state variable
+   * @param userKeyEnc - spark sql encoder for the map key
+   * @param valEncoder - SQL encoder for state variable
+   * @param ttlConfig  - the ttl configuration (time to live duration etc.)
+   * @tparam K - type of key for map state variable
+   * @tparam V - type of value for map state variable
+   * @return - instance of MapState of type [K,V] that can be used to store 
state persistently
+   */
+  override def getMapState[K, V](
+      stateName: String,
+      userKeyEnc: Encoder[K],
+      valEncoder: Encoder[V],
+      ttlConfig: TTLConfig): MapState[K, V] = {
+    verifyStateVarOperations("get_map_state")
+    val colFamilySchema = columnFamilySchemaFactory.
+      getMapStateTtlSchema(stateName, userKeyEnc, valEncoder)
+    columnFamilySchemas.add(colFamilySchema)
+    null
+  }
+
+  /** Function to return queryInfo for currently running task */
+  override def getQueryInfo(): QueryInfo = {
+    new QueryInfoImpl(UUID.randomUUID(), UUID.randomUUID(), 0L)
+  }
+
+  /**
+   * Methods that are only included to satisfy the interface.
+   * These methods are no-ops on the driver side

Review Comment:
   We expect these to fail right ? so maybe we can be explicit about that - say 
that the driver handle state does not allow for these ops to be executed ?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to