anishshri-db commented on code in PR #48710:
URL: https://github.com/apache/spark/pull/48710#discussion_r1823635648


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala:
##########
@@ -997,4 +997,88 @@ class StateDataSourceTransformWithStateSuite extends 
StateStoreMetricsTest
       }
     }
   }
+
+  /**
+   * Note that we cannot use the golden files approach for transformWithState. 
The new schema
+   * format keeps track of the schema file path as an absolute path which 
cannot be used with
+   * the getResource model used in other similar tests. Hence, we force the 
snapshot creation
+   * for given versions and ensure that we are loading from given start 
snapshot version for loading
+   * the state data.
+   */
+  testWithChangelogCheckpointingEnabled("snapshotStartBatchId with 
transformWithState") {
+    class AggregationStatefulProcessor extends StatefulProcessor[Int, (Int, 
Long), (Int, Long)] {
+      @transient protected var _countState: ValueState[Long] = _
+
+      override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = {
+        _countState = getHandle.getValueState[Long]("countState", 
Encoders.scalaLong)
+      }
+
+      override def handleInputRows(
+          key: Int,
+          inputRows: Iterator[(Int, Long)],
+          timerValues: TimerValues): Iterator[(Int, Long)] = {
+        val count = _countState.getOption().getOrElse(0L)
+        var totalSum = 0L
+        inputRows.foreach { entry =>
+          totalSum += entry._2
+        }
+        _countState.update(count + totalSum)
+        Iterator((key, count + totalSum))
+      }
+    }
+
+    withTempDir { tmpDir =>
+      withSQLConf(
+        SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+        classOf[RocksDBStateStoreProvider].getName,
+        SQLConf.SHUFFLE_PARTITIONS.key ->
+          TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString,
+        SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100",
+        SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1") {
+        val inputData = MemoryStream[(Int, Long)]
+        val query = inputData
+          .toDS()
+          .groupByKey(_._1)
+          .transformWithState(new AggregationStatefulProcessor(),
+            TimeMode.None(),
+            OutputMode.Append())
+        testStream(query)(
+          StartStream(checkpointLocation = tmpDir.getCanonicalPath),
+          AddData(inputData, (1, 1L), (2, 2L), (3, 3L)),
+          ProcessAllAvailable(),
+          Execute { _ => Thread.sleep(2000) },

Review Comment:
   Removed some of these. Still need to keep one to give a chance for maint 
thread to upload. Unfortunately we cant do golden files approach for this 
operator



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to