sahnib commented on code in PR #45674:
URL: https://github.com/apache/spark/pull/45674#discussion_r1554114021


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala:
##########
@@ -312,187 +312,109 @@ class ValueStateSuite extends StateVariableSuiteBase {
     }
   }
 
-  Seq(TTLMode.ProcessingTimeTTL(), TTLMode.EventTimeTTL()).foreach { ttlMode =>
-    test(s"test Value state TTL for $ttlMode") {
-      tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
-        val store = provider.getStore(0)
-        val timestampMs = 10
-        val handle = createHandleForTtlMode(ttlMode, store, timestampMs)
-
-        val testState: ValueStateImplWithTTL[String] = 
handle.getValueState[String]("testState",
-          Encoders.STRING).asInstanceOf[ValueStateImplWithTTL[String]]
-        ImplicitGroupingKeyTracker.setImplicitKey("test_key")
-        testState.update("v1")
-        assert(testState.get() === "v1")
-        assert(testState.getWithoutEnforcingTTL().get === "v1")
-
-        var ttlValue = testState.getTTLValue()
-        assert(ttlValue.isEmpty)
-        var ttlStateValueIterator = testState.getValuesInTTLState()
-        assert(ttlStateValueIterator.isEmpty)
-
-        testState.clear()
-        assert(!testState.exists())
-        assert(testState.get() === null)
-
-        val ttlExpirationMs = timestampMs + 60000
-
-        if (ttlMode == TTLMode.ProcessingTimeTTL()) {
-          testState.update("v1", Duration.ofMinutes(1))
-        } else {
-          testState.update("v1", ttlExpirationMs)
-        }
-        assert(testState.get() === "v1")
-        assert(testState.getWithoutEnforcingTTL().get === "v1")
-
-        ttlValue = testState.getTTLValue()
-        assert(ttlValue.isDefined)
-        assert(ttlValue.get === ttlExpirationMs)
-        ttlStateValueIterator = testState.getValuesInTTLState()
-        assert(ttlStateValueIterator.hasNext)
-        assert(ttlStateValueIterator.next() === ttlExpirationMs)
-        assert(ttlStateValueIterator.isEmpty)
-
-        // increment batchProcessingTime, or watermark and ensure expired 
value is not returned
-        val nextBatchHandle = createHandleForTtlMode(ttlMode, store, 
ttlExpirationMs)
-
-        val nextBatchTestState: ValueStateImplWithTTL[String] = nextBatchHandle
-          .getValueState[String]("testState", Encoders.STRING)
-          .asInstanceOf[ValueStateImplWithTTL[String]]
-        ImplicitGroupingKeyTracker.setImplicitKey("test_key")
-
-        // ensure get does not return the expired value
-        assert(!nextBatchTestState.exists())
-        assert(nextBatchTestState.get() === null)
-
-        // ttl value should still exist in state
-        ttlValue = nextBatchTestState.getTTLValue()
-        assert(ttlValue.isDefined)
-        assert(ttlValue.get === ttlExpirationMs)
-        ttlStateValueIterator = nextBatchTestState.getValuesInTTLState()
-        assert(ttlStateValueIterator.hasNext)
-        assert(ttlStateValueIterator.next() === ttlExpirationMs)
-        assert(ttlStateValueIterator.isEmpty)
-
-        // getWithoutTTL should still return the expired value
-        assert(nextBatchTestState.getWithoutEnforcingTTL().get === "v1")
-
-        nextBatchTestState.clear()
-        assert(!nextBatchTestState.exists())
-        assert(nextBatchTestState.get() === null)
-
-        nextBatchTestState.clear()
-        assert(!nextBatchTestState.exists())
-        assert(nextBatchTestState.get() === null)
-      }
-    }
-  }
 
-  test("test TTL duration throws error for event time") {
+  test(s"test Value state TTL") {
     tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
       val store = provider.getStore(0)
-      val eventTimeWatermarkMs = 10
+      val timestampMs = 10
       val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
         Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]],
-        TTLMode.EventTimeTTL(), TimeoutMode.NoTimeouts(),
-        eventTimeWatermarkMs = Some(eventTimeWatermarkMs))
+        TTLMode.ProcessingTimeTTL(), TimeoutMode.NoTimeouts(),
+        batchTimestampMs = Some(timestampMs))
 
+      val ttlConfig = TTLConfig(ttlDuration = Duration.ofMinutes(1))
       val testState: ValueStateImplWithTTL[String] = 
handle.getValueState[String]("testState",
-        Encoders.STRING).asInstanceOf[ValueStateImplWithTTL[String]]
+        Encoders.STRING, ttlConfig).asInstanceOf[ValueStateImplWithTTL[String]]
       ImplicitGroupingKeyTracker.setImplicitKey("test_key")
+      testState.update("v1")
+      assert(testState.get() === "v1")
+      assert(testState.getWithoutEnforcingTTL().get === "v1")
 
-      val ex = intercept[SparkUnsupportedOperationException] {
-        testState.update("v1", Duration.ofMinutes(1))
-      }
+      var ttlValue = testState.getTTLValue()
+      assert(ttlValue.isEmpty)
+      var ttlStateValueIterator = testState.getValuesInTTLState()
+      assert(ttlStateValueIterator.isEmpty)
 
-      checkError(
-        ex,
-        errorClass = 
"STATEFUL_PROCESSOR_CANNOT_USE_TTL_DURATION_IN_EVENT_TIME_TTL_MODE",
-        parameters = Map(
-          "operationType" -> "update",
-          "stateName" -> "testState"
-        ),
-        matchPVals = true
-      )
-    }
-  }
+      testState.clear()
+      assert(!testState.exists())
+      assert(testState.get() === null)
 
-  test("test negative TTL duration throws error") {
-    tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
-      val store = provider.getStore(0)
-      val batchTimestampMs = 10
-      val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
+      val ttlExpirationMs = timestampMs + 60000
+
+      testState.update("v1")
+      assert(testState.get() === "v1")
+      assert(testState.getWithoutEnforcingTTL().get === "v1")
+
+      ttlValue = testState.getTTLValue()
+      assert(ttlValue.isDefined)
+      assert(ttlValue.get === ttlExpirationMs)
+      ttlStateValueIterator = testState.getValuesInTTLState()
+      assert(ttlStateValueIterator.hasNext)
+      assert(ttlStateValueIterator.next() === ttlExpirationMs)
+      assert(ttlStateValueIterator.isEmpty)
+
+      // increment batchProcessingTime, or watermark and ensure expired value 
is not returned
+      val nextBatchHandle = new StatefulProcessorHandleImpl(store, 
UUID.randomUUID(),
         Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]],
         TTLMode.ProcessingTimeTTL(), TimeoutMode.NoTimeouts(),
-        batchTimestampMs = Some(batchTimestampMs))
+        batchTimestampMs = Some(timestampMs))
 
-      val testState: ValueStateImplWithTTL[String] = 
handle.getValueState[String]("testState",
-        Encoders.STRING).asInstanceOf[ValueStateImplWithTTL[String]]
+      val nextBatchTestState: ValueStateImplWithTTL[String] = nextBatchHandle

Review Comment:
   Okay, on this part - we would not allow removing ttlConfig from a state 
variable, or adding to a variable which did not have ttl previously. (This will 
be enforced as part of state metadata - cc: @anishshri-db ). 
   
   The reason to disallow is that schema is different between 
ValueState/ValueStateWithTTL (ttlExpiration column wont exist in first one), 
and mixing these 2 schemas in StateStore for a column family complicates things 
significantly. If a user wants to remove ttl, they can simply create a new 
State Variable to store future state without enforcing ttl. 
   
   Let me know if this makes sense @HeartSaVioR 
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to