anishshri-db commented on code in PR #43425:
URL: https://github.com/apache/spark/pull/43425#discussion_r1380729958


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala:
##########
@@ -0,0 +1,216 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.v2.state
+
+import java.util
+import java.util.UUID
+
+import scala.util.control.NonFatal
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.sql.{RuntimeConfig, SparkSession}
+import org.apache.spark.sql.connector.catalog.{Table, TableProvider}
+import org.apache.spark.sql.connector.expressions.Transform
+import 
org.apache.spark.sql.execution.datasources.v2.state.StateDataSource.JoinSideValues.JoinSideValues
+import org.apache.spark.sql.execution.streaming.{CommitLog, OffsetSeqLog, 
OffsetSeqMetadata}
+import 
org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.{LeftSide,
 RightSide}
+import 
org.apache.spark.sql.execution.streaming.state.{StateSchemaCompatibilityChecker,
 StateStore, StateStoreConf, StateStoreId, StateStoreProviderId}
+import org.apache.spark.sql.sources.DataSourceRegister
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+
+/**
+ * An implementation of [[TableProvider]] with [[DataSourceRegister]] for 
State Store data source.
+ */
+class StateDataSource extends TableProvider with DataSourceRegister {
+  import StateDataSource._
+
+  private lazy val session: SparkSession = SparkSession.active
+
+  private lazy val hadoopConf: Configuration = 
session.sessionState.newHadoopConf()
+
+  override def shortName(): String = "statestore"
+
+  override def getTable(
+      schema: StructType,
+      partitioning: Array[Transform],
+      properties: util.Map[String, String]): Table = {
+    val sourceOptions = StateSourceOptions.apply(session, hadoopConf, 
properties)
+    val stateConf = buildStateStoreConf(sourceOptions.resolvedCpLocation, 
sourceOptions.batchId)
+    new StateTable(session, schema, sourceOptions, stateConf)
+  }
+
+  override def inferSchema(options: CaseInsensitiveStringMap): StructType = {
+    val partitionId = StateStore.PARTITION_ID_TO_CHECK_SCHEMA
+    val sourceOptions = StateSourceOptions.apply(session, hadoopConf, options)
+    if (sourceOptions.joinSide != JoinSideValues.none &&
+        sourceOptions.storeName != StateStoreId.DEFAULT_STORE_NAME) {
+      throw new IllegalArgumentException(s"The options '$PARAM_JOIN_SIDE' and 
" +
+        s"'$PARAM_STORE_NAME' cannot be specified together. Please specify 
either one.")
+    }
+
+    val stateCheckpointLocation = sourceOptions.stateCheckpointLocation
+
+    try {
+      val (keySchema, valueSchema) = sourceOptions.joinSide match {
+        case JoinSideValues.left =>
+          StreamStreamJoinStateHelper.readKeyValueSchema(session, 
stateCheckpointLocation.toString,
+            sourceOptions.operatorId, LeftSide)
+
+        case JoinSideValues.right =>
+          StreamStreamJoinStateHelper.readKeyValueSchema(session, 
stateCheckpointLocation.toString,
+            sourceOptions.operatorId, RightSide)
+
+        case JoinSideValues.none =>
+          val storeId = new StateStoreId(stateCheckpointLocation.toString, 
sourceOptions.operatorId,
+            partitionId, sourceOptions.storeName)
+          val providerId = new StateStoreProviderId(storeId, UUID.randomUUID())
+          val manager = new StateSchemaCompatibilityChecker(providerId, 
hadoopConf)
+          manager.readSchemaFile()
+      }
+
+      new StructType()
+        .add("key", keySchema)
+        .add("value", valueSchema)
+    } catch {
+      case NonFatal(e) =>
+        throw new IllegalArgumentException("Fail to read the state schema. 
Either the file " +

Review Comment:
   Nit: `Failed to read the`



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala:
##########
@@ -0,0 +1,779 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.v2.state
+
+import java.io.{File, FileWriter}
+
+import org.scalatest.Assertions
+
+import org.apache.spark.SparkUnsupportedOperationException
+import org.apache.spark.io.CompressionCodec
+import org.apache.spark.sql.{AnalysisException, DataFrame, Encoders, Row}
+import org.apache.spark.sql.catalyst.expressions.{BoundReference, 
GenericInternalRow}
+import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
+import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil
+import org.apache.spark.sql.execution.streaming.{CommitLog, MemoryStream, 
OffsetSeqLog}
+import 
org.apache.spark.sql.execution.streaming.state.{HDFSBackedStateStoreProvider, 
RocksDBStateStoreProvider, StateStore}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.OutputMode
+import org.apache.spark.sql.types.{IntegerType, StructType}
+
+class StateDataSourceNegativeTestSuite extends StateDataSourceTestBase {
+  import testImplicits._
+
+  test("ERROR: read the state from stateless query") {
+    withTempDir { tempDir =>
+      val inputData = MemoryStream[Int]
+      val df = inputData.toDF()
+        .selectExpr("value", "value % 2 AS value2")
+
+      testStream(df)(
+        StartStream(checkpointLocation = tempDir.getAbsolutePath),
+        AddData(inputData, 1, 2, 3, 4, 5),
+        CheckLastBatch((1, 1), (2, 0), (3, 1), (4, 0), (5, 1)),
+        AddData(inputData, 6, 7, 8),
+        CheckLastBatch((6, 0), (7, 1), (8, 0))
+      )
+
+      intercept[IllegalArgumentException] {
+        spark.read.format("statestore").load(tempDir.getAbsolutePath)
+      }
+    }
+  }
+
+  test("ERROR: no committed batch on default batch ID") {
+    withTempDir { tempDir =>
+      runLargeDataStreamingAggregationQuery(tempDir.getAbsolutePath)
+
+      val offsetLog = new OffsetSeqLog(spark,
+        new File(tempDir.getAbsolutePath, "offsets").getAbsolutePath)
+      val commitLog = new CommitLog(spark,
+        new File(tempDir.getAbsolutePath, "commits").getAbsolutePath)
+
+      offsetLog.purgeAfter(0)
+      commitLog.purgeAfter(-1)
+
+      intercept[IllegalStateException] {
+        spark.read.format("statestore").load(tempDir.getAbsolutePath)
+      }
+    }
+  }
+
+  test("ERROR: corrupted state schema file") {
+    withTempDir { tempDir =>
+      runLargeDataStreamingAggregationQuery(tempDir.getAbsolutePath)
+
+      def rewriteStateSchemaFileToDummy(): Unit = {
+        // Refer to the StateSchemaCompatibilityChecker for the path of state 
schema file
+        val pathForSchema = Seq(
+          "state", "0", StateStore.PARTITION_ID_TO_CHECK_SCHEMA.toString,
+          "_metadata", "schema"
+        ).foldLeft(tempDir) { case (file, dirName) =>
+          new File(file, dirName)
+        }
+
+        assert(pathForSchema.exists())
+        assert(pathForSchema.delete())
+
+        val fileWriter = new FileWriter(pathForSchema)
+        fileWriter.write("lol dummy corrupted schema file")
+        fileWriter.close()
+
+        assert(pathForSchema.exists())
+      }
+
+      rewriteStateSchemaFileToDummy()
+
+      intercept[IllegalArgumentException] {
+        spark.read.format("statestore").load(tempDir.getAbsolutePath)
+      }
+    }
+  }
+
+  test("ERROR: path is not specified") {
+    intercept[IllegalArgumentException] {
+      spark.read.format("statestore").load()
+    }
+  }
+
+  test("ERROR: operator ID specified to negative") {
+    withTempDir { tempDir =>
+      intercept[IllegalArgumentException] {
+        spark.read.format("statestore")
+          .option(StateDataSource.PARAM_OPERATOR_ID, -1)
+          // trick to bypass getting the last committed batch before 
validating operator ID
+          .option(StateDataSource.PARAM_BATCH_ID, 0)
+          .load(tempDir.getAbsolutePath)
+      }
+    }
+  }
+
+  test("ERROR: batch ID specified to negative") {
+    withTempDir { tempDir =>
+      intercept[IllegalArgumentException] {
+        spark.read.format("statestore")
+          .option(StateDataSource.PARAM_BATCH_ID, -1)
+          .load(tempDir.getAbsolutePath)
+      }
+    }
+  }
+
+  test("ERROR: store name is empty") {
+    withTempDir { tempDir =>
+      intercept[IllegalArgumentException] {
+        spark.read.format("statestore")
+          .option(StateDataSource.PARAM_STORE_NAME, "")
+          // trick to bypass getting the last committed batch before 
validating operator ID
+          .option(StateDataSource.PARAM_BATCH_ID, 0)
+          .load(tempDir.getAbsolutePath)
+      }
+    }
+  }
+
+  test("ERROR: invalid value for joinSide option") {
+    withTempDir { tempDir =>
+      intercept[IllegalArgumentException] {
+        spark.read.format("statestore")
+          .option(StateDataSource.PARAM_JOIN_SIDE, "both")
+          // trick to bypass getting the last committed batch before 
validating operator ID
+          .option(StateDataSource.PARAM_BATCH_ID, 0)
+          .load(tempDir.getAbsolutePath)
+      }
+    }
+  }
+
+  test("ERROR: both options `joinSide` and `storeName` are specified") {
+    withTempDir { tempDir =>
+      intercept[IllegalArgumentException] {
+        spark.read.format("statestore")
+          .option(StateDataSource.PARAM_JOIN_SIDE, "right")
+          .option(StateDataSource.PARAM_STORE_NAME, "right-keyToNumValues")
+          // trick to bypass getting the last committed batch before 
validating operator ID
+          .option(StateDataSource.PARAM_BATCH_ID, 0)
+          .load(tempDir.getAbsolutePath)
+      }
+    }
+  }
+
+  test("ERROR: trying to read state data as stream") {
+    withTempDir { tempDir =>
+      runLargeDataStreamingAggregationQuery(tempDir.getAbsolutePath)
+
+      intercept[SparkUnsupportedOperationException] {
+        spark.readStream.format("statestore").load(tempDir.getAbsolutePath)
+          .writeStream.format("noop").start()
+      }
+    }
+  }
+}
+
+/**
+ * Here we build a combination of test criteria for
+ * 1) number of shuffle partitions
+ * 2) state store provider
+ * 3) compression codec
+ * and run one of the test to verify that above configs work.
+ *
+ * We are building 3 x 2 x 4 = 24 different test criteria, and it's probably 
waste of time
+ * and resource to run all combinations for all times, hence we will randomly 
pick 5 tests
+ * per run.
+ */
+class StateDataSourceSQLConfigSuite extends StateDataSourceTestBase {
+
+  private val TEST_SHUFFLE_PARTITIONS = Seq(1, 3, 5)
+  private val TEST_PROVIDERS = Seq(
+    classOf[HDFSBackedStateStoreProvider].getName,
+    classOf[RocksDBStateStoreProvider].getName
+  )
+  private val TEST_COMPRESSION_CODECS = CompressionCodec.ALL_COMPRESSION_CODECS
+
+  private val ALL_COMBINATIONS = {
+    val comb = for (
+      part <- TEST_SHUFFLE_PARTITIONS;
+      provider <- TEST_PROVIDERS;
+      codec <- TEST_COMPRESSION_CODECS
+    ) yield {
+      (part, provider, codec)
+    }
+    scala.util.Random.shuffle(comb)
+  }
+
+  ALL_COMBINATIONS.take(5).foreach { case (part, provider, codec) =>
+    val testName = s"Verify the read with config 
[part=$part][provider=$provider][codec=$codec]"
+    test(testName) {
+      withTempDir { tempDir =>
+        withSQLConf(
+          SQLConf.SHUFFLE_PARTITIONS.key -> part.toString,
+          SQLConf.STATE_STORE_PROVIDER_CLASS.key -> provider,
+          SQLConf.STATE_STORE_COMPRESSION_CODEC.key -> codec) {
+
+          runLargeDataStreamingAggregationQuery(tempDir.getAbsolutePath)
+
+          verifyLargeDataStreamingAggregationQuery(tempDir.getAbsolutePath)
+        }
+      }
+    }
+  }
+
+  test("Use different configs than session config") {
+    withTempDir { tempDir =>
+      withSQLConf(
+        SQLConf.SHUFFLE_PARTITIONS.key -> "3",
+        SQLConf.STATE_STORE_PROVIDER_CLASS.key -> 
classOf[RocksDBStateStoreProvider].getName,
+        SQLConf.STATE_STORE_COMPRESSION_CODEC.key -> "zstd") {
+
+        runLargeDataStreamingAggregationQuery(tempDir.getAbsolutePath)
+      }
+
+      // Set the different values in session config, to validate whether state 
data source refers
+      // to the config in offset log.
+      withSQLConf(
+        SQLConf.SHUFFLE_PARTITIONS.key -> "5",
+        SQLConf.STATE_STORE_PROVIDER_CLASS.key -> 
classOf[HDFSBackedStateStoreProvider].getName,
+        SQLConf.STATE_STORE_COMPRESSION_CODEC.key -> "lz4") {
+
+        verifyLargeDataStreamingAggregationQuery(tempDir.getAbsolutePath)
+      }
+    }
+  }
+
+  private def verifyLargeDataStreamingAggregationQuery(checkpointLocation: 
String): Unit = {
+    val operatorId = 0
+    val batchId = 2
+
+    val stateReadDf = spark.read
+      .format("statestore")
+      .option(StateDataSource.PARAM_PATH, checkpointLocation)
+      // explicitly specifying batch ID and operator ID to test out the 
functionality
+      .option(StateDataSource.PARAM_BATCH_ID, batchId)
+      .option(StateDataSource.PARAM_OPERATOR_ID, operatorId)
+      .load()
+
+    val resultDf = stateReadDf
+      .selectExpr("key.groupKey AS key_groupKey", "value.count AS value_cnt",
+        "value.sum AS value_sum", "value.max AS value_max", "value.min AS 
value_min")
+
+    checkAnswer(
+      resultDf,
+      Seq(
+        Row(0, 5, 60, 30, 0), // 0, 10, 20, 30
+        Row(1, 5, 65, 31, 1), // 1, 11, 21, 31
+        Row(2, 5, 70, 32, 2), // 2, 12, 22, 32
+        Row(3, 4, 72, 33, 3), // 3, 13, 23, 33
+        Row(4, 4, 76, 34, 4), // 4, 14, 24, 34
+        Row(5, 4, 80, 35, 5), // 5, 15, 25, 35
+        Row(6, 4, 84, 36, 6), // 6, 16, 26, 36
+        Row(7, 4, 88, 37, 7), // 7, 17, 27, 37
+        Row(8, 4, 92, 38, 8), // 8, 18, 28, 38
+        Row(9, 4, 96, 39, 9) // 9, 19, 29, 39
+      )
+    )
+  }
+}
+
+class HDFSBackedStateDataSourceReadSuite extends StateDataSourceReadSuite {
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    spark.conf.set(SQLConf.STATE_STORE_PROVIDER_CLASS.key,
+      classOf[HDFSBackedStateStoreProvider].getName)
+  }
+}
+
+class RocksDBStateDataSourceReadSuite extends StateDataSourceReadSuite {
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    spark.conf.set(SQLConf.STATE_STORE_PROVIDER_CLASS.key,
+      classOf[RocksDBStateStoreProvider].getName)
+    
spark.conf.set("spark.sql.streaming.stateStore.rocksdb.changelogCheckpointing.enabled",
+      "false")
+  }
+}
+
+class RocksDBWithChangelogCheckpointStateDataSourceReaderSuite extends 
StateDataSourceReadSuite {
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    spark.conf.set(SQLConf.STATE_STORE_PROVIDER_CLASS.key,
+      classOf[RocksDBStateStoreProvider].getName)
+    
spark.conf.set("spark.sql.streaming.stateStore.rocksdb.changelogCheckpointing.enabled",
+      "true")
+  }
+}
+
+abstract class StateDataSourceReadSuite extends StateDataSourceTestBase with 
Assertions {

Review Comment:
   Maybe just add a 1 line class comment ?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to