anishshri-db commented on code in PR #45932:
URL: https://github.com/apache/spark/pull/45932#discussion_r1563064461


##########
sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala:
##########
@@ -171,203 +160,15 @@ case class MultipleValueStatesTTLProcessor(
   }
 }
 
-/**
- * Tests that ttl works as expected for Value State for
- * processing time and event time based ttl.
- */
-class TransformWithValueStateTTLSuite
-  extends StreamTest {
-  import testImplicits._
-
-  test("validate state is evicted at ttl expiry") {
-    withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
-      classOf[RocksDBStateStoreProvider].getName) {
-      withTempDir { dir =>
-        val inputStream = MemoryStream[InputEvent]
-        val ttlConfig = TTLConfig(ttlDuration = Duration.ofMinutes(1))
-        val result = inputStream.toDS()
-          .groupByKey(x => x.key)
-          .transformWithState(
-            new ValueStateTTLProcessor(ttlConfig),
-            TimeMode.ProcessingTime(),
-            OutputMode.Append())
-
-        val clock = new StreamManualClock
-        testStream(result)(
-          StartStream(
-            Trigger.ProcessingTime("1 second"),
-            triggerClock = clock,
-            checkpointLocation = dir.getAbsolutePath),
-          AddData(inputStream, InputEvent("k1", "put", 1)),
-          // advance clock to trigger processing
-          AdvanceManualClock(1 * 1000),
-          CheckNewAnswer(),
-          StopStream,
-          StartStream(
-            Trigger.ProcessingTime("1 second"),
-            triggerClock = clock,
-            checkpointLocation = dir.getAbsolutePath),
-          // get this state, and make sure we get unexpired value
-          AddData(inputStream, InputEvent("k1", "get", -1)),
-          AdvanceManualClock(1 * 1000),
-          CheckNewAnswer(OutputEvent("k1", 1, isTTLValue = false, -1)),
-          StopStream,
-          StartStream(
-            Trigger.ProcessingTime("1 second"),
-            triggerClock = clock,
-            checkpointLocation = dir.getAbsolutePath),
-          // ensure ttl values were added correctly
-          AddData(inputStream, InputEvent("k1", "get_ttl_value_from_state", 
-1)),
-          AdvanceManualClock(1 * 1000),
-          CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 61000)),
-          AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", 
-1)),
-          AdvanceManualClock(1 * 1000),
-          CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 61000)),
-          StopStream,
-          StartStream(
-            Trigger.ProcessingTime("1 second"),
-            triggerClock = clock,
-            checkpointLocation = dir.getAbsolutePath),
-          // advance clock so that state expires
-          AdvanceManualClock(60 * 1000),
-          AddData(inputStream, InputEvent("k1", "get", -1, null)),
-          AdvanceManualClock(1 * 1000),
-          // validate expired value is not returned
-          CheckNewAnswer(),
-          // ensure this state does not exist any longer in State
-          AddData(inputStream, InputEvent("k1", "get_without_enforcing_ttl", 
-1)),
-          AdvanceManualClock(1 * 1000),
-          CheckNewAnswer(),
-          AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", 
-1)),
-          AdvanceManualClock(1 * 1000),
-          CheckNewAnswer(),
-          Execute { q =>
-            // Filter for idle progress events and then verify the custom 
metrics
-            // for stateful operator
-            val progData = q.recentProgress.filter(prog => 
prog.stateOperators.size > 0)
-            assert(progData.filter(prog =>
-              
prog.stateOperators(0).customMetrics.get("numValueStateWithTTLVars") > 0).size 
> 0)
-            assert(progData.filter(prog =>
-              prog.stateOperators(0).customMetrics
-                .get("numValuesRemovedDueToTTLExpiry") > 0).size > 0)
-          }
-        )
-      }
-    }
-  }
-
-  test("validate state update updates the expiration timestamp") {
-    withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
-      classOf[RocksDBStateStoreProvider].getName) {
-      val inputStream = MemoryStream[InputEvent]
-      val ttlConfig = TTLConfig(ttlDuration = Duration.ofMinutes(1))
-      val result = inputStream.toDS()
-        .groupByKey(x => x.key)
-        .transformWithState(
-          new ValueStateTTLProcessor(ttlConfig),
-          TimeMode.ProcessingTime(),
-          OutputMode.Append())
+class TransformWithValueStateTTLSuite extends TransformWithStateTTLTest {
 
-      val clock = new StreamManualClock
-      testStream(result)(
-        StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock),
-        AddData(inputStream, InputEvent("k1", "put", 1)),
-        // advance clock to trigger processing
-        AdvanceManualClock(1 * 1000),
-        CheckNewAnswer(),
-        // get this state, and make sure we get unexpired value
-        AddData(inputStream, InputEvent("k1", "get", -1)),
-        AdvanceManualClock(1 * 1000),
-        CheckNewAnswer(OutputEvent("k1", 1, isTTLValue = false, -1)),
-        // ensure ttl values were added correctly
-        AddData(inputStream, InputEvent("k1", "get_ttl_value_from_state", -1)),
-        AdvanceManualClock(1 * 1000),
-        CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 61000)),
-        AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1)),
-        AdvanceManualClock(1 * 1000),
-        CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 61000)),
-        // advance clock and update expiration time
-        AdvanceManualClock(30 * 1000),
-        AddData(inputStream, InputEvent("k1", "put", 1)),
-        AddData(inputStream, InputEvent("k1", "get", -1)),
-        // advance clock to trigger processing
-        AdvanceManualClock(1 * 1000),
-        // validate value is not expired
-        CheckNewAnswer(OutputEvent("k1", 1, isTTLValue = false, -1)),
-        // validate ttl value is updated in the state
-        AddData(inputStream, InputEvent("k1", "get_ttl_value_from_state", -1)),
-        AdvanceManualClock(1 * 1000),
-        CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 95000)),
-        // validate ttl state has both ttl values present
-        AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1)),
-        AdvanceManualClock(1 * 1000),
-        CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 61000),
-          OutputEvent("k1", -1, isTTLValue = true, 95000)
-        ),
-        // advance clock after older expiration value
-        AdvanceManualClock(30 * 1000),
-        // ensure unexpired value is still present in the state
-        AddData(inputStream, InputEvent("k1", "get", -1)),
-        AdvanceManualClock(1 * 1000),
-        CheckNewAnswer(OutputEvent("k1", 1, isTTLValue = false, -1)),
-        // validate that the older expiration value is removed from ttl state
-        AddData(inputStream, InputEvent("k1", "get_values_in_ttl_state", -1)),
-        AdvanceManualClock(1 * 1000),
-        CheckNewAnswer(OutputEvent("k1", -1, isTTLValue = true, 95000))
-      )
-    }
+  import testImplicits._
+  override def getProcessor(ttlConfig: TTLConfig):

Review Comment:
   Same here 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to