anishshri-db commented on code in PR #47104:
URL: https://github.com/apache/spark/pull/47104#discussion_r1659407298


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateSchemaV3File.scala:
##########
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{InputStream, OutputStream}
+import java.nio.charset.StandardCharsets.UTF_8
+
+import scala.io.{Source => IOSource}
+
+import org.apache.hadoop.conf.Configuration
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.execution.streaming.state.{ColumnFamilySchema, 
ColumnFamilySchemaV1}
+import org.apache.spark.sql.internal.SQLConf
+
+class StateSchemaV3File(
+    hadoopConf: Configuration,
+    path: String,
+    metadataCacheEnabled: Boolean = false)
+  extends HDFSMetadataLog[List[ColumnFamilySchema]](hadoopConf, path, 
metadataCacheEnabled) {
+
+  val VERSION = 3
+  private val COLUMN_FAMILY_SCHEMA_VERSION = 1
+
+  def this(sparkSession: SparkSession, path: String) = {
+    this(
+      sparkSession.sessionState.newHadoopConf(),
+      path,
+      metadataCacheEnabled = sparkSession.sessionState.conf.getConf(
+        SQLConf.STREAMING_METADATA_CACHE_ENABLED)
+    )
+  }
+
+  override def deserialize(in: InputStream): List[ColumnFamilySchema] = {
+    val lines = IOSource.fromInputStream(in, UTF_8.name()).getLines()
+
+    if (!lines.hasNext) {
+      throw new IllegalStateException("Incomplete log file in the offset 
commit log")
+    }
+
+    val version = lines.next().trim
+    validateVersion(version, VERSION)
+
+    val columnFamilySchemaVersion = lines.next().trim
+
+    columnFamilySchemaVersion match {
+      case "v1" => lines.map(ColumnFamilySchemaV1.fromJson).toList
+      case _ =>
+        throw new IllegalStateException(
+          s"Unsupported column family schema version: 
$columnFamilySchemaVersion")
+    }
+  }
+
+  override def serialize(schemas: List[ColumnFamilySchema], out: 
OutputStream): Unit = {
+    out.write(s"v${VERSION}".getBytes(UTF_8))
+    out.write('\n')
+    out.write(s"v${COLUMN_FAMILY_SCHEMA_VERSION}".getBytes(UTF_8))
+    out.write('\n')
+    out.write(schemas.map(_.json).mkString("\n").getBytes(UTF_8))
+  }
+
+  override def add(batchId: Long, metadata: List[ColumnFamilySchema]): Boolean 
= {
+    require(metadata != null, "'null' metadata cannot written to a metadata 
log")
+    val batchMetadataFile = batchIdToPath(batchId)
+    if (fileManager.exists(batchMetadataFile)) {
+      fileManager.delete(batchMetadataFile)
+    }
+    val res = addNewBatchByStream(batchId) { output => serialize(metadata, 
output) }
+    if (metadataCacheEnabled && res) batchCache.put(batchId, metadata)
+    res
+  }
+
+  override def addNewBatchByStream(batchId: Long)(fn: OutputStream => Unit): 
Boolean = {

Review Comment:
   could we add some function level comments here too ?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala:
##########
@@ -313,3 +313,146 @@ class StatefulProcessorHandleImpl(
     }
   }
 }
+
+/**
+ * This DriverStatefulProcessorHandleImpl is used within TransformWithExec
+ * on the driver side to collect the columnFamilySchemas before any processing 
is
+ * actually done. We need this class because we can only collect the schemas 
after
+ * the StatefulProcessor is initialized.
+ */
+class DriverStatefulProcessorHandleImpl extends StatefulProcessorHandle {
+
+  private[sql] val columnFamilySchemas: util.List[ColumnFamilySchema] =
+    new util.ArrayList[ColumnFamilySchema]()
+
+  /**
+   * Function to add the ValueState schema to the list of column family 
schemas.
+   * The user must ensure to call this function only within the `init()` 
method of the
+   * StatefulProcessor.
+   *
+   * @param stateName  - name of the state variable
+   * @param valEncoder - SQL encoder for state variable
+   * @tparam T - type of state variable
+   * @return - instance of ValueState of type T that can be used to store 
state persistently
+   */
+  override def getValueState[T](stateName: String, valEncoder: Encoder[T]): 
ValueState[T] = {
+    val colFamilySchema = ValueStateImpl.columnFamilySchema(stateName)
+    columnFamilySchemas.add(colFamilySchema)
+    null
+  }
+
+  /**
+   * Function to add the ValueStateWithTTL schema to the list of column family 
schemas.
+   * The user must ensure to call this function only within the `init()` 
method of the
+   * StatefulProcessor.
+   *
+   * @param stateName  - name of the state variable
+   * @param valEncoder - SQL encoder for state variable
+   * @param ttlConfig  - the ttl configuration (time to live duration etc.)
+   * @tparam T - type of state variable
+   * @return - instance of ValueState of type T that can be used to store 
state persistently
+   */
+  override def getValueState[T](
+      stateName: String,
+      valEncoder: Encoder[T],
+      ttlConfig: TTLConfig): ValueState[T] = {
+    val colFamilySchema = ValueStateImplWithTTL.columnFamilySchema(stateName)
+    columnFamilySchemas.add(colFamilySchema)
+    null
+  }
+
+  /**
+   * Function to add the ListState schema to the list of column family schemas.
+   * The user must ensure to call this function only within the `init()` 
method of the
+   * StatefulProcessor.
+   *
+   * @param stateName  - name of the state variable
+   * @param valEncoder - SQL encoder for state variable
+   * @tparam T - type of state variable
+   * @return - instance of ListState of type T that can be used to store state 
persistently
+   */
+  override def getListState[T](stateName: String, valEncoder: Encoder[T]): 
ListState[T] = {
+    val colFamilySchema = ListStateImpl.columnFamilySchema(stateName)

Review Comment:
   we don't need to pass the encoder ?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala:
##########
@@ -313,3 +313,146 @@ class StatefulProcessorHandleImpl(
     }
   }
 }
+
+/**
+ * This DriverStatefulProcessorHandleImpl is used within TransformWithExec
+ * on the driver side to collect the columnFamilySchemas before any processing 
is
+ * actually done. We need this class because we can only collect the schemas 
after
+ * the StatefulProcessor is initialized.
+ */
+class DriverStatefulProcessorHandleImpl extends StatefulProcessorHandle {
+
+  private[sql] val columnFamilySchemas: util.List[ColumnFamilySchema] =
+    new util.ArrayList[ColumnFamilySchema]()
+
+  /**
+   * Function to add the ValueState schema to the list of column family 
schemas.
+   * The user must ensure to call this function only within the `init()` 
method of the
+   * StatefulProcessor.
+   *
+   * @param stateName  - name of the state variable
+   * @param valEncoder - SQL encoder for state variable
+   * @tparam T - type of state variable
+   * @return - instance of ValueState of type T that can be used to store 
state persistently
+   */
+  override def getValueState[T](stateName: String, valEncoder: Encoder[T]): 
ValueState[T] = {
+    val colFamilySchema = ValueStateImpl.columnFamilySchema(stateName)
+    columnFamilySchemas.add(colFamilySchema)
+    null
+  }
+
+  /**
+   * Function to add the ValueStateWithTTL schema to the list of column family 
schemas.
+   * The user must ensure to call this function only within the `init()` 
method of the
+   * StatefulProcessor.
+   *
+   * @param stateName  - name of the state variable
+   * @param valEncoder - SQL encoder for state variable
+   * @param ttlConfig  - the ttl configuration (time to live duration etc.)
+   * @tparam T - type of state variable
+   * @return - instance of ValueState of type T that can be used to store 
state persistently
+   */
+  override def getValueState[T](
+      stateName: String,
+      valEncoder: Encoder[T],
+      ttlConfig: TTLConfig): ValueState[T] = {
+    val colFamilySchema = ValueStateImplWithTTL.columnFamilySchema(stateName)
+    columnFamilySchemas.add(colFamilySchema)
+    null
+  }
+
+  /**
+   * Function to add the ListState schema to the list of column family schemas.
+   * The user must ensure to call this function only within the `init()` 
method of the
+   * StatefulProcessor.
+   *
+   * @param stateName  - name of the state variable
+   * @param valEncoder - SQL encoder for state variable
+   * @tparam T - type of state variable
+   * @return - instance of ListState of type T that can be used to store state 
persistently
+   */
+  override def getListState[T](stateName: String, valEncoder: Encoder[T]): 
ListState[T] = {
+    val colFamilySchema = ListStateImpl.columnFamilySchema(stateName)
+    columnFamilySchemas.add(colFamilySchema)
+    null
+  }
+
+  /**
+   * Function to add the ListStateWithTTL schema to the list of column family 
schemas.
+   * The user must ensure to call this function only within the `init()` 
method of the
+   * StatefulProcessor.
+   *
+   * @param stateName  - name of the state variable
+   * @param valEncoder - SQL encoder for state variable
+   * @param ttlConfig  - the ttl configuration (time to live duration etc.)
+   * @tparam T - type of state variable
+   * @return - instance of ListState of type T that can be used to store state 
persistently
+   */
+  override def getListState[T](
+      stateName: String,
+      valEncoder: Encoder[T],
+      ttlConfig: TTLConfig): ListState[T] = {
+    val colFamilySchema = ListStateImplWithTTL.columnFamilySchema(stateName)
+    columnFamilySchemas.add(colFamilySchema)
+    null
+  }
+
+  /**
+   * Function to add the MapState schema to the list of column family schemas.
+   * The user must ensure to call this function only within the `init()` 
method of the
+   * StatefulProcessor.
+   * @param stateName  - name of the state variable
+   * @param userKeyEnc - spark sql encoder for the map key
+   * @param valEncoder - spark sql encoder for the map value
+   * @tparam K - type of key for map state variable
+   * @tparam V - type of value for map state variable
+   * @return - instance of MapState of type [K,V] that can be used to store 
state persistently
+   */
+  override def getMapState[K, V](
+      stateName: String,
+      userKeyEnc: Encoder[K],
+      valEncoder: Encoder[V]): MapState[K, V] = {
+    val colFamilySchema = MapStateImpl.columnFamilySchema(stateName)
+    columnFamilySchemas.add(colFamilySchema)
+    null
+  }
+
+  /**
+   * Function to add the MapStateWithTTL schema to the list of column family 
schemas.
+   * The user must ensure to call this function only within the `init()` 
method of the
+   * StatefulProcessor.
+   * @param stateName  - name of the state variable
+   * @param userKeyEnc - spark sql encoder for the map key
+   * @param valEncoder - SQL encoder for state variable
+   * @param ttlConfig  - the ttl configuration (time to live duration etc.)
+   * @tparam K - type of key for map state variable
+   * @tparam V - type of value for map state variable
+   * @return - instance of MapState of type [K,V] that can be used to store 
state persistently
+   */
+  override def getMapState[K, V](
+      stateName: String,
+      userKeyEnc: Encoder[K],
+      valEncoder: Encoder[V],
+      ttlConfig: TTLConfig): MapState[K, V] = {
+    val colFamilySchema = MapStateImplWithTTL.columnFamilySchema(stateName)
+    columnFamilySchemas.add(colFamilySchema)
+    null
+  }
+
+  /** Function to return queryInfo for currently running task */
+  override def getQueryInfo(): QueryInfo = {
+    new QueryInfoImpl(UUID.randomUUID(), UUID.randomUUID(), 0L)
+  }
+
+  /**
+   * Methods that are only included to satisfy the interface.
+   * These methods are no-ops on the driver side
+   */
+  override def registerTimer(expiryTimestampMs: Long): Unit = {}

Review Comment:
   Maybe just throw errors in all these places ?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala:
##########
@@ -92,6 +95,24 @@ case class TransformWithStateExec(
     }
   }
 
+  private def getDriverProcessorHandle: DriverStatefulProcessorHandleImpl = {
+    val driverProcessorHandle = new DriverStatefulProcessorHandleImpl
+    statefulProcessor.setHandle(driverProcessorHandle)
+    statefulProcessor.init(outputMode, timeMode)
+    driverProcessorHandle
+  }
+  def getNewSchema(): List[ColumnFamilySchema] = {

Review Comment:
   nit: newline ?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala:
##########
@@ -92,6 +95,24 @@ case class TransformWithStateExec(
     }
   }
 
+  private def getDriverProcessorHandle: DriverStatefulProcessorHandleImpl = {
+    val driverProcessorHandle = new DriverStatefulProcessorHandleImpl
+    statefulProcessor.setHandle(driverProcessorHandle)
+    statefulProcessor.init(outputMode, timeMode)

Review Comment:
   Are we also adding a pre-init phase to track this ?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala:
##########
@@ -92,6 +95,24 @@ case class TransformWithStateExec(
     }
   }
 
+  private def getDriverProcessorHandle: DriverStatefulProcessorHandleImpl = {

Review Comment:
   Lets add a comment to explain where this is run and what we are trying to 
collect ?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SchemaHelper.scala:
##########
@@ -28,6 +33,49 @@ import org.apache.spark.util.Utils
 /**
  * Helper classes for reading/writing state schema.
  */
+sealed trait ColumnFamilySchema extends Serializable {
+  def jsonValue: JsonAST.JObject
+
+  def json: String
+}
+
+case class ColumnFamilySchemaV1(
+    val columnFamilyName: String,

Review Comment:
   I guess we don't need `val` in these places ?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala:
##########
@@ -278,25 +282,63 @@ case class StateStoreCustomTimingMetric(name: String, 
desc: String) extends Stat
     SQLMetrics.createTimingMetric(sparkContext, desc)
 }
 
-sealed trait KeyStateEncoderSpec
+sealed trait KeyStateEncoderSpec {
+  def jsonValue: JsonAST.JObject
+  def json: String = compact(render(jsonValue))
+}
+
+object KeyStateEncoderSpec {
+  def fromJson(m: Map[String, Any]): KeyStateEncoderSpec = {
+    // match on type
+    val keySchema = StructType.fromString(m("keySchema").asInstanceOf[String])
+    m("keyStateEncoderType").asInstanceOf[String] match {
+      case "NoPrefixKeyStateEncoderSpec" =>
+        NoPrefixKeyStateEncoderSpec(keySchema)
+      case "RangeKeyScanStateEncoderSpec" =>
+        val orderingOrdinals = m("orderingOrdinals").
+          asInstanceOf[List[_]].map(_.asInstanceOf[Int])
+        RangeKeyScanStateEncoderSpec(keySchema, orderingOrdinals)
+      case "PrefixKeyScanStateEncoderSpec" =>
+        val numColsPrefixKey = m("numColsPrefixKey").asInstanceOf[Int]
+        PrefixKeyScanStateEncoderSpec(keySchema, numColsPrefixKey)
+    }
+  }
+}
 
-case class NoPrefixKeyStateEncoderSpec(keySchema: StructType) extends 
KeyStateEncoderSpec
+case class NoPrefixKeyStateEncoderSpec(keySchema: StructType) extends 
KeyStateEncoderSpec {
+  override def jsonValue: JsonAST.JObject = {
+    ("keyStateEncoderType" -> JString("NoPrefixKeyStateEncoderSpec")) ~
+      ("keySchema" -> JString(keySchema.json))
+  }
+}
 
 case class PrefixKeyScanStateEncoderSpec(
     keySchema: StructType,
     numColsPrefixKey: Int) extends KeyStateEncoderSpec {
   if (numColsPrefixKey == 0 || numColsPrefixKey >= keySchema.length) {
     throw 
StateStoreErrors.incorrectNumOrderingColsForPrefixScan(numColsPrefixKey.toString)
   }
+
+  override def jsonValue: JsonAST.JObject = {

Review Comment:
   Can we not return `JValue` ? maybe we can follow the format here 
-https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala#L137
   
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to