jingz-db commented on code in PR #47574:
URL: https://github.com/apache/spark/pull/47574#discussion_r1710383400


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala:
##########
@@ -52,32 +54,163 @@ class StateDataSource extends TableProvider with 
DataSourceRegister {
 
   override def shortName(): String = "statestore"
 
+  private var stateStoreMetadata: Option[Array[StateMetadataTableEntry]] = None
+
+  private var keyStateEncoderSpecOpt: Option[KeyStateEncoderSpec] = None
+
+  private var transformWithStateVariableInfoOpt: 
Option[TransformWithStateVariableInfo] = None
+
+  private def runStateVarChecks(
+      sourceOptions: StateSourceOptions,
+      stateStoreMetadata: Array[StateMetadataTableEntry]): Unit = {
+    val twsShortName = "transformWithStateExec"
+    if (sourceOptions.stateVarName.isDefined) {
+      // Perform checks for transformWithState operator in case state variable 
name is provided
+      require(stateStoreMetadata.size == 1)
+      val opMetadata = stateStoreMetadata.head
+      if (opMetadata.operatorName != twsShortName) {
+        // if we are trying to query state source with state variable name, 
then the operator
+        // should be transformWithState
+        val errorMsg = "Providing state variable names is only supported with 
the " +
+          s"transformWithState operator. Found 
operator=${opMetadata.operatorName}. " +
+          s"Please remove this option and re-run the query."
+        throw StateDataSourceErrors.invalidOptionValue(STATE_VAR_NAME, 
errorMsg)
+      }
+
+      // if the operator is transformWithState, but the operator properties 
are empty, then
+      // the user has not defined any state variables for the operator
+      val operatorProperties = opMetadata.operatorPropertiesJson
+      if (operatorProperties.isEmpty) {
+        throw StateDataSourceErrors.invalidOptionValue(STATE_VAR_NAME,
+          "No state variable names are defined for the transformWithState 
operator")
+      }
+
+      // if the state variable is not one of the defined/available state 
variables, then we
+      // fail the query
+      val stateVarName = sourceOptions.stateVarName.get
+      val twsOperatorProperties = 
TransformWithStateOperatorProperties.fromJson(operatorProperties)
+      val stateVars = twsOperatorProperties.stateVariables
+      if (stateVars.filter(stateVar => stateVar.stateName == 
stateVarName).size != 1) {
+        throw StateDataSourceErrors.invalidOptionValue(STATE_VAR_NAME,
+          s"State variable $stateVarName is not defined for the 
transformWithState operator.")
+      }
+    } else {
+      // if the operator is transformWithState, then a state variable argument 
is mandatory
+      if (stateStoreMetadata.size == 1 &&
+        stateStoreMetadata.head.operatorName == twsShortName) {
+        throw StateDataSourceErrors.requiredOptionUnspecified("stateVarName")
+      }
+    }
+  }
+
+  private def getStateStoreMetadata(stateSourceOptions: StateSourceOptions):
+    Array[StateMetadataTableEntry] = {
+    val allStateStoreMetadata = new StateMetadataPartitionReader(
+      stateSourceOptions.stateCheckpointLocation.getParent.toString,
+      serializedHadoopConf, stateSourceOptions.batchId).stateMetadata.toArray
+    val stateStoreMetadata = allStateStoreMetadata.filter { entry =>
+      entry.operatorId == stateSourceOptions.operatorId &&
+        entry.stateStoreName == stateSourceOptions.storeName
+    }
+    stateStoreMetadata
+  }
+
+  private def getStoreMetadataAndRunChecks(stateSourceOptions: 
StateSourceOptions): Unit = {
+    if (stateStoreMetadata.isEmpty) {
+      stateStoreMetadata = Some(getStateStoreMetadata(stateSourceOptions))
+      runStateVarChecks(stateSourceOptions, stateStoreMetadata.get)
+    }
+  }
+
   override def getTable(
       schema: StructType,
       partitioning: Array[Transform],
       properties: util.Map[String, String]): Table = {
     val sourceOptions = StateSourceOptions.apply(session, hadoopConf, 
properties)
     val stateConf = buildStateStoreConf(sourceOptions.resolvedCpLocation, 
sourceOptions.batchId)
-    // Read the operator metadata once to see if we can find the information 
for prefix scan
-    // encoder used in session window aggregation queries.
-    val allStateStoreMetadata = new StateMetadataPartitionReader(
-      sourceOptions.stateCheckpointLocation.getParent.toString, 
serializedHadoopConf,
-      sourceOptions.batchId)
-      .stateMetadata.toArray
-    val stateStoreMetadata = allStateStoreMetadata.filter { entry =>
-      entry.operatorId == sourceOptions.operatorId &&
-        entry.stateStoreName == sourceOptions.storeName
+    getStoreMetadataAndRunChecks(sourceOptions)
+
+    val keyStateEncoderSpec = if (keyStateEncoderSpecOpt.isDefined) {
+      keyStateEncoderSpecOpt.get
+    } else {
+      val keySchema = SchemaUtil.getSchemaAsDataType(schema, 
"key").asInstanceOf[StructType]
+      NoPrefixKeyStateEncoderSpec(keySchema)
     }
 
-    new StateTable(session, schema, sourceOptions, stateConf, 
stateStoreMetadata)
+    new StateTable(session, schema, sourceOptions, stateConf, 
keyStateEncoderSpec,
+      transformWithStateVariableInfoOpt)
+  }
+
+  private def getKeyStateEncoderSpec(colFamilySchema: 
StateStoreColFamilySchema):
+    KeyStateEncoderSpec = {
+    val storeMetadata = stateStoreMetadata.get
+    // If operator metadata is not found, then log a warning and continue with 
using the no-prefix
+    // key state encoder
+    val keyStateEncoderSpec = if (storeMetadata.length == 0) {
+      logWarning("Metadata for state store not found, possible cause is this 
checkpoint " +
+        "is created by older version of spark. If the query has session window 
aggregation, " +
+        "the state can't be read correctly and runtime exception will be 
thrown. " +
+        "Run the streaming query in newer spark version to generate state 
metadata " +
+        "can fix the issue.")
+      NoPrefixKeyStateEncoderSpec(colFamilySchema.keySchema)
+    } else {
+      require(storeMetadata.length == 1)
+      val storeMetadataEntry = storeMetadata.head
+      // if version has metadata info, then use numColsPrefixKey as specified
+      if (storeMetadataEntry.version == 1 && 
storeMetadataEntry.numColsPrefixKey == 0) {
+        NoPrefixKeyStateEncoderSpec(colFamilySchema.keySchema)
+      } else if (storeMetadataEntry.version == 1 && 
storeMetadataEntry.numColsPrefixKey > 0) {
+        PrefixKeyScanStateEncoderSpec(colFamilySchema.keySchema,
+          storeMetadataEntry.numColsPrefixKey)
+      } else if (storeMetadataEntry.version == 2) {
+        // for version 2, we have the encoder spec recorded to the state 
schema file. so we just
+        // use that directly
+        require(colFamilySchema.keyStateEncoderSpec.isDefined)
+        colFamilySchema.keyStateEncoderSpec.get
+      } else {
+        throw StateDataSourceErrors.internalError(s"Failed to read " +
+          s"key state encoder spec for 
operator=${storeMetadataEntry.operatorId}")
+      }
+    }
+    keyStateEncoderSpec
   }
 
+  private def generateSchemaForStateVar(
+      stateVarInfo: TransformWithStateVariableInfo,
+      stateStoreColFamilySchema: StateStoreColFamilySchema): StructType = {
+    val stateVarType = stateVarInfo.stateVariableType
+    val hasTTLEnabled = stateVarInfo.ttlEnabled
+
+    stateVarType match {
+      case StateVariableType.ValueState =>
+        if (hasTTLEnabled) {
+          new StructType()
+            .add("key", stateStoreColFamilySchema.keySchema)
+            .add("value", stateStoreColFamilySchema.valueSchema)
+            .add("expiration_timestamp", LongType)
+            .add("partition_id", IntegerType)
+        } else {
+          new StructType()
+            .add("key", stateStoreColFamilySchema.keySchema)
+            .add("value", stateStoreColFamilySchema.valueSchema)
+            .add("partition_id", IntegerType)
+        }
+
+      case _ =>
+        throw StateDataSourceErrors.internalError(s"Unsupported state variable 
type $stateVarType")
+    }
+  }
+
+
   override def inferSchema(options: CaseInsensitiveStringMap): StructType = {
     val partitionId = StateStore.PARTITION_ID_TO_CHECK_SCHEMA
     val sourceOptions = StateSourceOptions.apply(session, hadoopConf, options)
 
+    getStoreMetadataAndRunChecks(sourceOptions)

Review Comment:
   Do we call this `getStoreMetadataAndRunChecks` in `inferSchema` and in 
`getTable` twice so we will do `StateMetadataPartitionReader` twice? Is it 
possible that we only run this once and store it as a class variable?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to