anishshri-db commented on code in PR #48710: URL: https://github.com/apache/spark/pull/48710#discussion_r1823635648
########## sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala: ########## @@ -997,4 +997,88 @@ class StateDataSourceTransformWithStateSuite extends StateStoreMetricsTest } } } + + /** + * Note that we cannot use the golden files approach for transformWithState. The new schema + * format keeps track of the schema file path as an absolute path which cannot be used with + * the getResource model used in other similar tests. Hence, we force the snapshot creation + * for given versions and ensure that we are loading from given start snapshot version for loading + * the state data. + */ + testWithChangelogCheckpointingEnabled("snapshotStartBatchId with transformWithState") { + class AggregationStatefulProcessor extends StatefulProcessor[Int, (Int, Long), (Int, Long)] { + @transient protected var _countState: ValueState[Long] = _ + + override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = { + _countState = getHandle.getValueState[Long]("countState", Encoders.scalaLong) + } + + override def handleInputRows( + key: Int, + inputRows: Iterator[(Int, Long)], + timerValues: TimerValues): Iterator[(Int, Long)] = { + val count = _countState.getOption().getOrElse(0L) + var totalSum = 0L + inputRows.foreach { entry => + totalSum += entry._2 + } + _countState.update(count + totalSum) + Iterator((key, count + totalSum)) + } + } + + withTempDir { tmpDir => + withSQLConf( + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString, + SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", + SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1") { + val inputData = MemoryStream[(Int, Long)] + val query = inputData + .toDS() + .groupByKey(_._1) + .transformWithState(new AggregationStatefulProcessor(), + TimeMode.None(), + OutputMode.Append()) + testStream(query)( + StartStream(checkpointLocation = tmpDir.getCanonicalPath), + AddData(inputData, (1, 1L), (2, 2L), (3, 3L)), + ProcessAllAvailable(), + Execute { _ => Thread.sleep(2000) }, Review Comment: Removed some of these. Still need to keep one to give a chance for maint thread to upload. Unfortunately we cant do golden files approach for this operator -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org