anishshri-db commented on code in PR #47574:
URL: https://github.com/apache/spark/pull/47574#discussion_r1735576088
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala:
##########
@@ -52,30 +54,136 @@ class StateDataSource extends TableProvider with
DataSourceRegister {
override def shortName(): String = "statestore"
+ private var stateStoreMetadata: Option[Array[StateMetadataTableEntry]] = None
Review Comment:
Made this stateless too
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala:
##########
@@ -52,30 +54,136 @@ class StateDataSource extends TableProvider with
DataSourceRegister {
override def shortName(): String = "statestore"
+ private var stateStoreMetadata: Option[Array[StateMetadataTableEntry]] = None
+
+ private var keyStateEncoderSpecOpt: Option[KeyStateEncoderSpec] = None
+
+ private var transformWithStateVariableInfoOpt:
Option[TransformWithStateVariableInfo] = None
+
+ private var stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema]
= None
+
+ private def runStateVarChecks(
+ sourceOptions: StateSourceOptions,
+ stateStoreMetadata: Array[StateMetadataTableEntry]): Unit = {
+ val twsShortName = "transformWithStateExec"
+ if (sourceOptions.stateVarName.isDefined) {
+ // Perform checks for transformWithState operator in case state variable
name is provided
+ require(stateStoreMetadata.size == 1)
+ val opMetadata = stateStoreMetadata.head
+ if (opMetadata.operatorName != twsShortName) {
+ // if we are trying to query state source with state variable name,
then the operator
+ // should be transformWithState
+ val errorMsg = "Providing state variable names is only supported with
the " +
+ s"transformWithState operator. Found
operator=${opMetadata.operatorName}. " +
+ s"Please remove this option and re-run the query."
+ throw StateDataSourceErrors.invalidOptionValue(STATE_VAR_NAME,
errorMsg)
+ }
+
+ // if the operator is transformWithState, but the operator properties
are empty, then
+ // the user has not defined any state variables for the operator
+ val operatorProperties = opMetadata.operatorPropertiesJson
+ if (operatorProperties.isEmpty) {
+ throw StateDataSourceErrors.invalidOptionValue(STATE_VAR_NAME,
+ "No state variable names are defined for the transformWithState
operator")
+ }
+
+ // if the state variable is not one of the defined/available state
variables, then we
+ // fail the query
+ val stateVarName = sourceOptions.stateVarName.get
+ val twsOperatorProperties =
TransformWithStateOperatorProperties.fromJson(operatorProperties)
+ val stateVars = twsOperatorProperties.stateVariables
+ if (stateVars.filter(stateVar => stateVar.stateName ==
stateVarName).size != 1) {
+ throw StateDataSourceErrors.invalidOptionValue(STATE_VAR_NAME,
+ s"State variable $stateVarName is not defined for the
transformWithState operator.")
+ }
+ } else {
+ // if the operator is transformWithState, then a state variable argument
is mandatory
+ if (stateStoreMetadata.size == 1 &&
+ stateStoreMetadata.head.operatorName == twsShortName) {
+ throw StateDataSourceErrors.requiredOptionUnspecified("stateVarName")
+ }
+ }
+ }
+
+ private def getStateStoreMetadata(stateSourceOptions: StateSourceOptions):
+ Array[StateMetadataTableEntry] = {
+ val allStateStoreMetadata = new StateMetadataPartitionReader(
+ stateSourceOptions.stateCheckpointLocation.getParent.toString,
+ serializedHadoopConf, stateSourceOptions.batchId).stateMetadata.toArray
+ val stateStoreMetadata = allStateStoreMetadata.filter { entry =>
+ entry.operatorId == stateSourceOptions.operatorId &&
+ entry.stateStoreName == stateSourceOptions.storeName
+ }
+ stateStoreMetadata
+ }
+
+ private def getStoreMetadataAndRunChecks(stateSourceOptions:
StateSourceOptions): Unit = {
+ if (stateStoreMetadata.isEmpty) {
+ stateStoreMetadata = Some(getStateStoreMetadata(stateSourceOptions))
+ runStateVarChecks(stateSourceOptions, stateStoreMetadata.get)
+ }
+ }
+
override def getTable(
schema: StructType,
partitioning: Array[Transform],
properties: util.Map[String, String]): Table = {
val sourceOptions = StateSourceOptions.apply(session, hadoopConf,
properties)
val stateConf = buildStateStoreConf(sourceOptions.resolvedCpLocation,
sourceOptions.batchId)
- // Read the operator metadata once to see if we can find the information
for prefix scan
- // encoder used in session window aggregation queries.
- val allStateStoreMetadata = new StateMetadataPartitionReader(
- sourceOptions.stateCheckpointLocation.getParent.toString,
serializedHadoopConf,
- sourceOptions.batchId)
- .stateMetadata.toArray
- val stateStoreMetadata = allStateStoreMetadata.filter { entry =>
- entry.operatorId == sourceOptions.operatorId &&
- entry.stateStoreName == sourceOptions.storeName
+ getStoreMetadataAndRunChecks(sourceOptions)
+
+ // The key state encoder spec should be available for all operators except
stream-stream joins
+ val keyStateEncoderSpec = if (keyStateEncoderSpecOpt.isDefined) {
Review Comment:
Yea removed this dependence
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]