Re: [PR] [SPARK-47805][SS] Implementing TTL for MapState [spark]

2024-04-23 Thread via GitHub


HeartSaVioR closed pull request #45991: [SPARK-47805][SS] Implementing TTL for 
MapState
URL: https://github.com/apache/spark/pull/45991


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



Re: [PR] [SPARK-47805][SS] Implementing TTL for MapState [spark]

2024-04-23 Thread via GitHub


HeartSaVioR commented on PR #45991:
URL: https://github.com/apache/spark/pull/45991#issuecomment-2071662156

   > [Run / Build modules: pyspark-sql, pyspark-resource, 
pyspark-testing](https://github.com/ericm-db/spark/actions/runs/8794474954/job/24137701938#logs)
   > failed 2 hours ago in 26m 30s
   
   Mostly about connection refused - doesn't seem to be related.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



Re: [PR] [SPARK-47805][SS] Implementing TTL for MapState [spark]

2024-04-23 Thread via GitHub


HeartSaVioR commented on PR #45991:
URL: https://github.com/apache/spark/pull/45991#issuecomment-2071662357

   Thanks! Merging to master.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



Re: [PR] [SPARK-47805][SS] Implementing TTL for MapState [spark]

2024-04-22 Thread via GitHub


HeartSaVioR commented on code in PR #45991:
URL: https://github.com/apache/spark/pull/45991#discussion_r1575597023


##
sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateTTLSuite.scala:
##
@@ -0,0 +1,308 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming
+
+import java.time.Duration
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.Encoders
+import org.apache.spark.sql.execution.streaming.{MapStateImplWithTTL, 
MemoryStream}
+import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.util.StreamManualClock
+
+class MapStateSingleKeyTTLProcessor(ttlConfig: TTLConfig)

Review Comment:
   Ah OK you are passing this over the base tests.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



Re: [PR] [SPARK-47805][SS] Implementing TTL for MapState [spark]

2024-04-22 Thread via GitHub


ericm-db commented on code in PR #45991:
URL: https://github.com/apache/spark/pull/45991#discussion_r1575592659


##
sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateTTLSuite.scala:
##
@@ -0,0 +1,308 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming
+
+import java.time.Duration
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.Encoders
+import org.apache.spark.sql.execution.streaming.{MapStateImplWithTTL, 
MemoryStream}
+import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.util.StreamManualClock
+
+class MapStateSingleKeyTTLProcessor(ttlConfig: TTLConfig)
+  extends StatefulProcessor[String, InputEvent, OutputEvent]
+with Logging {
+
+  @transient private var _mapState: MapStateImplWithTTL[String, Int] = _
+
+  override def init(
+  outputMode: OutputMode,
+  timeMode: TimeMode): Unit = {
+_mapState = getHandle
+  .getMapState("mapState", Encoders.STRING, Encoders.scalaInt, ttlConfig)
+  .asInstanceOf[MapStateImplWithTTL[String, Int]]
+  }
+  override def handleInputRows(
+key: String,
+inputRows: Iterator[InputEvent],
+timerValues: TimerValues,
+expiredTimerInfo: ExpiredTimerInfo): Iterator[OutputEvent] = {
+var results = List[OutputEvent]()
+
+for (row <- inputRows) {
+  val resultIter = processRow(row, _mapState)
+  resultIter.foreach { r =>
+results = r :: results
+  }
+}
+
+results.iterator
+  }
+
+  def processRow(
+  row: InputEvent,
+  mapState: MapStateImplWithTTL[String, Int]): Iterator[OutputEvent] = {
+var results = List[OutputEvent]()
+val key = row.key
+val userKey = "key"
+if (row.action == "get") {
+  if (mapState.containsKey(userKey)) {
+results = OutputEvent(key, mapState.getValue(userKey), isTTLValue = 
false, -1) :: results
+  }
+} else if (row.action == "get_without_enforcing_ttl") {
+  val currState = mapState.getWithoutEnforcingTTL(userKey)
+  if (currState.isDefined) {
+results = OutputEvent(key, currState.get, isTTLValue = false, -1) :: 
results
+  }
+} else if (row.action == "get_ttl_value_from_state") {
+  val ttlValue = mapState.getTTLValue(userKey)
+  if (ttlValue.isDefined) {
+val value = ttlValue.get._1
+val ttlExpiration = ttlValue.get._2
+results = OutputEvent(key, value, isTTLValue = true, ttlExpiration) :: 
results
+  }
+} else if (row.action == "put") {
+  mapState.updateValue(userKey, row.value)
+} else if (row.action == "get_values_in_ttl_state") {
+  val ttlValues = mapState.getKeyValuesInTTLState()
+  ttlValues.foreach { v =>
+results = OutputEvent(key, -1, isTTLValue = true, ttlValue = v._2) :: 
results
+  }
+}
+
+results.iterator
+  }
+}
+
+case class MapInputEvent(
+key: String,
+userKey: String,
+action: String,
+value: Int)
+
+case class MapOutputEvent(
+key: String,
+userKey: String,
+value: Int,
+isTTLValue: Boolean,
+ttlValue: Long)
+
+class MapStateTTLProcessor(ttlConfig: TTLConfig)
+  extends StatefulProcessor[String, MapInputEvent, MapOutputEvent]
+with Logging {
+
+  @transient private var _mapState: MapStateImplWithTTL[String, Int] = _
+
+  override def init(
+  outputMode: OutputMode,
+  timeMode: TimeMode): Unit = {
+_mapState = getHandle
+  .getMapState("mapState", Encoders.STRING, Encoders.scalaInt, ttlConfig)
+  .asInstanceOf[MapStateImplWithTTL[String, Int]]
+  }
+
+  override def handleInputRows(
+  key: String,
+  inputRows: Iterator[MapInputEvent],
+  timerValues: TimerValues,
+  expiredTimerInfo: ExpiredTimerInfo): Iterator[MapOutputEvent] = {
+var results = List[MapOutputEvent]()
+
+for (row <- inputRows) {
+  val resultIter = processRow(row, _mapState)
+  resultIter.foreach { r =>
+results = r :: results
+  }
+}
+
+results.iterator
+  }
+
+  def processRow(
+  row: MapInputEvent,
+  mapState: MapStateImplWithTTL[String, Int]): 

Re: [PR] [SPARK-47805][SS] Implementing TTL for MapState [spark]

2024-04-22 Thread via GitHub


ericm-db commented on code in PR #45991:
URL: https://github.com/apache/spark/pull/45991#discussion_r1575586347


##
sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateTTLSuite.scala:
##
@@ -0,0 +1,308 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming
+
+import java.time.Duration
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.Encoders
+import org.apache.spark.sql.execution.streaming.{MapStateImplWithTTL, 
MemoryStream}
+import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.util.StreamManualClock
+
+class MapStateSingleKeyTTLProcessor(ttlConfig: TTLConfig)

Review Comment:
   We need this processor since it has different input/output types than the 
other processor in this file.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



Re: [PR] [SPARK-47805][SS] Implementing TTL for MapState [spark]

2024-04-22 Thread via GitHub


HeartSaVioR commented on code in PR #45991:
URL: https://github.com/apache/spark/pull/45991#discussion_r1575569366


##
sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateTTLSuite.scala:
##
@@ -0,0 +1,308 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming
+
+import java.time.Duration
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.Encoders
+import org.apache.spark.sql.execution.streaming.{MapStateImplWithTTL, 
MemoryStream}
+import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.util.StreamManualClock
+
+class MapStateSingleKeyTTLProcessor(ttlConfig: TTLConfig)

Review Comment:
   Maybe redundant to have this separately?



##
sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateTTLSuite.scala:
##
@@ -0,0 +1,308 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming
+
+import java.time.Duration
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.Encoders
+import org.apache.spark.sql.execution.streaming.{MapStateImplWithTTL, 
MemoryStream}
+import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.util.StreamManualClock
+
+class MapStateSingleKeyTTLProcessor(ttlConfig: TTLConfig)
+  extends StatefulProcessor[String, InputEvent, OutputEvent]
+with Logging {

Review Comment:
   nit: we don't use this, right?



##
sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateTTLSuite.scala:
##
@@ -0,0 +1,308 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming
+
+import java.time.Duration
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.Encoders
+import org.apache.spark.sql.execution.streaming.{MapStateImplWithTTL, 
MemoryStream}
+import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.util.StreamManualClock
+
+class MapStateSingleKeyTTLProcessor(ttlConfig: TTLConfig)
+  extends StatefulProcessor[String, InputEvent, OutputEvent]
+with Logging {
+
+  @transient private var _mapState: MapStateImplWithTTL[String, Int] = _
+
+  override def init(
+  outputMode: OutputMode,
+  timeMode: TimeMode): Unit = {
+_mapState = getHandle
+  .getMapState("mapState", Encoders.STRING, Encoders.scalaInt, ttlConfig)
+  

Re: [PR] [SPARK-47805][SS] Implementing TTL for MapState [spark]

2024-04-22 Thread via GitHub


HeartSaVioR commented on code in PR #45991:
URL: https://github.com/apache/spark/pull/45991#discussion_r1575549218


##
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala:
##
@@ -0,0 +1,280 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.Encoder
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import 
org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{COMPOSITE_KEY_ROW_SCHEMA,
 VALUE_ROW_SCHEMA_WITH_TTL}
+import 
org.apache.spark.sql.execution.streaming.state.{PrefixKeyScanStateEncoderSpec, 
StateStore, StateStoreErrors}
+import org.apache.spark.sql.streaming.{MapState, TTLConfig}
+import org.apache.spark.util.NextIterator
+
+/**
+ * Class that provides a concrete implementation for map state associated with 
state
+ * variables (with ttl expiration support) used in the streaming 
transformWithState operator.
+ * @param store - reference to the StateStore instance to be used for storing 
state
+ * @param stateName  - name of the state variable
+ * @param keyExprEnc - Spark SQL encoder for key
+ * @param userKeyEnc  - Spark SQL encoder for the map key
+ * @param valEncoder - SQL encoder for state variable
+ * @param ttlConfig  - the ttl configuration (time to live duration etc.)
+ * @param batchTimestampMs - current batch processing timestamp.
+ * @tparam K - type of key for map state variable
+ * @tparam V - type of value for map state variable
+ * @return - instance of MapState of type [K,V] that can be used to store 
state persistently
+ */
+class MapStateImplWithTTL[K, V](
+store: StateStore,
+stateName: String,
+keyExprEnc: ExpressionEncoder[Any],
+userKeyEnc: Encoder[K],
+valEncoder: Encoder[V],
+ttlConfig: TTLConfig,
+batchTimestampMs: Long) extends CompositeKeyTTLStateImpl(stateName, store, 
batchTimestampMs)
+  with MapState[K, V] with Logging {
+
+  private val keySerializer = keyExprEnc.createSerializer()
+  private val stateTypesEncoder = new CompositeKeyStateEncoder(
+keySerializer, userKeyEnc, valEncoder, COMPOSITE_KEY_ROW_SCHEMA, 
stateName, hasTtl = true)
+
+  private val ttlExpirationMs =
+StateTTL.calculateExpirationTimeForDuration(ttlConfig.ttlDuration, 
batchTimestampMs)
+
+  createColumnFamily()
+
+  private def createColumnFamily(): Unit = {
+store.createColFamilyIfAbsent(stateName, COMPOSITE_KEY_ROW_SCHEMA, 
VALUE_ROW_SCHEMA_WITH_TTL,
+  PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, 1))
+  }
+
+  /** Whether state exists or not. */
+  override def exists(): Boolean = {
+iterator().nonEmpty
+  }
+
+  /** Get the state value if it exists */
+  override def getValue(key: K): V = {
+StateStoreErrors.requireNonNullStateValue(key, stateName)
+val encodedCompositeKey = stateTypesEncoder.encodeCompositeKey(key)
+val retRow = store.get(encodedCompositeKey, stateName)
+
+if (retRow != null) {
+  val resState = stateTypesEncoder.decodeValue(retRow)
+
+  if (!stateTypesEncoder.isExpired(retRow, batchTimestampMs)) {
+resState
+  } else {
+null.asInstanceOf[V]
+  }
+} else {
+  null.asInstanceOf[V]
+}
+  }
+
+  /** Check if the user key is contained in the map */
+  override def containsKey(key: K): Boolean = {
+StateStoreErrors.requireNonNullStateValue(key, stateName)
+getValue(key) != null
+  }
+
+  /** Update value for given user key */
+  override def updateValue(key: K, value: V): Unit = {
+StateStoreErrors.requireNonNullStateValue(key, stateName)
+StateStoreErrors.requireNonNullStateValue(value, stateName)
+
+val encodedValue = stateTypesEncoder.encodeValue(value, ttlExpirationMs)
+val encodedCompositeKey = stateTypesEncoder.encodeCompositeKey(key)
+store.put(encodedCompositeKey, encodedValue, stateName)
+
+val serializedGroupingKey = stateTypesEncoder.serializeGroupingKey()
+val serializedUserKey = stateTypesEncoder.serializeUserKey(key)
+upsertTTLForStateKey(ttlExpirationMs, serializedGroupingKey, 
serializedUserKey)
+  }
+
+  /** Get the map associated 

Re: [PR] [SPARK-47805][SS] Implementing TTL for MapState [spark]

2024-04-22 Thread via GitHub


ericm-db commented on code in PR #45991:
URL: https://github.com/apache/spark/pull/45991#discussion_r1575072868


##
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala:
##
@@ -0,0 +1,280 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.Encoder
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import 
org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{COMPOSITE_KEY_ROW_SCHEMA,
 VALUE_ROW_SCHEMA_WITH_TTL}
+import 
org.apache.spark.sql.execution.streaming.state.{PrefixKeyScanStateEncoderSpec, 
StateStore, StateStoreErrors}
+import org.apache.spark.sql.streaming.{MapState, TTLConfig}
+import org.apache.spark.util.NextIterator
+
+/**
+ * Class that provides a concrete implementation for map state associated with 
state
+ * variables (with ttl expiration support) used in the streaming 
transformWithState operator.
+ * @param store - reference to the StateStore instance to be used for storing 
state
+ * @param stateName  - name of the state variable
+ * @param keyExprEnc - Spark SQL encoder for key
+ * @param userKeyEnc  - Spark SQL encoder for the map key
+ * @param valEncoder - SQL encoder for state variable
+ * @param ttlConfig  - the ttl configuration (time to live duration etc.)
+ * @param batchTimestampMs - current batch processing timestamp.
+ * @tparam K - type of key for map state variable
+ * @tparam V - type of value for map state variable
+ * @return - instance of MapState of type [K,V] that can be used to store 
state persistently
+ */
+class MapStateImplWithTTL[K, V](
+store: StateStore,
+stateName: String,
+keyExprEnc: ExpressionEncoder[Any],
+userKeyEnc: Encoder[K],
+valEncoder: Encoder[V],
+ttlConfig: TTLConfig,
+batchTimestampMs: Long) extends CompositeKeyTTLStateImpl(stateName, store, 
batchTimestampMs)
+  with MapState[K, V] with Logging {
+
+  private val keySerializer = keyExprEnc.createSerializer()
+  private val stateTypesEncoder = new CompositeKeyStateEncoder(
+keySerializer, userKeyEnc, valEncoder, COMPOSITE_KEY_ROW_SCHEMA, 
stateName, hasTtl = true)
+
+  private val ttlExpirationMs =
+StateTTL.calculateExpirationTimeForDuration(ttlConfig.ttlDuration, 
batchTimestampMs)
+
+  createColumnFamily()

Review Comment:
   renamed back to initialize()



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



Re: [PR] [SPARK-47805][SS] Implementing TTL for MapState [spark]

2024-04-22 Thread via GitHub


ericm-db commented on code in PR #45991:
URL: https://github.com/apache/spark/pull/45991#discussion_r1575071946


##
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala:
##
@@ -0,0 +1,280 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.Encoder
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import 
org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{COMPOSITE_KEY_ROW_SCHEMA,
 VALUE_ROW_SCHEMA_WITH_TTL}
+import 
org.apache.spark.sql.execution.streaming.state.{PrefixKeyScanStateEncoderSpec, 
StateStore, StateStoreErrors}
+import org.apache.spark.sql.streaming.{MapState, TTLConfig}
+import org.apache.spark.util.NextIterator
+
+/**
+ * Class that provides a concrete implementation for map state associated with 
state
+ * variables (with ttl expiration support) used in the streaming 
transformWithState operator.
+ * @param store - reference to the StateStore instance to be used for storing 
state
+ * @param stateName  - name of the state variable
+ * @param keyExprEnc - Spark SQL encoder for key
+ * @param userKeyEnc  - Spark SQL encoder for the map key
+ * @param valEncoder - SQL encoder for state variable
+ * @param ttlConfig  - the ttl configuration (time to live duration etc.)
+ * @param batchTimestampMs - current batch processing timestamp.
+ * @tparam K - type of key for map state variable
+ * @tparam V - type of value for map state variable
+ * @return - instance of MapState of type [K,V] that can be used to store 
state persistently
+ */
+class MapStateImplWithTTL[K, V](
+store: StateStore,
+stateName: String,
+keyExprEnc: ExpressionEncoder[Any],
+userKeyEnc: Encoder[K],
+valEncoder: Encoder[V],
+ttlConfig: TTLConfig,
+batchTimestampMs: Long) extends CompositeKeyTTLStateImpl(stateName, store, 
batchTimestampMs)
+  with MapState[K, V] with Logging {
+
+  private val keySerializer = keyExprEnc.createSerializer()
+  private val stateTypesEncoder = new CompositeKeyStateEncoder(
+keySerializer, userKeyEnc, valEncoder, COMPOSITE_KEY_ROW_SCHEMA, 
stateName, hasTtl = true)
+
+  private val ttlExpirationMs =
+StateTTL.calculateExpirationTimeForDuration(ttlConfig.ttlDuration, 
batchTimestampMs)
+
+  createColumnFamily()
+
+  private def createColumnFamily(): Unit = {
+store.createColFamilyIfAbsent(stateName, COMPOSITE_KEY_ROW_SCHEMA, 
VALUE_ROW_SCHEMA_WITH_TTL,
+  PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, 1))
+  }
+
+  /** Whether state exists or not. */
+  override def exists(): Boolean = {
+iterator().nonEmpty
+  }
+
+  /** Get the state value if it exists */
+  override def getValue(key: K): V = {
+StateStoreErrors.requireNonNullStateValue(key, stateName)
+val encodedCompositeKey = stateTypesEncoder.encodeCompositeKey(key)
+val retRow = store.get(encodedCompositeKey, stateName)
+
+if (retRow != null) {
+  val resState = stateTypesEncoder.decodeValue(retRow)
+
+  if (!stateTypesEncoder.isExpired(retRow, batchTimestampMs)) {
+resState
+  } else {
+null.asInstanceOf[V]
+  }
+} else {
+  null.asInstanceOf[V]
+}
+  }
+
+  /** Check if the user key is contained in the map */
+  override def containsKey(key: K): Boolean = {
+StateStoreErrors.requireNonNullStateValue(key, stateName)
+getValue(key) != null
+  }
+
+  /** Update value for given user key */
+  override def updateValue(key: K, value: V): Unit = {
+StateStoreErrors.requireNonNullStateValue(key, stateName)
+StateStoreErrors.requireNonNullStateValue(value, stateName)
+
+val encodedValue = stateTypesEncoder.encodeValue(value, ttlExpirationMs)
+val encodedCompositeKey = stateTypesEncoder.encodeCompositeKey(key)
+store.put(encodedCompositeKey, encodedValue, stateName)
+
+val serializedGroupingKey = stateTypesEncoder.serializeGroupingKey()

Review Comment:
   There's an existing method for this, will call this method instead.



##
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala:

Re: [PR] [SPARK-47805][SS] Implementing TTL for MapState [spark]

2024-04-22 Thread via GitHub


ericm-db commented on code in PR #45991:
URL: https://github.com/apache/spark/pull/45991#discussion_r1575025178


##
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala:
##
@@ -0,0 +1,280 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.Encoder
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import 
org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{COMPOSITE_KEY_ROW_SCHEMA,
 VALUE_ROW_SCHEMA_WITH_TTL}
+import 
org.apache.spark.sql.execution.streaming.state.{PrefixKeyScanStateEncoderSpec, 
StateStore, StateStoreErrors}
+import org.apache.spark.sql.streaming.{MapState, TTLConfig}
+import org.apache.spark.util.NextIterator
+
+/**
+ * Class that provides a concrete implementation for map state associated with 
state
+ * variables (with ttl expiration support) used in the streaming 
transformWithState operator.
+ * @param store - reference to the StateStore instance to be used for storing 
state
+ * @param stateName  - name of the state variable
+ * @param keyExprEnc - Spark SQL encoder for key
+ * @param userKeyEnc  - Spark SQL encoder for the map key
+ * @param valEncoder - SQL encoder for state variable
+ * @param ttlConfig  - the ttl configuration (time to live duration etc.)
+ * @param batchTimestampMs - current batch processing timestamp.
+ * @tparam K - type of key for map state variable
+ * @tparam V - type of value for map state variable
+ * @return - instance of MapState of type [K,V] that can be used to store 
state persistently
+ */
+class MapStateImplWithTTL[K, V](
+store: StateStore,
+stateName: String,
+keyExprEnc: ExpressionEncoder[Any],
+userKeyEnc: Encoder[K],
+valEncoder: Encoder[V],
+ttlConfig: TTLConfig,
+batchTimestampMs: Long) extends CompositeKeyTTLStateImpl(stateName, store, 
batchTimestampMs)
+  with MapState[K, V] with Logging {
+
+  private val keySerializer = keyExprEnc.createSerializer()
+  private val stateTypesEncoder = new CompositeKeyStateEncoder(
+keySerializer, userKeyEnc, valEncoder, COMPOSITE_KEY_ROW_SCHEMA, 
stateName, hasTtl = true)

Review Comment:
   Make sense - I can open a follow-up PR for this.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



Re: [PR] [SPARK-47805][SS] Implementing TTL for MapState [spark]

2024-04-22 Thread via GitHub


ericm-db commented on code in PR #45991:
URL: https://github.com/apache/spark/pull/45991#discussion_r1575020539


##
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala:
##
@@ -0,0 +1,280 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.Encoder
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import 
org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{COMPOSITE_KEY_ROW_SCHEMA,
 VALUE_ROW_SCHEMA_WITH_TTL}
+import 
org.apache.spark.sql.execution.streaming.state.{PrefixKeyScanStateEncoderSpec, 
StateStore, StateStoreErrors}
+import org.apache.spark.sql.streaming.{MapState, TTLConfig}
+import org.apache.spark.util.NextIterator
+
+/**
+ * Class that provides a concrete implementation for map state associated with 
state
+ * variables (with ttl expiration support) used in the streaming 
transformWithState operator.
+ * @param store - reference to the StateStore instance to be used for storing 
state
+ * @param stateName  - name of the state variable
+ * @param keyExprEnc - Spark SQL encoder for key
+ * @param userKeyEnc  - Spark SQL encoder for the map key
+ * @param valEncoder - SQL encoder for state variable
+ * @param ttlConfig  - the ttl configuration (time to live duration etc.)
+ * @param batchTimestampMs - current batch processing timestamp.
+ * @tparam K - type of key for map state variable
+ * @tparam V - type of value for map state variable
+ * @return - instance of MapState of type [K,V] that can be used to store 
state persistently
+ */
+class MapStateImplWithTTL[K, V](
+store: StateStore,
+stateName: String,
+keyExprEnc: ExpressionEncoder[Any],
+userKeyEnc: Encoder[K],
+valEncoder: Encoder[V],
+ttlConfig: TTLConfig,
+batchTimestampMs: Long) extends CompositeKeyTTLStateImpl(stateName, store, 
batchTimestampMs)
+  with MapState[K, V] with Logging {
+
+  private val keySerializer = keyExprEnc.createSerializer()
+  private val stateTypesEncoder = new CompositeKeyStateEncoder(
+keySerializer, userKeyEnc, valEncoder, COMPOSITE_KEY_ROW_SCHEMA, 
stateName, hasTtl = true)
+
+  private val ttlExpirationMs =
+StateTTL.calculateExpirationTimeForDuration(ttlConfig.ttlDuration, 
batchTimestampMs)
+
+  createColumnFamily()
+
+  private def createColumnFamily(): Unit = {
+store.createColFamilyIfAbsent(stateName, COMPOSITE_KEY_ROW_SCHEMA, 
VALUE_ROW_SCHEMA_WITH_TTL,
+  PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, 1))
+  }
+
+  /** Whether state exists or not. */
+  override def exists(): Boolean = {
+iterator().nonEmpty
+  }
+
+  /** Get the state value if it exists */
+  override def getValue(key: K): V = {
+StateStoreErrors.requireNonNullStateValue(key, stateName)
+val encodedCompositeKey = stateTypesEncoder.encodeCompositeKey(key)
+val retRow = store.get(encodedCompositeKey, stateName)
+
+if (retRow != null) {
+  val resState = stateTypesEncoder.decodeValue(retRow)
+
+  if (!stateTypesEncoder.isExpired(retRow, batchTimestampMs)) {
+resState
+  } else {
+null.asInstanceOf[V]
+  }
+} else {
+  null.asInstanceOf[V]
+}
+  }
+
+  /** Check if the user key is contained in the map */
+  override def containsKey(key: K): Boolean = {
+StateStoreErrors.requireNonNullStateValue(key, stateName)
+getValue(key) != null
+  }
+
+  /** Update value for given user key */
+  override def updateValue(key: K, value: V): Unit = {
+StateStoreErrors.requireNonNullStateValue(key, stateName)
+StateStoreErrors.requireNonNullStateValue(value, stateName)
+
+val encodedValue = stateTypesEncoder.encodeValue(value, ttlExpirationMs)
+val encodedCompositeKey = stateTypesEncoder.encodeCompositeKey(key)
+store.put(encodedCompositeKey, encodedValue, stateName)
+
+val serializedGroupingKey = stateTypesEncoder.serializeGroupingKey()
+val serializedUserKey = stateTypesEncoder.serializeUserKey(key)
+upsertTTLForStateKey(ttlExpirationMs, serializedGroupingKey, 
serializedUserKey)
+  }
+
+  /** Get the map associated with 

Re: [PR] [SPARK-47805][SS] Implementing TTL for MapState [spark]

2024-04-22 Thread via GitHub


ericm-db commented on code in PR #45991:
URL: https://github.com/apache/spark/pull/45991#discussion_r1575011750


##
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala:
##
@@ -0,0 +1,280 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.Encoder
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import 
org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{COMPOSITE_KEY_ROW_SCHEMA,
 VALUE_ROW_SCHEMA_WITH_TTL}
+import 
org.apache.spark.sql.execution.streaming.state.{PrefixKeyScanStateEncoderSpec, 
StateStore, StateStoreErrors}
+import org.apache.spark.sql.streaming.{MapState, TTLConfig}
+import org.apache.spark.util.NextIterator
+
+/**
+ * Class that provides a concrete implementation for map state associated with 
state
+ * variables (with ttl expiration support) used in the streaming 
transformWithState operator.
+ * @param store - reference to the StateStore instance to be used for storing 
state
+ * @param stateName  - name of the state variable
+ * @param keyExprEnc - Spark SQL encoder for key
+ * @param userKeyEnc  - Spark SQL encoder for the map key
+ * @param valEncoder - SQL encoder for state variable
+ * @param ttlConfig  - the ttl configuration (time to live duration etc.)
+ * @param batchTimestampMs - current batch processing timestamp.
+ * @tparam K - type of key for map state variable
+ * @tparam V - type of value for map state variable
+ * @return - instance of MapState of type [K,V] that can be used to store 
state persistently
+ */
+class MapStateImplWithTTL[K, V](
+store: StateStore,
+stateName: String,
+keyExprEnc: ExpressionEncoder[Any],
+userKeyEnc: Encoder[K],
+valEncoder: Encoder[V],
+ttlConfig: TTLConfig,
+batchTimestampMs: Long) extends CompositeKeyTTLStateImpl(stateName, store, 
batchTimestampMs)
+  with MapState[K, V] with Logging {
+
+  private val keySerializer = keyExprEnc.createSerializer()
+  private val stateTypesEncoder = new CompositeKeyStateEncoder(
+keySerializer, userKeyEnc, valEncoder, COMPOSITE_KEY_ROW_SCHEMA, 
stateName, hasTtl = true)
+
+  private val ttlExpirationMs =
+StateTTL.calculateExpirationTimeForDuration(ttlConfig.ttlDuration, 
batchTimestampMs)
+
+  createColumnFamily()
+
+  private def createColumnFamily(): Unit = {
+store.createColFamilyIfAbsent(stateName, COMPOSITE_KEY_ROW_SCHEMA, 
VALUE_ROW_SCHEMA_WITH_TTL,
+  PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, 1))
+  }
+
+  /** Whether state exists or not. */
+  override def exists(): Boolean = {
+iterator().nonEmpty
+  }
+
+  /** Get the state value if it exists */
+  override def getValue(key: K): V = {
+StateStoreErrors.requireNonNullStateValue(key, stateName)
+val encodedCompositeKey = stateTypesEncoder.encodeCompositeKey(key)
+val retRow = store.get(encodedCompositeKey, stateName)
+
+if (retRow != null) {
+  val resState = stateTypesEncoder.decodeValue(retRow)
+
+  if (!stateTypesEncoder.isExpired(retRow, batchTimestampMs)) {

Review Comment:
   Good idea, moved.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



Re: [PR] [SPARK-47805][SS] Implementing TTL for MapState [spark]

2024-04-22 Thread via GitHub


ericm-db commented on code in PR #45991:
URL: https://github.com/apache/spark/pull/45991#discussion_r1575010427


##
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala:
##
@@ -99,6 +98,21 @@ abstract class SingleKeyTTLStateImpl(
   store.createColFamilyIfAbsent(ttlColumnFamilyName, TTL_KEY_ROW_SCHEMA, 
TTL_VALUE_ROW_SCHEMA,
 RangeKeyScanStateEncoderSpec(TTL_KEY_ROW_SCHEMA, Seq(0)), isInternal = 
true)
 
+  /**
+   * This function will be called when clear() on State Variables
+   * with ttl enabled is called. This function should clear any
+   * associated ttlState, since we are clearing the user state.
+   */
+  def clearTTLState(): Unit = {

Review Comment:
   You're right - I had missed this initially.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



Re: [PR] [SPARK-47805][SS] Implementing TTL for MapState [spark]

2024-04-22 Thread via GitHub


HeartSaVioR commented on code in PR #45991:
URL: https://github.com/apache/spark/pull/45991#discussion_r1574155684


##
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala:
##
@@ -0,0 +1,280 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.Encoder
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import 
org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{COMPOSITE_KEY_ROW_SCHEMA,
 VALUE_ROW_SCHEMA_WITH_TTL}
+import 
org.apache.spark.sql.execution.streaming.state.{PrefixKeyScanStateEncoderSpec, 
StateStore, StateStoreErrors}
+import org.apache.spark.sql.streaming.{MapState, TTLConfig}
+import org.apache.spark.util.NextIterator
+
+/**
+ * Class that provides a concrete implementation for map state associated with 
state
+ * variables (with ttl expiration support) used in the streaming 
transformWithState operator.
+ * @param store - reference to the StateStore instance to be used for storing 
state
+ * @param stateName  - name of the state variable
+ * @param keyExprEnc - Spark SQL encoder for key
+ * @param userKeyEnc  - Spark SQL encoder for the map key
+ * @param valEncoder - SQL encoder for state variable
+ * @param ttlConfig  - the ttl configuration (time to live duration etc.)
+ * @param batchTimestampMs - current batch processing timestamp.
+ * @tparam K - type of key for map state variable
+ * @tparam V - type of value for map state variable
+ * @return - instance of MapState of type [K,V] that can be used to store 
state persistently
+ */
+class MapStateImplWithTTL[K, V](
+store: StateStore,
+stateName: String,
+keyExprEnc: ExpressionEncoder[Any],
+userKeyEnc: Encoder[K],
+valEncoder: Encoder[V],
+ttlConfig: TTLConfig,
+batchTimestampMs: Long) extends CompositeKeyTTLStateImpl(stateName, store, 
batchTimestampMs)
+  with MapState[K, V] with Logging {
+
+  private val keySerializer = keyExprEnc.createSerializer()
+  private val stateTypesEncoder = new CompositeKeyStateEncoder(
+keySerializer, userKeyEnc, valEncoder, COMPOSITE_KEY_ROW_SCHEMA, 
stateName, hasTtl = true)

Review Comment:
   (Could be a part of MINOR FOLLOW-UP PR)
   
   This refactoring (COMPOSITE_KEY_ROW_SCHEMA) seems to be missed to be applied 
to MapStateImpl? Please revisit MapStateImpl and apply refactoring we missed so 
far.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



Re: [PR] [SPARK-47805][SS] Implementing TTL for MapState [spark]

2024-04-22 Thread via GitHub


HeartSaVioR commented on code in PR #45991:
URL: https://github.com/apache/spark/pull/45991#discussion_r1574155684


##
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala:
##
@@ -0,0 +1,280 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.Encoder
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import 
org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{COMPOSITE_KEY_ROW_SCHEMA,
 VALUE_ROW_SCHEMA_WITH_TTL}
+import 
org.apache.spark.sql.execution.streaming.state.{PrefixKeyScanStateEncoderSpec, 
StateStore, StateStoreErrors}
+import org.apache.spark.sql.streaming.{MapState, TTLConfig}
+import org.apache.spark.util.NextIterator
+
+/**
+ * Class that provides a concrete implementation for map state associated with 
state
+ * variables (with ttl expiration support) used in the streaming 
transformWithState operator.
+ * @param store - reference to the StateStore instance to be used for storing 
state
+ * @param stateName  - name of the state variable
+ * @param keyExprEnc - Spark SQL encoder for key
+ * @param userKeyEnc  - Spark SQL encoder for the map key
+ * @param valEncoder - SQL encoder for state variable
+ * @param ttlConfig  - the ttl configuration (time to live duration etc.)
+ * @param batchTimestampMs - current batch processing timestamp.
+ * @tparam K - type of key for map state variable
+ * @tparam V - type of value for map state variable
+ * @return - instance of MapState of type [K,V] that can be used to store 
state persistently
+ */
+class MapStateImplWithTTL[K, V](
+store: StateStore,
+stateName: String,
+keyExprEnc: ExpressionEncoder[Any],
+userKeyEnc: Encoder[K],
+valEncoder: Encoder[V],
+ttlConfig: TTLConfig,
+batchTimestampMs: Long) extends CompositeKeyTTLStateImpl(stateName, store, 
batchTimestampMs)
+  with MapState[K, V] with Logging {
+
+  private val keySerializer = keyExprEnc.createSerializer()
+  private val stateTypesEncoder = new CompositeKeyStateEncoder(
+keySerializer, userKeyEnc, valEncoder, COMPOSITE_KEY_ROW_SCHEMA, 
stateName, hasTtl = true)

Review Comment:
   This refactoring (COMPOSITE_KEY_ROW_SCHEMA) seems to be missed to be applied 
to MapStateImpl?
   
   Please revisit MapStateImpl and apply refactoring we missed so far. No need 
to include the change to this PR - MINOR follow-up PR would be fine.



##
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala:
##
@@ -99,6 +98,21 @@ abstract class SingleKeyTTLStateImpl(
   store.createColFamilyIfAbsent(ttlColumnFamilyName, TTL_KEY_ROW_SCHEMA, 
TTL_VALUE_ROW_SCHEMA,
 RangeKeyScanStateEncoderSpec(TTL_KEY_ROW_SCHEMA, Seq(0)), isInternal = 
true)
 
+  /**
+   * This function will be called when clear() on State Variables
+   * with ttl enabled is called. This function should clear any
+   * associated ttlState, since we are clearing the user state.
+   */
+  def clearTTLState(): Unit = {

Review Comment:
   If this intends to remove all TTL states because we are removing state, why 
removing only expired entries? Sounds like we should remove all TTL entries - 
please correct me if I'm missing something.



##
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala:
##
@@ -0,0 +1,280 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language 

Re: [PR] [SPARK-47805][SS] Implementing TTL for MapState [spark]

2024-04-19 Thread via GitHub


ericm-db commented on PR #45991:
URL: https://github.com/apache/spark/pull/45991#issuecomment-2067115942

   @HeartSaVioR PTAL, thanks!


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



Re: [PR] [SPARK-47805][SS] Implementing TTL for MapState [spark]

2024-04-18 Thread via GitHub


ericm-db commented on code in PR #45991:
URL: https://github.com/apache/spark/pull/45991#discussion_r1571187342


##
sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateTTLSuite.scala:
##
@@ -0,0 +1,308 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming
+
+import java.time.Duration
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.Encoders
+import org.apache.spark.sql.execution.streaming.{MapStateImplWithTTL, 
MemoryStream}
+import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.util.StreamManualClock
+
+class MapStateSingleKeyTTLProcessor(ttlConfig: TTLConfig)
+  extends StatefulProcessor[String, InputEvent, OutputEvent]
+with Logging {
+
+  @transient private var _mapState: MapStateImplWithTTL[String, Int] = _
+
+  override def init(
+  outputMode: OutputMode,
+  timeMode: TimeMode): Unit = {
+_mapState = getHandle
+  .getMapState("mapState", Encoders.STRING, Encoders.scalaInt, ttlConfig)
+  .asInstanceOf[MapStateImplWithTTL[String, Int]]
+  }
+  override def handleInputRows(
+key: String,
+inputRows: Iterator[InputEvent],
+timerValues: TimerValues,
+expiredTimerInfo: ExpiredTimerInfo): Iterator[OutputEvent] = {
+var results = List[OutputEvent]()
+
+for (row <- inputRows) {
+  val resultIter = processRow(row, _mapState)
+  resultIter.foreach { r =>
+results = r :: results
+  }
+}
+
+results.iterator
+  }
+
+  def processRow(
+  row: InputEvent,
+  mapState: MapStateImplWithTTL[String, Int]): Iterator[OutputEvent] = {
+var results = List[OutputEvent]()
+val key = row.key
+val userKey = "key"
+if (row.action == "get") {
+  if (mapState.containsKey(userKey)) {
+results = OutputEvent(key, mapState.getValue(userKey), isTTLValue = 
false, -1) :: results
+  }
+} else if (row.action == "get_without_enforcing_ttl") {
+  val currState = mapState.getWithoutEnforcingTTL(userKey)
+  if (currState.isDefined) {
+results = OutputEvent(key, currState.get, isTTLValue = false, -1) :: 
results
+  }
+} else if (row.action == "get_ttl_value_from_state") {
+  val ttlValue = mapState.getTTLValue(userKey)
+  if (ttlValue.isDefined) {
+val value = ttlValue.get._1
+val ttlExpiration = ttlValue.get._2
+results = OutputEvent(key, value, isTTLValue = true, ttlExpiration) :: 
results
+  }
+} else if (row.action == "put") {
+  mapState.updateValue(userKey, row.value)
+} else if (row.action == "get_values_in_ttl_state") {
+  val ttlValues = mapState.getValuesInTTLState()
+  ttlValues.foreach { v =>
+results = OutputEvent(key, -1, isTTLValue = true, ttlValue = v) :: 
results
+  }
+}
+
+results.iterator
+  }
+}
+
+case class MapInputEvent(
+key: String,
+userKey: String,
+action: String,
+value: Int)
+
+case class MapOutputEvent(
+key: String,
+userKey: String,
+value: Int,
+isTTLValue: Boolean,
+ttlValue: Long)
+
+class MapStateTTLProcessor(ttlConfig: TTLConfig)
+  extends StatefulProcessor[String, MapInputEvent, MapOutputEvent]
+with Logging {
+
+  @transient private var _mapState: MapStateImplWithTTL[String, Int] = _
+
+  override def init(
+  outputMode: OutputMode,
+  timeMode: TimeMode): Unit = {
+_mapState = getHandle
+  .getMapState("mapState", Encoders.STRING, Encoders.scalaInt, ttlConfig)
+  .asInstanceOf[MapStateImplWithTTL[String, Int]]
+  }
+
+  override def handleInputRows(
+  key: String,
+  inputRows: Iterator[MapInputEvent],
+  timerValues: TimerValues,
+  expiredTimerInfo: ExpiredTimerInfo): Iterator[MapOutputEvent] = {
+var results = List[MapOutputEvent]()
+
+for (row <- inputRows) {
+  val resultIter = processRow(row, _mapState)
+  resultIter.foreach { r =>
+results = r :: results
+  }
+}
+
+results.iterator
+  }
+
+  def processRow(
+  row: MapInputEvent,
+  mapState: MapStateImplWithTTL[String, Int]): 

Re: [PR] [SPARK-47805][SS] Implementing TTL for MapState [spark]

2024-04-18 Thread via GitHub


anishshri-db commented on code in PR #45991:
URL: https://github.com/apache/spark/pull/45991#discussion_r1571185285


##
sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateTTLSuite.scala:
##
@@ -0,0 +1,308 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming
+
+import java.time.Duration
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.Encoders
+import org.apache.spark.sql.execution.streaming.{MapStateImplWithTTL, 
MemoryStream}
+import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.util.StreamManualClock
+
+class MapStateSingleKeyTTLProcessor(ttlConfig: TTLConfig)
+  extends StatefulProcessor[String, InputEvent, OutputEvent]
+with Logging {
+
+  @transient private var _mapState: MapStateImplWithTTL[String, Int] = _
+
+  override def init(
+  outputMode: OutputMode,
+  timeMode: TimeMode): Unit = {
+_mapState = getHandle
+  .getMapState("mapState", Encoders.STRING, Encoders.scalaInt, ttlConfig)
+  .asInstanceOf[MapStateImplWithTTL[String, Int]]
+  }
+  override def handleInputRows(
+key: String,
+inputRows: Iterator[InputEvent],
+timerValues: TimerValues,
+expiredTimerInfo: ExpiredTimerInfo): Iterator[OutputEvent] = {
+var results = List[OutputEvent]()
+
+for (row <- inputRows) {
+  val resultIter = processRow(row, _mapState)
+  resultIter.foreach { r =>
+results = r :: results
+  }
+}
+
+results.iterator
+  }
+
+  def processRow(
+  row: InputEvent,
+  mapState: MapStateImplWithTTL[String, Int]): Iterator[OutputEvent] = {
+var results = List[OutputEvent]()
+val key = row.key
+val userKey = "key"
+if (row.action == "get") {
+  if (mapState.containsKey(userKey)) {
+results = OutputEvent(key, mapState.getValue(userKey), isTTLValue = 
false, -1) :: results
+  }
+} else if (row.action == "get_without_enforcing_ttl") {
+  val currState = mapState.getWithoutEnforcingTTL(userKey)
+  if (currState.isDefined) {
+results = OutputEvent(key, currState.get, isTTLValue = false, -1) :: 
results
+  }
+} else if (row.action == "get_ttl_value_from_state") {
+  val ttlValue = mapState.getTTLValue(userKey)
+  if (ttlValue.isDefined) {
+val value = ttlValue.get._1
+val ttlExpiration = ttlValue.get._2
+results = OutputEvent(key, value, isTTLValue = true, ttlExpiration) :: 
results
+  }
+} else if (row.action == "put") {
+  mapState.updateValue(userKey, row.value)
+} else if (row.action == "get_values_in_ttl_state") {
+  val ttlValues = mapState.getValuesInTTLState()
+  ttlValues.foreach { v =>
+results = OutputEvent(key, -1, isTTLValue = true, ttlValue = v) :: 
results
+  }
+}
+
+results.iterator
+  }
+}
+
+case class MapInputEvent(
+key: String,
+userKey: String,
+action: String,
+value: Int)
+
+case class MapOutputEvent(
+key: String,
+userKey: String,
+value: Int,
+isTTLValue: Boolean,
+ttlValue: Long)
+
+class MapStateTTLProcessor(ttlConfig: TTLConfig)
+  extends StatefulProcessor[String, MapInputEvent, MapOutputEvent]
+with Logging {
+
+  @transient private var _mapState: MapStateImplWithTTL[String, Int] = _
+
+  override def init(
+  outputMode: OutputMode,
+  timeMode: TimeMode): Unit = {
+_mapState = getHandle
+  .getMapState("mapState", Encoders.STRING, Encoders.scalaInt, ttlConfig)
+  .asInstanceOf[MapStateImplWithTTL[String, Int]]
+  }
+
+  override def handleInputRows(
+  key: String,
+  inputRows: Iterator[MapInputEvent],
+  timerValues: TimerValues,
+  expiredTimerInfo: ExpiredTimerInfo): Iterator[MapOutputEvent] = {
+var results = List[MapOutputEvent]()
+
+for (row <- inputRows) {
+  val resultIter = processRow(row, _mapState)
+  resultIter.foreach { r =>
+results = r :: results
+  }
+}
+
+results.iterator
+  }
+
+  def processRow(
+  row: MapInputEvent,
+  mapState: MapStateImplWithTTL[String, Int]): 

Re: [PR] [SPARK-47805][SS] Implementing TTL for MapState [spark]

2024-04-18 Thread via GitHub


ericm-db commented on code in PR #45991:
URL: https://github.com/apache/spark/pull/45991#discussion_r1571071937


##
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala:
##
@@ -0,0 +1,277 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.Encoder
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import 
org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{COMPOSITE_KEY_ROW_SCHEMA,
 VALUE_ROW_SCHEMA_WITH_TTL}
+import 
org.apache.spark.sql.execution.streaming.state.{PrefixKeyScanStateEncoderSpec, 
StateStore, StateStoreErrors}
+import org.apache.spark.sql.streaming.{MapState, TTLConfig}
+import org.apache.spark.util.NextIterator
+
+/**
+ * Class that provides a concrete implementation for map state associated with 
state
+ * variables (with ttl expiration support) used in the streaming 
transformWithState operator.
+ * @param store - reference to the StateStore instance to be used for storing 
state
+ * @param stateName  - name of the state variable
+ * @param keyExprEnc - Spark SQL encoder for key
+ * @param userKeyEnc  - Spark SQL encoder for the map key
+ * @param valEncoder - SQL encoder for state variable
+ * @param ttlConfig  - the ttl configuration (time to live duration etc.)
+ * @tparam K - type of key for map state variable
+ * @tparam V - type of value for map state variable
+ * @return - instance of MapState of type [K,V] that can be used to store 
state persistently
+ */
+class MapStateImplWithTTL[K, V](
+store: StateStore,
+stateName: String,
+keyExprEnc: ExpressionEncoder[Any],
+userKeyEnc: Encoder[K],
+valEncoder: Encoder[V],
+ttlConfig: TTLConfig,
+batchTimestampMs: Long) extends CompositeKeyTTLStateImpl(stateName, store, 
batchTimestampMs)
+  with MapState[K, V] with Logging {
+
+  private val keySerializer = keyExprEnc.createSerializer()
+  private val stateTypesEncoder = new CompositeKeyStateEncoder(
+keySerializer, userKeyEnc, valEncoder, COMPOSITE_KEY_ROW_SCHEMA, 
stateName, hasTtl = true)
+
+  private val ttlExpirationMs =
+StateTTL.calculateExpirationTimeForDuration(ttlConfig.ttlDuration, 
batchTimestampMs)
+
+  initialize()
+
+  private def initialize(): Unit = {
+store.createColFamilyIfAbsent(stateName, COMPOSITE_KEY_ROW_SCHEMA, 
VALUE_ROW_SCHEMA_WITH_TTL,
+  PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, 1))
+  }
+
+  /** Whether state exists or not. */
+  override def exists(): Boolean = {
+iterator().nonEmpty
+  }
+
+  /** Get the state value if it exists */
+  override def getValue(key: K): V = {
+StateStoreErrors.requireNonNullStateValue(key, stateName)
+val encodedCompositeKey = stateTypesEncoder.encodeCompositeKey(key)
+val retRow = store.get(encodedCompositeKey, stateName)
+
+if (retRow != null) {
+  val resState = stateTypesEncoder.decodeValue(retRow)
+
+  if (!stateTypesEncoder.isExpired(retRow, batchTimestampMs)) {
+resState
+  } else {
+null.asInstanceOf[V]
+  }
+} else {
+  null.asInstanceOf[V]
+}
+  }
+
+  /** Check if the user key is contained in the map */
+  override def containsKey(key: K): Boolean = {
+StateStoreErrors.requireNonNullStateValue(key, stateName)
+getValue(key) != null
+  }
+
+  /** Update value for given user key */
+  override def updateValue(key: K, value: V): Unit = {
+StateStoreErrors.requireNonNullStateValue(key, stateName)
+StateStoreErrors.requireNonNullStateValue(value, stateName)
+val encodedValue = stateTypesEncoder.encodeValue(value, ttlExpirationMs)
+val encodedCompositeKey = stateTypesEncoder.encodeCompositeKey(key)
+store.put(encodedCompositeKey, encodedValue, stateName)
+val serializedGroupingKey = stateTypesEncoder.serializeGroupingKey()
+val serializedUserKey = stateTypesEncoder.serializeUserKey(key)
+upsertTTLForStateKey(ttlExpirationMs, serializedGroupingKey, 
serializedUserKey)
+  }
+
+  /** Get the map associated with grouping key */
+  override def iterator(): Iterator[(K, V)] = {
+val 

Re: [PR] [SPARK-47805][SS] Implementing TTL for MapState [spark]

2024-04-18 Thread via GitHub


ericm-db commented on code in PR #45991:
URL: https://github.com/apache/spark/pull/45991#discussion_r1571061801


##
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala:
##
@@ -0,0 +1,265 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.Encoder
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import 
org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{COMPOSITE_KEY_ROW_SCHEMA,
 VALUE_ROW_SCHEMA_WITH_TTL}
+import 
org.apache.spark.sql.execution.streaming.state.{PrefixKeyScanStateEncoderSpec, 
StateStore, StateStoreErrors}
+import org.apache.spark.sql.streaming.{MapState, TTLConfig}
+import org.apache.spark.util.NextIterator
+
+
+class MapStateImplWithTTL[K, V](
+  store: StateStore,
+  stateName: String,
+  keyExprEnc: ExpressionEncoder[Any],
+  userKeyEnc: Encoder[K],
+  valEncoder: Encoder[V],
+  ttlConfig: TTLConfig,
+  batchTimestampMs: Long) extends CompositeKeyTTLStateImpl(stateName, store, 
batchTimestampMs)
+  with MapState[K, V] with Logging {
+
+  private val keySerializer = keyExprEnc.createSerializer()
+  private val stateTypesEncoder = new CompositeKeyStateEncoder(
+keySerializer, userKeyEnc, valEncoder, COMPOSITE_KEY_ROW_SCHEMA, 
stateName, hasTtl = true)
+
+  private val ttlExpirationMs =
+StateTTL.calculateExpirationTimeForDuration(ttlConfig.ttlDuration, 
batchTimestampMs)
+
+  initialize()
+
+  private def initialize(): Unit = {
+store.createColFamilyIfAbsent(stateName, COMPOSITE_KEY_ROW_SCHEMA, 
VALUE_ROW_SCHEMA_WITH_TTL,
+  PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, 1))
+  }
+
+  /** Whether state exists or not. */
+  override def exists(): Boolean = {
+iterator().nonEmpty
+  }
+
+  /** Get the state value if it exists */
+  override def getValue(key: K): V = {
+StateStoreErrors.requireNonNullStateValue(key, stateName)
+val encodedCompositeKey = stateTypesEncoder.encodeCompositeKey(key)
+val retRow = store.get(encodedCompositeKey, stateName)
+
+if (retRow != null) {
+  val resState = stateTypesEncoder.decodeValue(retRow)
+
+  if (!stateTypesEncoder.isExpired(retRow, batchTimestampMs)) {
+resState
+  } else {
+null.asInstanceOf[V]
+  }
+} else {
+  null.asInstanceOf[V]
+}
+  }
+
+  /** Check if the user key is contained in the map */
+  override def containsKey(key: K): Boolean = {
+StateStoreErrors.requireNonNullStateValue(key, stateName)
+getValue(key) != null
+  }
+
+  /** Update value for given user key */
+  override def updateValue(key: K, value: V): Unit = {
+StateStoreErrors.requireNonNullStateValue(key, stateName)
+StateStoreErrors.requireNonNullStateValue(value, stateName)
+val encodedValue = stateTypesEncoder.encodeValue(value, ttlExpirationMs)
+val encodedCompositeKey = stateTypesEncoder.encodeCompositeKey(key)
+store.put(encodedCompositeKey, encodedValue, stateName)
+val serializedGroupingKey = stateTypesEncoder.serializeGroupingKey()
+val serializedUserKey = stateTypesEncoder.serializeUserKey(key)
+upsertTTLForStateKey(ttlExpirationMs, serializedGroupingKey, 
serializedUserKey)
+  }
+
+  /** Get the map associated with grouping key */
+  override def iterator(): Iterator[(K, V)] = {
+val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey()
+val unsafeRowPairIterator = store.prefixScan(encodedGroupingKey, stateName)
+new NextIterator[(K, V)] {
+  override protected def getNext(): (K, V) = {
+val iter = unsafeRowPairIterator.dropWhile { rowPair =>
+  stateTypesEncoder.isExpired(rowPair.value, batchTimestampMs)
+}
+if (iter.hasNext) {
+  val currentRowPair = iter.next()
+  val key = stateTypesEncoder.decodeCompositeKey(currentRowPair.key)
+  val value = stateTypesEncoder.decodeValue(currentRowPair.value)
+  (key, value)
+} else {
+  finished = true
+  null.asInstanceOf[(K, V)]
+}
+  }
+
+  override protected def close(): Unit = {}
+}
+  }
+
+  

Re: [PR] [SPARK-47805][SS] Implementing TTL for MapState [spark]

2024-04-18 Thread via GitHub


ericm-db commented on code in PR #45991:
URL: https://github.com/apache/spark/pull/45991#discussion_r1571061801


##
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala:
##
@@ -0,0 +1,265 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.Encoder
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import 
org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{COMPOSITE_KEY_ROW_SCHEMA,
 VALUE_ROW_SCHEMA_WITH_TTL}
+import 
org.apache.spark.sql.execution.streaming.state.{PrefixKeyScanStateEncoderSpec, 
StateStore, StateStoreErrors}
+import org.apache.spark.sql.streaming.{MapState, TTLConfig}
+import org.apache.spark.util.NextIterator
+
+
+class MapStateImplWithTTL[K, V](
+  store: StateStore,
+  stateName: String,
+  keyExprEnc: ExpressionEncoder[Any],
+  userKeyEnc: Encoder[K],
+  valEncoder: Encoder[V],
+  ttlConfig: TTLConfig,
+  batchTimestampMs: Long) extends CompositeKeyTTLStateImpl(stateName, store, 
batchTimestampMs)
+  with MapState[K, V] with Logging {
+
+  private val keySerializer = keyExprEnc.createSerializer()
+  private val stateTypesEncoder = new CompositeKeyStateEncoder(
+keySerializer, userKeyEnc, valEncoder, COMPOSITE_KEY_ROW_SCHEMA, 
stateName, hasTtl = true)
+
+  private val ttlExpirationMs =
+StateTTL.calculateExpirationTimeForDuration(ttlConfig.ttlDuration, 
batchTimestampMs)
+
+  initialize()
+
+  private def initialize(): Unit = {
+store.createColFamilyIfAbsent(stateName, COMPOSITE_KEY_ROW_SCHEMA, 
VALUE_ROW_SCHEMA_WITH_TTL,
+  PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, 1))
+  }
+
+  /** Whether state exists or not. */
+  override def exists(): Boolean = {
+iterator().nonEmpty
+  }
+
+  /** Get the state value if it exists */
+  override def getValue(key: K): V = {
+StateStoreErrors.requireNonNullStateValue(key, stateName)
+val encodedCompositeKey = stateTypesEncoder.encodeCompositeKey(key)
+val retRow = store.get(encodedCompositeKey, stateName)
+
+if (retRow != null) {
+  val resState = stateTypesEncoder.decodeValue(retRow)
+
+  if (!stateTypesEncoder.isExpired(retRow, batchTimestampMs)) {
+resState
+  } else {
+null.asInstanceOf[V]
+  }
+} else {
+  null.asInstanceOf[V]
+}
+  }
+
+  /** Check if the user key is contained in the map */
+  override def containsKey(key: K): Boolean = {
+StateStoreErrors.requireNonNullStateValue(key, stateName)
+getValue(key) != null
+  }
+
+  /** Update value for given user key */
+  override def updateValue(key: K, value: V): Unit = {
+StateStoreErrors.requireNonNullStateValue(key, stateName)
+StateStoreErrors.requireNonNullStateValue(value, stateName)
+val encodedValue = stateTypesEncoder.encodeValue(value, ttlExpirationMs)
+val encodedCompositeKey = stateTypesEncoder.encodeCompositeKey(key)
+store.put(encodedCompositeKey, encodedValue, stateName)
+val serializedGroupingKey = stateTypesEncoder.serializeGroupingKey()
+val serializedUserKey = stateTypesEncoder.serializeUserKey(key)
+upsertTTLForStateKey(ttlExpirationMs, serializedGroupingKey, 
serializedUserKey)
+  }
+
+  /** Get the map associated with grouping key */
+  override def iterator(): Iterator[(K, V)] = {
+val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey()
+val unsafeRowPairIterator = store.prefixScan(encodedGroupingKey, stateName)
+new NextIterator[(K, V)] {
+  override protected def getNext(): (K, V) = {
+val iter = unsafeRowPairIterator.dropWhile { rowPair =>
+  stateTypesEncoder.isExpired(rowPair.value, batchTimestampMs)
+}
+if (iter.hasNext) {
+  val currentRowPair = iter.next()
+  val key = stateTypesEncoder.decodeCompositeKey(currentRowPair.key)
+  val value = stateTypesEncoder.decodeValue(currentRowPair.value)
+  (key, value)
+} else {
+  finished = true
+  null.asInstanceOf[(K, V)]
+}
+  }
+
+  override protected def close(): Unit = {}
+}
+  }
+
+  

Re: [PR] [SPARK-47805][SS] Implementing TTL for MapState [spark]

2024-04-18 Thread via GitHub


ericm-db commented on code in PR #45991:
URL: https://github.com/apache/spark/pull/45991#discussion_r1571060613


##
sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateTTLSuite.scala:
##
@@ -0,0 +1,308 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming
+
+import java.time.Duration
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.Encoders
+import org.apache.spark.sql.execution.streaming.{MapStateImplWithTTL, 
MemoryStream}
+import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.util.StreamManualClock
+
+class MapStateSingleKeyTTLProcessor(ttlConfig: TTLConfig)
+  extends StatefulProcessor[String, InputEvent, OutputEvent]
+with Logging {
+
+  @transient private var _mapState: MapStateImplWithTTL[String, Int] = _
+
+  override def init(
+  outputMode: OutputMode,
+  timeMode: TimeMode): Unit = {
+_mapState = getHandle
+  .getMapState("mapState", Encoders.STRING, Encoders.scalaInt, ttlConfig)
+  .asInstanceOf[MapStateImplWithTTL[String, Int]]
+  }
+  override def handleInputRows(
+key: String,
+inputRows: Iterator[InputEvent],
+timerValues: TimerValues,
+expiredTimerInfo: ExpiredTimerInfo): Iterator[OutputEvent] = {
+var results = List[OutputEvent]()
+
+for (row <- inputRows) {
+  val resultIter = processRow(row, _mapState)
+  resultIter.foreach { r =>
+results = r :: results
+  }
+}
+
+results.iterator
+  }
+
+  def processRow(
+  row: InputEvent,
+  mapState: MapStateImplWithTTL[String, Int]): Iterator[OutputEvent] = {
+var results = List[OutputEvent]()
+val key = row.key
+val userKey = "key"
+if (row.action == "get") {
+  if (mapState.containsKey(userKey)) {
+results = OutputEvent(key, mapState.getValue(userKey), isTTLValue = 
false, -1) :: results
+  }
+} else if (row.action == "get_without_enforcing_ttl") {
+  val currState = mapState.getWithoutEnforcingTTL(userKey)
+  if (currState.isDefined) {
+results = OutputEvent(key, currState.get, isTTLValue = false, -1) :: 
results
+  }
+} else if (row.action == "get_ttl_value_from_state") {
+  val ttlValue = mapState.getTTLValue(userKey)
+  if (ttlValue.isDefined) {
+val value = ttlValue.get._1
+val ttlExpiration = ttlValue.get._2
+results = OutputEvent(key, value, isTTLValue = true, ttlExpiration) :: 
results
+  }
+} else if (row.action == "put") {
+  mapState.updateValue(userKey, row.value)
+} else if (row.action == "get_values_in_ttl_state") {
+  val ttlValues = mapState.getValuesInTTLState()
+  ttlValues.foreach { v =>
+results = OutputEvent(key, -1, isTTLValue = true, ttlValue = v) :: 
results
+  }
+}
+
+results.iterator
+  }
+}
+
+case class MapInputEvent(
+key: String,
+userKey: String,
+action: String,
+value: Int)
+
+case class MapOutputEvent(
+key: String,
+userKey: String,
+value: Int,
+isTTLValue: Boolean,
+ttlValue: Long)
+
+class MapStateTTLProcessor(ttlConfig: TTLConfig)
+  extends StatefulProcessor[String, MapInputEvent, MapOutputEvent]
+with Logging {
+
+  @transient private var _mapState: MapStateImplWithTTL[String, Int] = _
+
+  override def init(
+  outputMode: OutputMode,
+  timeMode: TimeMode): Unit = {
+_mapState = getHandle
+  .getMapState("mapState", Encoders.STRING, Encoders.scalaInt, ttlConfig)
+  .asInstanceOf[MapStateImplWithTTL[String, Int]]
+  }
+
+  override def handleInputRows(
+  key: String,
+  inputRows: Iterator[MapInputEvent],
+  timerValues: TimerValues,
+  expiredTimerInfo: ExpiredTimerInfo): Iterator[MapOutputEvent] = {
+var results = List[MapOutputEvent]()
+
+for (row <- inputRows) {
+  val resultIter = processRow(row, _mapState)
+  resultIter.foreach { r =>
+results = r :: results
+  }
+}
+
+results.iterator
+  }
+
+  def processRow(
+  row: MapInputEvent,
+  mapState: MapStateImplWithTTL[String, Int]): 

Re: [PR] [SPARK-47805][SS] Implementing TTL for MapState [spark]

2024-04-18 Thread via GitHub


anishshri-db commented on code in PR #45991:
URL: https://github.com/apache/spark/pull/45991#discussion_r1570093814


##
sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateTTLSuite.scala:
##
@@ -0,0 +1,308 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming
+
+import java.time.Duration
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.Encoders
+import org.apache.spark.sql.execution.streaming.{MapStateImplWithTTL, 
MemoryStream}
+import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.util.StreamManualClock
+
+class MapStateSingleKeyTTLProcessor(ttlConfig: TTLConfig)
+  extends StatefulProcessor[String, InputEvent, OutputEvent]
+with Logging {
+
+  @transient private var _mapState: MapStateImplWithTTL[String, Int] = _
+
+  override def init(
+  outputMode: OutputMode,
+  timeMode: TimeMode): Unit = {
+_mapState = getHandle
+  .getMapState("mapState", Encoders.STRING, Encoders.scalaInt, ttlConfig)
+  .asInstanceOf[MapStateImplWithTTL[String, Int]]
+  }
+  override def handleInputRows(
+key: String,
+inputRows: Iterator[InputEvent],
+timerValues: TimerValues,
+expiredTimerInfo: ExpiredTimerInfo): Iterator[OutputEvent] = {
+var results = List[OutputEvent]()
+
+for (row <- inputRows) {
+  val resultIter = processRow(row, _mapState)
+  resultIter.foreach { r =>
+results = r :: results
+  }
+}
+
+results.iterator
+  }
+
+  def processRow(
+  row: InputEvent,
+  mapState: MapStateImplWithTTL[String, Int]): Iterator[OutputEvent] = {
+var results = List[OutputEvent]()
+val key = row.key
+val userKey = "key"
+if (row.action == "get") {
+  if (mapState.containsKey(userKey)) {
+results = OutputEvent(key, mapState.getValue(userKey), isTTLValue = 
false, -1) :: results
+  }
+} else if (row.action == "get_without_enforcing_ttl") {
+  val currState = mapState.getWithoutEnforcingTTL(userKey)
+  if (currState.isDefined) {
+results = OutputEvent(key, currState.get, isTTLValue = false, -1) :: 
results
+  }
+} else if (row.action == "get_ttl_value_from_state") {
+  val ttlValue = mapState.getTTLValue(userKey)
+  if (ttlValue.isDefined) {
+val value = ttlValue.get._1
+val ttlExpiration = ttlValue.get._2
+results = OutputEvent(key, value, isTTLValue = true, ttlExpiration) :: 
results
+  }
+} else if (row.action == "put") {
+  mapState.updateValue(userKey, row.value)
+} else if (row.action == "get_values_in_ttl_state") {
+  val ttlValues = mapState.getValuesInTTLState()
+  ttlValues.foreach { v =>
+results = OutputEvent(key, -1, isTTLValue = true, ttlValue = v) :: 
results
+  }
+}
+
+results.iterator
+  }
+}
+
+case class MapInputEvent(
+key: String,
+userKey: String,
+action: String,
+value: Int)
+
+case class MapOutputEvent(
+key: String,
+userKey: String,
+value: Int,
+isTTLValue: Boolean,
+ttlValue: Long)
+
+class MapStateTTLProcessor(ttlConfig: TTLConfig)
+  extends StatefulProcessor[String, MapInputEvent, MapOutputEvent]
+with Logging {
+
+  @transient private var _mapState: MapStateImplWithTTL[String, Int] = _
+
+  override def init(
+  outputMode: OutputMode,
+  timeMode: TimeMode): Unit = {
+_mapState = getHandle
+  .getMapState("mapState", Encoders.STRING, Encoders.scalaInt, ttlConfig)
+  .asInstanceOf[MapStateImplWithTTL[String, Int]]
+  }
+
+  override def handleInputRows(
+  key: String,
+  inputRows: Iterator[MapInputEvent],
+  timerValues: TimerValues,
+  expiredTimerInfo: ExpiredTimerInfo): Iterator[MapOutputEvent] = {
+var results = List[MapOutputEvent]()
+
+for (row <- inputRows) {
+  val resultIter = processRow(row, _mapState)
+  resultIter.foreach { r =>
+results = r :: results
+  }
+}
+
+results.iterator
+  }
+
+  def processRow(
+  row: MapInputEvent,
+  mapState: MapStateImplWithTTL[String, Int]): 

Re: [PR] [SPARK-47805][SS] Implementing TTL for MapState [spark]

2024-04-18 Thread via GitHub


anishshri-db commented on code in PR #45991:
URL: https://github.com/apache/spark/pull/45991#discussion_r1570092192


##
sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateTTLSuite.scala:
##
@@ -0,0 +1,308 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming
+
+import java.time.Duration
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.Encoders
+import org.apache.spark.sql.execution.streaming.{MapStateImplWithTTL, 
MemoryStream}
+import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.util.StreamManualClock
+
+class MapStateSingleKeyTTLProcessor(ttlConfig: TTLConfig)
+  extends StatefulProcessor[String, InputEvent, OutputEvent]
+with Logging {
+
+  @transient private var _mapState: MapStateImplWithTTL[String, Int] = _
+
+  override def init(
+  outputMode: OutputMode,
+  timeMode: TimeMode): Unit = {
+_mapState = getHandle
+  .getMapState("mapState", Encoders.STRING, Encoders.scalaInt, ttlConfig)
+  .asInstanceOf[MapStateImplWithTTL[String, Int]]
+  }
+  override def handleInputRows(
+key: String,
+inputRows: Iterator[InputEvent],
+timerValues: TimerValues,
+expiredTimerInfo: ExpiredTimerInfo): Iterator[OutputEvent] = {
+var results = List[OutputEvent]()
+
+for (row <- inputRows) {
+  val resultIter = processRow(row, _mapState)
+  resultIter.foreach { r =>
+results = r :: results
+  }
+}
+
+results.iterator
+  }
+
+  def processRow(
+  row: InputEvent,
+  mapState: MapStateImplWithTTL[String, Int]): Iterator[OutputEvent] = {
+var results = List[OutputEvent]()
+val key = row.key
+val userKey = "key"
+if (row.action == "get") {
+  if (mapState.containsKey(userKey)) {
+results = OutputEvent(key, mapState.getValue(userKey), isTTLValue = 
false, -1) :: results
+  }
+} else if (row.action == "get_without_enforcing_ttl") {
+  val currState = mapState.getWithoutEnforcingTTL(userKey)
+  if (currState.isDefined) {
+results = OutputEvent(key, currState.get, isTTLValue = false, -1) :: 
results
+  }
+} else if (row.action == "get_ttl_value_from_state") {
+  val ttlValue = mapState.getTTLValue(userKey)
+  if (ttlValue.isDefined) {
+val value = ttlValue.get._1
+val ttlExpiration = ttlValue.get._2
+results = OutputEvent(key, value, isTTLValue = true, ttlExpiration) :: 
results
+  }
+} else if (row.action == "put") {
+  mapState.updateValue(userKey, row.value)
+} else if (row.action == "get_values_in_ttl_state") {
+  val ttlValues = mapState.getValuesInTTLState()
+  ttlValues.foreach { v =>
+results = OutputEvent(key, -1, isTTLValue = true, ttlValue = v) :: 
results
+  }
+}
+
+results.iterator
+  }
+}
+
+case class MapInputEvent(
+key: String,
+userKey: String,
+action: String,
+value: Int)
+
+case class MapOutputEvent(
+key: String,
+userKey: String,
+value: Int,
+isTTLValue: Boolean,
+ttlValue: Long)
+
+class MapStateTTLProcessor(ttlConfig: TTLConfig)
+  extends StatefulProcessor[String, MapInputEvent, MapOutputEvent]
+with Logging {
+
+  @transient private var _mapState: MapStateImplWithTTL[String, Int] = _
+
+  override def init(
+  outputMode: OutputMode,
+  timeMode: TimeMode): Unit = {
+_mapState = getHandle
+  .getMapState("mapState", Encoders.STRING, Encoders.scalaInt, ttlConfig)
+  .asInstanceOf[MapStateImplWithTTL[String, Int]]
+  }
+
+  override def handleInputRows(
+  key: String,
+  inputRows: Iterator[MapInputEvent],
+  timerValues: TimerValues,
+  expiredTimerInfo: ExpiredTimerInfo): Iterator[MapOutputEvent] = {
+var results = List[MapOutputEvent]()
+
+for (row <- inputRows) {
+  val resultIter = processRow(row, _mapState)
+  resultIter.foreach { r =>
+results = r :: results
+  }
+}
+
+results.iterator
+  }
+
+  def processRow(
+  row: MapInputEvent,
+  mapState: MapStateImplWithTTL[String, Int]): 

Re: [PR] [SPARK-47805][SS] Implementing TTL for MapState [spark]

2024-04-18 Thread via GitHub


anishshri-db commented on code in PR #45991:
URL: https://github.com/apache/spark/pull/45991#discussion_r1570090219


##
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala:
##
@@ -99,6 +98,16 @@ abstract class SingleKeyTTLStateImpl(
   store.createColFamilyIfAbsent(ttlColumnFamilyName, TTL_KEY_ROW_SCHEMA, 
TTL_VALUE_ROW_SCHEMA,
 RangeKeyScanStateEncoderSpec(TTL_KEY_ROW_SCHEMA, Seq(0)), isInternal = 
true)
 
+  def clearTTLState(): Unit = {

Review Comment:
   nit: lets also add a small function comment here ?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



Re: [PR] [SPARK-47805][SS] Implementing TTL for MapState [spark]

2024-04-18 Thread via GitHub


anishshri-db commented on code in PR #45991:
URL: https://github.com/apache/spark/pull/45991#discussion_r1570085488


##
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala:
##
@@ -0,0 +1,277 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.Encoder
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import 
org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{COMPOSITE_KEY_ROW_SCHEMA,
 VALUE_ROW_SCHEMA_WITH_TTL}
+import 
org.apache.spark.sql.execution.streaming.state.{PrefixKeyScanStateEncoderSpec, 
StateStore, StateStoreErrors}
+import org.apache.spark.sql.streaming.{MapState, TTLConfig}
+import org.apache.spark.util.NextIterator
+
+/**
+ * Class that provides a concrete implementation for map state associated with 
state
+ * variables (with ttl expiration support) used in the streaming 
transformWithState operator.
+ * @param store - reference to the StateStore instance to be used for storing 
state
+ * @param stateName  - name of the state variable
+ * @param keyExprEnc - Spark SQL encoder for key
+ * @param userKeyEnc  - Spark SQL encoder for the map key
+ * @param valEncoder - SQL encoder for state variable
+ * @param ttlConfig  - the ttl configuration (time to live duration etc.)
+ * @tparam K - type of key for map state variable
+ * @tparam V - type of value for map state variable
+ * @return - instance of MapState of type [K,V] that can be used to store 
state persistently
+ */
+class MapStateImplWithTTL[K, V](
+store: StateStore,
+stateName: String,
+keyExprEnc: ExpressionEncoder[Any],
+userKeyEnc: Encoder[K],
+valEncoder: Encoder[V],
+ttlConfig: TTLConfig,
+batchTimestampMs: Long) extends CompositeKeyTTLStateImpl(stateName, store, 
batchTimestampMs)
+  with MapState[K, V] with Logging {
+
+  private val keySerializer = keyExprEnc.createSerializer()
+  private val stateTypesEncoder = new CompositeKeyStateEncoder(
+keySerializer, userKeyEnc, valEncoder, COMPOSITE_KEY_ROW_SCHEMA, 
stateName, hasTtl = true)
+
+  private val ttlExpirationMs =
+StateTTL.calculateExpirationTimeForDuration(ttlConfig.ttlDuration, 
batchTimestampMs)
+
+  initialize()
+
+  private def initialize(): Unit = {
+store.createColFamilyIfAbsent(stateName, COMPOSITE_KEY_ROW_SCHEMA, 
VALUE_ROW_SCHEMA_WITH_TTL,
+  PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, 1))
+  }
+
+  /** Whether state exists or not. */
+  override def exists(): Boolean = {
+iterator().nonEmpty
+  }
+
+  /** Get the state value if it exists */
+  override def getValue(key: K): V = {
+StateStoreErrors.requireNonNullStateValue(key, stateName)
+val encodedCompositeKey = stateTypesEncoder.encodeCompositeKey(key)
+val retRow = store.get(encodedCompositeKey, stateName)
+
+if (retRow != null) {
+  val resState = stateTypesEncoder.decodeValue(retRow)
+
+  if (!stateTypesEncoder.isExpired(retRow, batchTimestampMs)) {
+resState
+  } else {
+null.asInstanceOf[V]
+  }
+} else {
+  null.asInstanceOf[V]
+}
+  }
+
+  /** Check if the user key is contained in the map */
+  override def containsKey(key: K): Boolean = {
+StateStoreErrors.requireNonNullStateValue(key, stateName)
+getValue(key) != null
+  }
+
+  /** Update value for given user key */
+  override def updateValue(key: K, value: V): Unit = {
+StateStoreErrors.requireNonNullStateValue(key, stateName)
+StateStoreErrors.requireNonNullStateValue(value, stateName)
+val encodedValue = stateTypesEncoder.encodeValue(value, ttlExpirationMs)
+val encodedCompositeKey = stateTypesEncoder.encodeCompositeKey(key)
+store.put(encodedCompositeKey, encodedValue, stateName)
+val serializedGroupingKey = stateTypesEncoder.serializeGroupingKey()
+val serializedUserKey = stateTypesEncoder.serializeUserKey(key)
+upsertTTLForStateKey(ttlExpirationMs, serializedGroupingKey, 
serializedUserKey)
+  }
+
+  /** Get the map associated with grouping key */
+  override def iterator(): Iterator[(K, V)] = {
+val 

Re: [PR] [SPARK-47805][SS] Implementing TTL for MapState [spark]

2024-04-18 Thread via GitHub


anishshri-db commented on code in PR #45991:
URL: https://github.com/apache/spark/pull/45991#discussion_r1570083230


##
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala:
##
@@ -0,0 +1,265 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.Encoder
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import 
org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{COMPOSITE_KEY_ROW_SCHEMA,
 VALUE_ROW_SCHEMA_WITH_TTL}
+import 
org.apache.spark.sql.execution.streaming.state.{PrefixKeyScanStateEncoderSpec, 
StateStore, StateStoreErrors}
+import org.apache.spark.sql.streaming.{MapState, TTLConfig}
+import org.apache.spark.util.NextIterator
+
+
+class MapStateImplWithTTL[K, V](
+  store: StateStore,
+  stateName: String,
+  keyExprEnc: ExpressionEncoder[Any],
+  userKeyEnc: Encoder[K],
+  valEncoder: Encoder[V],
+  ttlConfig: TTLConfig,
+  batchTimestampMs: Long) extends CompositeKeyTTLStateImpl(stateName, store, 
batchTimestampMs)
+  with MapState[K, V] with Logging {
+
+  private val keySerializer = keyExprEnc.createSerializer()
+  private val stateTypesEncoder = new CompositeKeyStateEncoder(
+keySerializer, userKeyEnc, valEncoder, COMPOSITE_KEY_ROW_SCHEMA, 
stateName, hasTtl = true)
+
+  private val ttlExpirationMs =
+StateTTL.calculateExpirationTimeForDuration(ttlConfig.ttlDuration, 
batchTimestampMs)
+
+  initialize()
+
+  private def initialize(): Unit = {
+store.createColFamilyIfAbsent(stateName, COMPOSITE_KEY_ROW_SCHEMA, 
VALUE_ROW_SCHEMA_WITH_TTL,
+  PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, 1))
+  }
+
+  /** Whether state exists or not. */
+  override def exists(): Boolean = {
+iterator().nonEmpty
+  }
+
+  /** Get the state value if it exists */
+  override def getValue(key: K): V = {
+StateStoreErrors.requireNonNullStateValue(key, stateName)
+val encodedCompositeKey = stateTypesEncoder.encodeCompositeKey(key)
+val retRow = store.get(encodedCompositeKey, stateName)
+
+if (retRow != null) {
+  val resState = stateTypesEncoder.decodeValue(retRow)
+
+  if (!stateTypesEncoder.isExpired(retRow, batchTimestampMs)) {
+resState
+  } else {
+null.asInstanceOf[V]
+  }
+} else {
+  null.asInstanceOf[V]
+}
+  }
+
+  /** Check if the user key is contained in the map */
+  override def containsKey(key: K): Boolean = {
+StateStoreErrors.requireNonNullStateValue(key, stateName)
+getValue(key) != null
+  }
+
+  /** Update value for given user key */
+  override def updateValue(key: K, value: V): Unit = {
+StateStoreErrors.requireNonNullStateValue(key, stateName)
+StateStoreErrors.requireNonNullStateValue(value, stateName)
+val encodedValue = stateTypesEncoder.encodeValue(value, ttlExpirationMs)
+val encodedCompositeKey = stateTypesEncoder.encodeCompositeKey(key)
+store.put(encodedCompositeKey, encodedValue, stateName)
+val serializedGroupingKey = stateTypesEncoder.serializeGroupingKey()
+val serializedUserKey = stateTypesEncoder.serializeUserKey(key)
+upsertTTLForStateKey(ttlExpirationMs, serializedGroupingKey, 
serializedUserKey)
+  }
+
+  /** Get the map associated with grouping key */
+  override def iterator(): Iterator[(K, V)] = {
+val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey()
+val unsafeRowPairIterator = store.prefixScan(encodedGroupingKey, stateName)
+new NextIterator[(K, V)] {
+  override protected def getNext(): (K, V) = {
+val iter = unsafeRowPairIterator.dropWhile { rowPair =>
+  stateTypesEncoder.isExpired(rowPair.value, batchTimestampMs)
+}
+if (iter.hasNext) {
+  val currentRowPair = iter.next()
+  val key = stateTypesEncoder.decodeCompositeKey(currentRowPair.key)
+  val value = stateTypesEncoder.decodeValue(currentRowPair.value)
+  (key, value)
+} else {
+  finished = true
+  null.asInstanceOf[(K, V)]
+}
+  }
+
+  override protected def close(): Unit = {}
+}
+  }
+

Re: [PR] [SPARK-47805][SS] Implementing TTL for MapState [spark]

2024-04-18 Thread via GitHub


anishshri-db commented on code in PR #45991:
URL: https://github.com/apache/spark/pull/45991#discussion_r1570080402


##
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala:
##
@@ -0,0 +1,277 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.Encoder
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import 
org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{COMPOSITE_KEY_ROW_SCHEMA,
 VALUE_ROW_SCHEMA_WITH_TTL}
+import 
org.apache.spark.sql.execution.streaming.state.{PrefixKeyScanStateEncoderSpec, 
StateStore, StateStoreErrors}
+import org.apache.spark.sql.streaming.{MapState, TTLConfig}
+import org.apache.spark.util.NextIterator
+
+/**
+ * Class that provides a concrete implementation for map state associated with 
state
+ * variables (with ttl expiration support) used in the streaming 
transformWithState operator.
+ * @param store - reference to the StateStore instance to be used for storing 
state
+ * @param stateName  - name of the state variable
+ * @param keyExprEnc - Spark SQL encoder for key
+ * @param userKeyEnc  - Spark SQL encoder for the map key
+ * @param valEncoder - SQL encoder for state variable
+ * @param ttlConfig  - the ttl configuration (time to live duration etc.)
+ * @tparam K - type of key for map state variable
+ * @tparam V - type of value for map state variable
+ * @return - instance of MapState of type [K,V] that can be used to store 
state persistently
+ */
+class MapStateImplWithTTL[K, V](
+store: StateStore,
+stateName: String,
+keyExprEnc: ExpressionEncoder[Any],
+userKeyEnc: Encoder[K],
+valEncoder: Encoder[V],
+ttlConfig: TTLConfig,
+batchTimestampMs: Long) extends CompositeKeyTTLStateImpl(stateName, store, 
batchTimestampMs)
+  with MapState[K, V] with Logging {
+
+  private val keySerializer = keyExprEnc.createSerializer()
+  private val stateTypesEncoder = new CompositeKeyStateEncoder(
+keySerializer, userKeyEnc, valEncoder, COMPOSITE_KEY_ROW_SCHEMA, 
stateName, hasTtl = true)
+
+  private val ttlExpirationMs =
+StateTTL.calculateExpirationTimeForDuration(ttlConfig.ttlDuration, 
batchTimestampMs)
+
+  initialize()
+
+  private def initialize(): Unit = {
+store.createColFamilyIfAbsent(stateName, COMPOSITE_KEY_ROW_SCHEMA, 
VALUE_ROW_SCHEMA_WITH_TTL,
+  PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, 1))
+  }
+
+  /** Whether state exists or not. */
+  override def exists(): Boolean = {
+iterator().nonEmpty
+  }
+
+  /** Get the state value if it exists */
+  override def getValue(key: K): V = {
+StateStoreErrors.requireNonNullStateValue(key, stateName)
+val encodedCompositeKey = stateTypesEncoder.encodeCompositeKey(key)
+val retRow = store.get(encodedCompositeKey, stateName)
+
+if (retRow != null) {
+  val resState = stateTypesEncoder.decodeValue(retRow)
+
+  if (!stateTypesEncoder.isExpired(retRow, batchTimestampMs)) {
+resState
+  } else {
+null.asInstanceOf[V]
+  }
+} else {
+  null.asInstanceOf[V]
+}
+  }
+
+  /** Check if the user key is contained in the map */
+  override def containsKey(key: K): Boolean = {
+StateStoreErrors.requireNonNullStateValue(key, stateName)
+getValue(key) != null
+  }
+
+  /** Update value for given user key */
+  override def updateValue(key: K, value: V): Unit = {
+StateStoreErrors.requireNonNullStateValue(key, stateName)
+StateStoreErrors.requireNonNullStateValue(value, stateName)
+val encodedValue = stateTypesEncoder.encodeValue(value, ttlExpirationMs)

Review Comment:
   nit: maybe add a new line here ?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: 

Re: [PR] [SPARK-47805][SS] Implementing TTL for MapState [spark]

2024-04-18 Thread via GitHub


anishshri-db commented on code in PR #45991:
URL: https://github.com/apache/spark/pull/45991#discussion_r1570079481


##
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala:
##
@@ -0,0 +1,277 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.Encoder
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import 
org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{COMPOSITE_KEY_ROW_SCHEMA,
 VALUE_ROW_SCHEMA_WITH_TTL}
+import 
org.apache.spark.sql.execution.streaming.state.{PrefixKeyScanStateEncoderSpec, 
StateStore, StateStoreErrors}
+import org.apache.spark.sql.streaming.{MapState, TTLConfig}
+import org.apache.spark.util.NextIterator
+
+/**
+ * Class that provides a concrete implementation for map state associated with 
state
+ * variables (with ttl expiration support) used in the streaming 
transformWithState operator.
+ * @param store - reference to the StateStore instance to be used for storing 
state
+ * @param stateName  - name of the state variable
+ * @param keyExprEnc - Spark SQL encoder for key
+ * @param userKeyEnc  - Spark SQL encoder for the map key
+ * @param valEncoder - SQL encoder for state variable
+ * @param ttlConfig  - the ttl configuration (time to live duration etc.)

Review Comment:
   `batchTimestampMs` missing here in the list of params ?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



Re: [PR] [SPARK-47805][SS] Implementing TTL for MapState [spark]

2024-04-17 Thread via GitHub


ericm-db commented on code in PR #45991:
URL: https://github.com/apache/spark/pull/45991#discussion_r1569241785


##
sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateTTLSuite.scala:
##
@@ -0,0 +1,310 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming
+
+import java.time.Duration
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.Encoders
+import org.apache.spark.sql.execution.streaming.{MapStateImplWithTTL, 
MemoryStream}
+import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.util.StreamManualClock
+
+class MapStateSingleKeyTTLProcessor(ttlConfig: TTLConfig)
+  extends StatefulProcessor[String, InputEvent, OutputEvent]
+with Logging {
+
+  @transient private var _mapState: MapStateImplWithTTL[String, Int] = _
+
+  override def init(
+  outputMode: OutputMode,
+  timeMode: TimeMode): Unit = {
+_mapState = getHandle
+  .getMapState("mapState", Encoders.STRING, Encoders.scalaInt, ttlConfig)
+  .asInstanceOf[MapStateImplWithTTL[String, Int]]
+  }
+  override def handleInputRows(
+key: String,
+inputRows: Iterator[InputEvent],
+timerValues: TimerValues,
+expiredTimerInfo: ExpiredTimerInfo): Iterator[OutputEvent] = {
+var results = List[OutputEvent]()
+
+for (row <- inputRows) {
+  val resultIter = processRow(row, _mapState)
+  resultIter.foreach { r =>
+results = r :: results
+  }
+}
+
+results.iterator
+  }
+
+  def processRow(
+  row: InputEvent,
+  mapState: MapStateImplWithTTL[String, Int]): Iterator[OutputEvent] = {
+var results = List[OutputEvent]()
+val key = row.key
+val userKey = "key"
+if (row.action == "get") {
+  if (mapState.containsKey(userKey)) {
+results = OutputEvent(key, mapState.getValue(userKey), isTTLValue = 
false, -1) :: results
+  }
+} else if (row.action == "get_without_enforcing_ttl") {
+  val currState = mapState.getWithoutEnforcingTTL(userKey)
+  if (currState.isDefined) {
+results = OutputEvent(key, currState.get, isTTLValue = false, -1) :: 
results
+  }
+} else if (row.action == "get_ttl_value_from_state") {
+  val ttlValue = mapState.getTTLValue(userKey)
+  if (ttlValue.isDefined) {
+val value = ttlValue.get._1
+val ttlExpiration = ttlValue.get._2
+results = OutputEvent(key, value, isTTLValue = true, ttlExpiration) :: 
results
+  }
+} else if (row.action == "put") {
+  mapState.updateValue(userKey, row.value)
+} else if (row.action == "get_values_in_ttl_state") {
+  val ttlValues = mapState.getValuesInTTLState()
+  ttlValues.foreach { v =>
+results = OutputEvent(key, -1, isTTLValue = true, ttlValue = v) :: 
results
+  }
+}
+
+results.iterator
+  }
+}
+
+
+case class MapInputEvent(
+key: String,
+userKey: String,
+action: String,
+value: Int)
+
+case class MapOutputEvent(
+key: String,
+userKey: String,
+value: Int,
+isTTLValue: Boolean,
+ttlValue: Long)
+
+
+class MapStateTTLProcessor(ttlConfig: TTLConfig)
+  extends StatefulProcessor[String, MapInputEvent, MapOutputEvent]
+with Logging {
+
+  @transient private var _mapState: MapStateImplWithTTL[String, Int] = _
+
+  override def init(
+  outputMode: OutputMode,

Review Comment:
   Hm, it's 4 spaces from the function indent



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



Re: [PR] [SPARK-47805][SS] Implementing TTL for MapState [spark]

2024-04-17 Thread via GitHub


anishshri-db commented on code in PR #45991:
URL: https://github.com/apache/spark/pull/45991#discussion_r1569223183


##
sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateTTLSuite.scala:
##
@@ -0,0 +1,310 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming
+
+import java.time.Duration
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.Encoders
+import org.apache.spark.sql.execution.streaming.{MapStateImplWithTTL, 
MemoryStream}
+import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.util.StreamManualClock
+
+class MapStateSingleKeyTTLProcessor(ttlConfig: TTLConfig)
+  extends StatefulProcessor[String, InputEvent, OutputEvent]
+with Logging {
+
+  @transient private var _mapState: MapStateImplWithTTL[String, Int] = _
+
+  override def init(
+  outputMode: OutputMode,
+  timeMode: TimeMode): Unit = {
+_mapState = getHandle
+  .getMapState("mapState", Encoders.STRING, Encoders.scalaInt, ttlConfig)
+  .asInstanceOf[MapStateImplWithTTL[String, Int]]
+  }
+  override def handleInputRows(
+key: String,
+inputRows: Iterator[InputEvent],
+timerValues: TimerValues,
+expiredTimerInfo: ExpiredTimerInfo): Iterator[OutputEvent] = {
+var results = List[OutputEvent]()
+
+for (row <- inputRows) {
+  val resultIter = processRow(row, _mapState)
+  resultIter.foreach { r =>
+results = r :: results
+  }
+}
+
+results.iterator
+  }
+
+  def processRow(
+  row: InputEvent,
+  mapState: MapStateImplWithTTL[String, Int]): Iterator[OutputEvent] = {
+var results = List[OutputEvent]()
+val key = row.key
+val userKey = "key"
+if (row.action == "get") {
+  if (mapState.containsKey(userKey)) {
+results = OutputEvent(key, mapState.getValue(userKey), isTTLValue = 
false, -1) :: results
+  }
+} else if (row.action == "get_without_enforcing_ttl") {
+  val currState = mapState.getWithoutEnforcingTTL(userKey)
+  if (currState.isDefined) {
+results = OutputEvent(key, currState.get, isTTLValue = false, -1) :: 
results
+  }
+} else if (row.action == "get_ttl_value_from_state") {
+  val ttlValue = mapState.getTTLValue(userKey)
+  if (ttlValue.isDefined) {
+val value = ttlValue.get._1
+val ttlExpiration = ttlValue.get._2
+results = OutputEvent(key, value, isTTLValue = true, ttlExpiration) :: 
results
+  }
+} else if (row.action == "put") {
+  mapState.updateValue(userKey, row.value)
+} else if (row.action == "get_values_in_ttl_state") {
+  val ttlValues = mapState.getValuesInTTLState()
+  ttlValues.foreach { v =>
+results = OutputEvent(key, -1, isTTLValue = true, ttlValue = v) :: 
results
+  }
+}
+
+results.iterator
+  }
+}
+
+

Review Comment:
   nit: extra newline ?



##
sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateTTLSuite.scala:
##
@@ -0,0 +1,310 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming
+
+import java.time.Duration
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.Encoders
+import org.apache.spark.sql.execution.streaming.{MapStateImplWithTTL, 
MemoryStream}
+import 

Re: [PR] [SPARK-47805][SS] Implementing TTL for MapState [spark]

2024-04-17 Thread via GitHub


anishshri-db commented on code in PR #45991:
URL: https://github.com/apache/spark/pull/45991#discussion_r1569224035


##
sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateTTLSuite.scala:
##
@@ -0,0 +1,310 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming
+
+import java.time.Duration
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.Encoders
+import org.apache.spark.sql.execution.streaming.{MapStateImplWithTTL, 
MemoryStream}
+import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.util.StreamManualClock
+
+class MapStateSingleKeyTTLProcessor(ttlConfig: TTLConfig)
+  extends StatefulProcessor[String, InputEvent, OutputEvent]
+with Logging {
+
+  @transient private var _mapState: MapStateImplWithTTL[String, Int] = _
+
+  override def init(
+  outputMode: OutputMode,
+  timeMode: TimeMode): Unit = {
+_mapState = getHandle
+  .getMapState("mapState", Encoders.STRING, Encoders.scalaInt, ttlConfig)
+  .asInstanceOf[MapStateImplWithTTL[String, Int]]
+  }
+  override def handleInputRows(
+key: String,
+inputRows: Iterator[InputEvent],
+timerValues: TimerValues,
+expiredTimerInfo: ExpiredTimerInfo): Iterator[OutputEvent] = {
+var results = List[OutputEvent]()
+
+for (row <- inputRows) {
+  val resultIter = processRow(row, _mapState)
+  resultIter.foreach { r =>
+results = r :: results
+  }
+}
+
+results.iterator
+  }
+
+  def processRow(
+  row: InputEvent,
+  mapState: MapStateImplWithTTL[String, Int]): Iterator[OutputEvent] = {
+var results = List[OutputEvent]()
+val key = row.key
+val userKey = "key"
+if (row.action == "get") {
+  if (mapState.containsKey(userKey)) {
+results = OutputEvent(key, mapState.getValue(userKey), isTTLValue = 
false, -1) :: results
+  }
+} else if (row.action == "get_without_enforcing_ttl") {
+  val currState = mapState.getWithoutEnforcingTTL(userKey)
+  if (currState.isDefined) {
+results = OutputEvent(key, currState.get, isTTLValue = false, -1) :: 
results
+  }
+} else if (row.action == "get_ttl_value_from_state") {
+  val ttlValue = mapState.getTTLValue(userKey)
+  if (ttlValue.isDefined) {
+val value = ttlValue.get._1
+val ttlExpiration = ttlValue.get._2
+results = OutputEvent(key, value, isTTLValue = true, ttlExpiration) :: 
results
+  }
+} else if (row.action == "put") {
+  mapState.updateValue(userKey, row.value)
+} else if (row.action == "get_values_in_ttl_state") {
+  val ttlValues = mapState.getValuesInTTLState()
+  ttlValues.foreach { v =>
+results = OutputEvent(key, -1, isTTLValue = true, ttlValue = v) :: 
results
+  }
+}
+
+results.iterator
+  }
+}
+
+
+case class MapInputEvent(
+key: String,
+userKey: String,
+action: String,
+value: Int)
+
+case class MapOutputEvent(
+key: String,
+userKey: String,
+value: Int,
+isTTLValue: Boolean,
+ttlValue: Long)
+
+
+class MapStateTTLProcessor(ttlConfig: TTLConfig)
+  extends StatefulProcessor[String, MapInputEvent, MapOutputEvent]
+with Logging {
+
+  @transient private var _mapState: MapStateImplWithTTL[String, Int] = _
+
+  override def init(
+  outputMode: OutputMode,

Review Comment:
   nit: not sure if the indent is correct for this function ?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



Re: [PR] [SPARK-47805][SS] Implementing TTL for MapState [spark]

2024-04-17 Thread via GitHub


ericm-db commented on code in PR #45991:
URL: https://github.com/apache/spark/pull/45991#discussion_r1569213581


##
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala:
##
@@ -137,6 +137,7 @@ class ListStateImplWithTTL[S](
   /** Remove this state. */
   override def clear(): Unit = {
 store.remove(stateTypesEncoder.encodeGroupingKey(), stateName)
+clearTTLState()

Review Comment:
   Yes, it was. @sahnib realized that when we clear state, we can also clear 
the associated ttlState



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



Re: [PR] [SPARK-47805][SS] Implementing TTL for MapState [spark]

2024-04-17 Thread via GitHub


ericm-db commented on code in PR #45991:
URL: https://github.com/apache/spark/pull/45991#discussion_r1569214252


##
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala:
##
@@ -0,0 +1,265 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.Encoder
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import 
org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{COMPOSITE_KEY_ROW_SCHEMA,
 VALUE_ROW_SCHEMA_WITH_TTL}
+import 
org.apache.spark.sql.execution.streaming.state.{PrefixKeyScanStateEncoderSpec, 
StateStore, StateStoreErrors}
+import org.apache.spark.sql.streaming.{MapState, TTLConfig}
+import org.apache.spark.util.NextIterator
+
+
+class MapStateImplWithTTL[K, V](
+  store: StateStore,
+  stateName: String,
+  keyExprEnc: ExpressionEncoder[Any],
+  userKeyEnc: Encoder[K],
+  valEncoder: Encoder[V],
+  ttlConfig: TTLConfig,
+  batchTimestampMs: Long) extends CompositeKeyTTLStateImpl(stateName, store, 
batchTimestampMs)
+  with MapState[K, V] with Logging {
+
+  private val keySerializer = keyExprEnc.createSerializer()
+  private val stateTypesEncoder = new CompositeKeyStateEncoder(
+keySerializer, userKeyEnc, valEncoder, COMPOSITE_KEY_ROW_SCHEMA, 
stateName, hasTtl = true)
+
+  private val ttlExpirationMs =
+StateTTL.calculateExpirationTimeForDuration(ttlConfig.ttlDuration, 
batchTimestampMs)
+
+  initialize()
+
+  private def initialize(): Unit = {
+store.createColFamilyIfAbsent(stateName, COMPOSITE_KEY_ROW_SCHEMA, 
VALUE_ROW_SCHEMA_WITH_TTL,
+  PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, 1))
+  }
+
+  /** Whether state exists or not. */
+  override def exists(): Boolean = {
+iterator().nonEmpty
+  }
+
+  /** Get the state value if it exists */
+  override def getValue(key: K): V = {
+StateStoreErrors.requireNonNullStateValue(key, stateName)
+val encodedCompositeKey = stateTypesEncoder.encodeCompositeKey(key)
+val retRow = store.get(encodedCompositeKey, stateName)
+
+if (retRow != null) {
+  val resState = stateTypesEncoder.decodeValue(retRow)
+
+  if (!stateTypesEncoder.isExpired(retRow, batchTimestampMs)) {
+resState
+  } else {
+null.asInstanceOf[V]
+  }
+} else {
+  null.asInstanceOf[V]
+}
+  }
+
+  /** Check if the user key is contained in the map */
+  override def containsKey(key: K): Boolean = {
+StateStoreErrors.requireNonNullStateValue(key, stateName)
+getValue(key) != null
+  }
+
+  /** Update value for given user key */
+  override def updateValue(key: K, value: V): Unit = {
+StateStoreErrors.requireNonNullStateValue(key, stateName)
+StateStoreErrors.requireNonNullStateValue(value, stateName)
+val encodedValue = stateTypesEncoder.encodeValue(value, ttlExpirationMs)
+val encodedCompositeKey = stateTypesEncoder.encodeCompositeKey(key)
+store.put(encodedCompositeKey, encodedValue, stateName)
+val serializedGroupingKey = stateTypesEncoder.serializeGroupingKey()
+val serializedUserKey = stateTypesEncoder.serializeUserKey(key)
+upsertTTLForStateKey(ttlExpirationMs, serializedGroupingKey, 
serializedUserKey)
+  }
+
+  /** Get the map associated with grouping key */
+  override def iterator(): Iterator[(K, V)] = {
+val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey()
+val unsafeRowPairIterator = store.prefixScan(encodedGroupingKey, stateName)
+new NextIterator[(K, V)] {
+  override protected def getNext(): (K, V) = {
+val iter = unsafeRowPairIterator.dropWhile { rowPair =>
+  stateTypesEncoder.isExpired(rowPair.value, batchTimestampMs)
+}
+if (iter.hasNext) {
+  val currentRowPair = iter.next()
+  val key = stateTypesEncoder.decodeCompositeKey(currentRowPair.key)
+  val value = stateTypesEncoder.decodeValue(currentRowPair.value)
+  (key, value)
+} else {
+  finished = true
+  null.asInstanceOf[(K, V)]
+}
+  }
+
+  override protected def close(): Unit = {}
+}
+  }
+
+  

Re: [PR] [SPARK-47805][SS] Implementing TTL for MapState [spark]

2024-04-17 Thread via GitHub


ericm-db commented on code in PR #45991:
URL: https://github.com/apache/spark/pull/45991#discussion_r1569209877


##
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala:
##
@@ -59,23 +75,6 @@ trait TTLState {
* @return number of values cleaned up.
*/
   def clearExpiredState(): Long
-
-  /**
-   * Clears the user state associated with this grouping key
-   * if it has expired. This function is called by Spark to perform
-   * cleanup at the end of transformWithState processing.
-   *
-   * Spark uses a secondary index to determine if the user state for
-   * this grouping key has expired. However, its possible that the user
-   * has updated the TTL and secondary index is out of date. Implementations
-   * must validate that the user State has actually expired before cleanup 
based
-   * on their own State data.
-   *
-   * @param groupingKey grouping key for which cleanup should be performed.
-   *
-   * @return how many state objects were cleaned up.
-   */
-  def clearIfExpired(groupingKey: Array[Byte]): Long

Review Comment:
   We removed it from the TTLState because it only applies to SingleKeyStateTTL



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



Re: [PR] [SPARK-47805][SS] Implementing TTL for MapState [spark]

2024-04-17 Thread via GitHub


anishshri-db commented on code in PR #45991:
URL: https://github.com/apache/spark/pull/45991#discussion_r1569194799


##
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala:
##
@@ -59,23 +75,6 @@ trait TTLState {
* @return number of values cleaned up.
*/
   def clearExpiredState(): Long
-
-  /**
-   * Clears the user state associated with this grouping key
-   * if it has expired. This function is called by Spark to perform
-   * cleanup at the end of transformWithState processing.
-   *
-   * Spark uses a secondary index to determine if the user state for
-   * this grouping key has expired. However, its possible that the user
-   * has updated the TTL and secondary index is out of date. Implementations
-   * must validate that the user State has actually expired before cleanup 
based
-   * on their own State data.
-   *
-   * @param groupingKey grouping key for which cleanup should be performed.
-   *
-   * @return how many state objects were cleaned up.
-   */
-  def clearIfExpired(groupingKey: Array[Byte]): Long

Review Comment:
   Hmm - why do remove this from the base trait ?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



Re: [PR] [SPARK-47805][SS] Implementing TTL for MapState [spark]

2024-04-17 Thread via GitHub


anishshri-db commented on code in PR #45991:
URL: https://github.com/apache/spark/pull/45991#discussion_r1569195018


##
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala:
##
@@ -163,6 +172,115 @@ abstract class SingleKeyTTLStateImpl(
   }
 }
   }
+
+  /**
+   * Clears the user state associated with this grouping key
+   * if it has expired. This function is called by Spark to perform
+   * cleanup at the end of transformWithState processing.
+   *
+   * Spark uses a secondary index to determine if the user state for
+   * this grouping key has expired. However, its possible that the user
+   * has updated the TTL and secondary index is out of date. Implementations
+   * must validate that the user State has actually expired before cleanup 
based
+   * on their own State data.
+   *
+   * @param groupingKey grouping key for which cleanup should be performed.
+   *
+   * @return true if the state was cleared, false otherwise.
+   */
+  def clearIfExpired(groupingKey: Array[Byte]): Long
+}
+
+/**
+ * Manages the ttl information for user state keyed with a single key 
(grouping key).
+ */
+abstract class CompositeKeyTTLStateImpl(
+  stateName: String,

Review Comment:
   nit: indent 4 spaces



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



Re: [PR] [SPARK-47805][SS] Implementing TTL for MapState [spark]

2024-04-17 Thread via GitHub


anishshri-db commented on code in PR #45991:
URL: https://github.com/apache/spark/pull/45991#discussion_r1569193194


##
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala:
##
@@ -192,6 +195,28 @@ class CompositeKeyStateEncoder[GK, K, V](
 compositeKeyRow
   }
 
+  def decodeUserKeyFromTTLRow(row: CompositeKeyTTLRow): K = {
+val bytes = row.userKey
+reusedKeyRow.pointTo(bytes, bytes.length)
+val userKey = userKeyRowToObjDeserializer.apply(reusedKeyRow)
+userKey
+  }
+
+  /**
+   * Grouping key and user key are encoded as a row of 
`schemaForCompositeKeyRow` schema.
+   * Grouping key will be encoded in `RocksDBStateEncoder` as the prefix 
column.
+   */
+  def encodeCompositeKey(
+groupingKeyByteArr: Array[Byte],

Review Comment:
   nit: indent spacing needs to be 4 spaces ?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



Re: [PR] [SPARK-47805][SS] Implementing TTL for MapState [spark]

2024-04-17 Thread via GitHub


anishshri-db commented on code in PR #45991:
URL: https://github.com/apache/spark/pull/45991#discussion_r1569192544


##
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala:
##
@@ -0,0 +1,265 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.Encoder
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import 
org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{COMPOSITE_KEY_ROW_SCHEMA,
 VALUE_ROW_SCHEMA_WITH_TTL}
+import 
org.apache.spark.sql.execution.streaming.state.{PrefixKeyScanStateEncoderSpec, 
StateStore, StateStoreErrors}
+import org.apache.spark.sql.streaming.{MapState, TTLConfig}
+import org.apache.spark.util.NextIterator
+
+
+class MapStateImplWithTTL[K, V](
+  store: StateStore,
+  stateName: String,
+  keyExprEnc: ExpressionEncoder[Any],
+  userKeyEnc: Encoder[K],
+  valEncoder: Encoder[V],
+  ttlConfig: TTLConfig,
+  batchTimestampMs: Long) extends CompositeKeyTTLStateImpl(stateName, store, 
batchTimestampMs)
+  with MapState[K, V] with Logging {
+
+  private val keySerializer = keyExprEnc.createSerializer()
+  private val stateTypesEncoder = new CompositeKeyStateEncoder(
+keySerializer, userKeyEnc, valEncoder, COMPOSITE_KEY_ROW_SCHEMA, 
stateName, hasTtl = true)
+
+  private val ttlExpirationMs =
+StateTTL.calculateExpirationTimeForDuration(ttlConfig.ttlDuration, 
batchTimestampMs)
+
+  initialize()
+
+  private def initialize(): Unit = {
+store.createColFamilyIfAbsent(stateName, COMPOSITE_KEY_ROW_SCHEMA, 
VALUE_ROW_SCHEMA_WITH_TTL,
+  PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, 1))
+  }
+
+  /** Whether state exists or not. */
+  override def exists(): Boolean = {
+iterator().nonEmpty
+  }
+
+  /** Get the state value if it exists */
+  override def getValue(key: K): V = {
+StateStoreErrors.requireNonNullStateValue(key, stateName)
+val encodedCompositeKey = stateTypesEncoder.encodeCompositeKey(key)
+val retRow = store.get(encodedCompositeKey, stateName)
+
+if (retRow != null) {
+  val resState = stateTypesEncoder.decodeValue(retRow)
+
+  if (!stateTypesEncoder.isExpired(retRow, batchTimestampMs)) {
+resState
+  } else {
+null.asInstanceOf[V]
+  }
+} else {
+  null.asInstanceOf[V]
+}
+  }
+
+  /** Check if the user key is contained in the map */
+  override def containsKey(key: K): Boolean = {
+StateStoreErrors.requireNonNullStateValue(key, stateName)
+getValue(key) != null
+  }
+
+  /** Update value for given user key */
+  override def updateValue(key: K, value: V): Unit = {
+StateStoreErrors.requireNonNullStateValue(key, stateName)
+StateStoreErrors.requireNonNullStateValue(value, stateName)
+val encodedValue = stateTypesEncoder.encodeValue(value, ttlExpirationMs)
+val encodedCompositeKey = stateTypesEncoder.encodeCompositeKey(key)
+store.put(encodedCompositeKey, encodedValue, stateName)
+val serializedGroupingKey = stateTypesEncoder.serializeGroupingKey()
+val serializedUserKey = stateTypesEncoder.serializeUserKey(key)
+upsertTTLForStateKey(ttlExpirationMs, serializedGroupingKey, 
serializedUserKey)
+  }
+
+  /** Get the map associated with grouping key */
+  override def iterator(): Iterator[(K, V)] = {
+val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey()
+val unsafeRowPairIterator = store.prefixScan(encodedGroupingKey, stateName)
+new NextIterator[(K, V)] {
+  override protected def getNext(): (K, V) = {
+val iter = unsafeRowPairIterator.dropWhile { rowPair =>
+  stateTypesEncoder.isExpired(rowPair.value, batchTimestampMs)
+}
+if (iter.hasNext) {
+  val currentRowPair = iter.next()
+  val key = stateTypesEncoder.decodeCompositeKey(currentRowPair.key)
+  val value = stateTypesEncoder.decodeValue(currentRowPair.value)
+  (key, value)
+} else {
+  finished = true
+  null.asInstanceOf[(K, V)]
+}
+  }
+
+  override protected def close(): Unit = {}
+}
+  }
+

Re: [PR] [SPARK-47805][SS] Implementing TTL for MapState [spark]

2024-04-17 Thread via GitHub


anishshri-db commented on code in PR #45991:
URL: https://github.com/apache/spark/pull/45991#discussion_r1569182582


##
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala:
##
@@ -0,0 +1,265 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.Encoder
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import 
org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{COMPOSITE_KEY_ROW_SCHEMA,
 VALUE_ROW_SCHEMA_WITH_TTL}
+import 
org.apache.spark.sql.execution.streaming.state.{PrefixKeyScanStateEncoderSpec, 
StateStore, StateStoreErrors}
+import org.apache.spark.sql.streaming.{MapState, TTLConfig}
+import org.apache.spark.util.NextIterator
+
+

Review Comment:
   nit: extra newline



##
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala:
##
@@ -0,0 +1,265 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.Encoder
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import 
org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{COMPOSITE_KEY_ROW_SCHEMA,
 VALUE_ROW_SCHEMA_WITH_TTL}
+import 
org.apache.spark.sql.execution.streaming.state.{PrefixKeyScanStateEncoderSpec, 
StateStore, StateStoreErrors}
+import org.apache.spark.sql.streaming.{MapState, TTLConfig}
+import org.apache.spark.util.NextIterator
+
+
+class MapStateImplWithTTL[K, V](
+  store: StateStore,

Review Comment:
   indent seems off ? needs to be 4 spaces for args ?



##
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala:
##
@@ -0,0 +1,265 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.Encoder
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import 
org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{COMPOSITE_KEY_ROW_SCHEMA,
 VALUE_ROW_SCHEMA_WITH_TTL}
+import 
org.apache.spark.sql.execution.streaming.state.{PrefixKeyScanStateEncoderSpec, 
StateStore, StateStoreErrors}
+import org.apache.spark.sql.streaming.{MapState, TTLConfig}
+import org.apache.spark.util.NextIterator
+
+
+class MapStateImplWithTTL[K, V](

Review Comment:
   Lets add a class level comment ?



-- 
This is an automated message from the Apache Git Service.
To respond to 

Re: [PR] [SPARK-47805][SS] Implementing TTL for MapState [spark]

2024-04-17 Thread via GitHub


anishshri-db commented on code in PR #45991:
URL: https://github.com/apache/spark/pull/45991#discussion_r1569182254


##
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala:
##
@@ -137,6 +137,7 @@ class ListStateImplWithTTL[S](
   /** Remove this state. */
   override def clear(): Unit = {
 store.remove(stateTypesEncoder.encodeGroupingKey(), stateName)
+clearTTLState()

Review Comment:
   This was a miss on the `ListState` PR ?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



Re: [PR] [SPARK-47805][SS] Implementing TTL for MapState [spark]

2024-04-17 Thread via GitHub


anishshri-db commented on code in PR #45991:
URL: https://github.com/apache/spark/pull/45991#discussion_r1569180634


##
sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala:
##
@@ -108,6 +108,28 @@ private[sql] trait StatefulProcessorHandle extends 
Serializable {
   userKeyEnc: Encoder[K],
   valEncoder: Encoder[V]): MapState[K, V]
 
+  /**
+   * Function to create new or return existing map state variable of given type
+   * with ttl. State values will not be returned past ttlDuration, and will be 
eventually removed
+   * from the state store. Any values in mapState which have expired after 
ttlDuration will not
+   * returned on get() and will be eventually removed from the state.
+   *
+   * The user must ensure to call this function only within the `init()` 
method of the
+   * StatefulProcessor.
+   *
+   * @param stateName  - name of the state variable
+   * @param valEncoder - SQL encoder for state variable

Review Comment:
   This param comment is missing ?
   
   ```
userKeyEnc: Encoder[K],
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



Re: [PR] [SPARK-47805][SS] Implementing TTL for MapState [spark]

2024-04-16 Thread via GitHub


ericm-db commented on code in PR #45991:
URL: https://github.com/apache/spark/pull/45991#discussion_r1567940592


##
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala:
##
@@ -0,0 +1,265 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.Encoder
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import 
org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{COMPOSITE_KEY_ROW_SCHEMA,
 VALUE_ROW_SCHEMA_WITH_TTL}
+import 
org.apache.spark.sql.execution.streaming.state.{PrefixKeyScanStateEncoderSpec, 
StateStore, StateStoreErrors}
+import org.apache.spark.sql.streaming.{MapState, TTLConfig}
+import org.apache.spark.util.NextIterator
+
+
+class MapStateImplWithTTL[K, V](
+  store: StateStore,
+  stateName: String,
+  keyExprEnc: ExpressionEncoder[Any],
+  userKeyEnc: Encoder[K],
+  valEncoder: Encoder[V],
+  ttlConfig: TTLConfig,
+  batchTimestampMs: Long) extends CompositeKeyTTLStateImpl(stateName, store, 
batchTimestampMs)
+  with MapState[K, V] with Logging {
+
+  private val keySerializer = keyExprEnc.createSerializer()
+  private val stateTypesEncoder = new CompositeKeyStateEncoder(
+keySerializer, userKeyEnc, valEncoder, COMPOSITE_KEY_ROW_SCHEMA, 
stateName, hasTtl = true)
+
+  private val ttlExpirationMs =
+StateTTL.calculateExpirationTimeForDuration(ttlConfig.ttlDuration, 
batchTimestampMs)
+
+  initialize()
+
+  private def initialize(): Unit = {
+store.createColFamilyIfAbsent(stateName, COMPOSITE_KEY_ROW_SCHEMA, 
VALUE_ROW_SCHEMA_WITH_TTL,
+  PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, 1))
+  }
+
+  /** Whether state exists or not. */
+  override def exists(): Boolean = {
+iterator().nonEmpty
+  }
+
+  /** Get the state value if it exists */
+  override def getValue(key: K): V = {
+StateStoreErrors.requireNonNullStateValue(key, stateName)
+val encodedCompositeKey = stateTypesEncoder.encodeCompositeKey(key)
+val retRow = store.get(encodedCompositeKey, stateName)
+
+if (retRow != null) {
+  val resState = stateTypesEncoder.decodeValue(retRow)
+
+  if (!stateTypesEncoder.isExpired(retRow, batchTimestampMs)) {
+resState
+  } else {
+null.asInstanceOf[V]
+  }
+} else {
+  null.asInstanceOf[V]
+}
+  }
+
+  /** Check if the user key is contained in the map */
+  override def containsKey(key: K): Boolean = {
+StateStoreErrors.requireNonNullStateValue(key, stateName)
+getValue(key) != null
+  }
+
+  /** Update value for given user key */
+  override def updateValue(key: K, value: V): Unit = {
+StateStoreErrors.requireNonNullStateValue(key, stateName)
+StateStoreErrors.requireNonNullStateValue(value, stateName)
+val encodedValue = stateTypesEncoder.encodeValue(value, ttlExpirationMs)
+val encodedCompositeKey = stateTypesEncoder.encodeCompositeKey(key)
+store.put(encodedCompositeKey, encodedValue, stateName)
+val serializedGroupingKey = stateTypesEncoder.serializeGroupingKey()
+val serializedUserKey = stateTypesEncoder.serializeUserKey(key)
+upsertTTLForStateKey(ttlExpirationMs, serializedGroupingKey, 
serializedUserKey)
+  }
+
+  /** Get the map associated with grouping key */
+  override def iterator(): Iterator[(K, V)] = {
+val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey()
+val unsafeRowPairIterator = store.prefixScan(encodedGroupingKey, stateName)
+new NextIterator[(K, V)] {
+  override protected def getNext(): (K, V) = {
+val iter = unsafeRowPairIterator.dropWhile { rowPair =>
+  stateTypesEncoder.isExpired(rowPair.value, batchTimestampMs)
+}
+if (iter.hasNext) {
+  val currentRowPair = iter.next()
+  val key = stateTypesEncoder.decodeCompositeKey(currentRowPair.key)
+  val value = stateTypesEncoder.decodeValue(currentRowPair.value)
+  (key, value)
+} else {
+  finished = true
+  null.asInstanceOf[(K, V)]
+}
+  }
+
+  override protected def close(): Unit = {}
+}
+  }
+
+  

Re: [PR] [SPARK-47805][SS] Implementing TTL for MapState [spark]

2024-04-16 Thread via GitHub


ericm-db commented on code in PR #45991:
URL: https://github.com/apache/spark/pull/45991#discussion_r1567908966


##
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala:
##
@@ -0,0 +1,265 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.Encoder
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import 
org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{COMPOSITE_KEY_ROW_SCHEMA,
 VALUE_ROW_SCHEMA_WITH_TTL}
+import 
org.apache.spark.sql.execution.streaming.state.{PrefixKeyScanStateEncoderSpec, 
StateStore, StateStoreErrors}
+import org.apache.spark.sql.streaming.{MapState, TTLConfig}
+import org.apache.spark.util.NextIterator
+
+
+class MapStateImplWithTTL[K, V](
+  store: StateStore,
+  stateName: String,
+  keyExprEnc: ExpressionEncoder[Any],
+  userKeyEnc: Encoder[K],
+  valEncoder: Encoder[V],
+  ttlConfig: TTLConfig,
+  batchTimestampMs: Long) extends CompositeKeyTTLStateImpl(stateName, store, 
batchTimestampMs)
+  with MapState[K, V] with Logging {
+
+  private val keySerializer = keyExprEnc.createSerializer()
+  private val stateTypesEncoder = new CompositeKeyStateEncoder(
+keySerializer, userKeyEnc, valEncoder, COMPOSITE_KEY_ROW_SCHEMA, 
stateName, hasTtl = true)
+
+  private val ttlExpirationMs =
+StateTTL.calculateExpirationTimeForDuration(ttlConfig.ttlDuration, 
batchTimestampMs)
+
+  initialize()
+
+  private def initialize(): Unit = {
+store.createColFamilyIfAbsent(stateName, COMPOSITE_KEY_ROW_SCHEMA, 
VALUE_ROW_SCHEMA_WITH_TTL,
+  PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, 1))
+  }
+
+  /** Whether state exists or not. */
+  override def exists(): Boolean = {
+iterator().nonEmpty
+  }
+
+  /** Get the state value if it exists */
+  override def getValue(key: K): V = {
+StateStoreErrors.requireNonNullStateValue(key, stateName)
+val encodedCompositeKey = stateTypesEncoder.encodeCompositeKey(key)
+val retRow = store.get(encodedCompositeKey, stateName)
+
+if (retRow != null) {
+  val resState = stateTypesEncoder.decodeValue(retRow)
+
+  if (!stateTypesEncoder.isExpired(retRow, batchTimestampMs)) {
+resState
+  } else {
+null.asInstanceOf[V]
+  }
+} else {
+  null.asInstanceOf[V]
+}
+  }
+
+  /** Check if the user key is contained in the map */
+  override def containsKey(key: K): Boolean = {
+StateStoreErrors.requireNonNullStateValue(key, stateName)
+getValue(key) != null
+  }
+
+  /** Update value for given user key */
+  override def updateValue(key: K, value: V): Unit = {
+StateStoreErrors.requireNonNullStateValue(key, stateName)
+StateStoreErrors.requireNonNullStateValue(value, stateName)
+val encodedValue = stateTypesEncoder.encodeValue(value, ttlExpirationMs)
+val encodedCompositeKey = stateTypesEncoder.encodeCompositeKey(key)
+store.put(encodedCompositeKey, encodedValue, stateName)
+val serializedGroupingKey = stateTypesEncoder.serializeGroupingKey()
+val serializedUserKey = stateTypesEncoder.serializeUserKey(key)
+upsertTTLForStateKey(ttlExpirationMs, serializedGroupingKey, 
serializedUserKey)
+  }
+
+  /** Get the map associated with grouping key */
+  override def iterator(): Iterator[(K, V)] = {
+val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey()
+val unsafeRowPairIterator = store.prefixScan(encodedGroupingKey, stateName)
+new NextIterator[(K, V)] {
+  override protected def getNext(): (K, V) = {
+val iter = unsafeRowPairIterator.dropWhile { rowPair =>
+  stateTypesEncoder.isExpired(rowPair.value, batchTimestampMs)
+}
+if (iter.hasNext) {
+  val currentRowPair = iter.next()
+  val key = stateTypesEncoder.decodeCompositeKey(currentRowPair.key)
+  val value = stateTypesEncoder.decodeValue(currentRowPair.value)
+  (key, value)
+} else {
+  finished = true
+  null.asInstanceOf[(K, V)]
+}
+  }
+
+  override protected def close(): Unit = {}
+}
+  }
+
+  

Re: [PR] [SPARK-47805][SS] Implementing TTL for MapState [spark]

2024-04-16 Thread via GitHub


sahnib commented on code in PR #45991:
URL: https://github.com/apache/spark/pull/45991#discussion_r1567707784


##
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala:
##
@@ -0,0 +1,265 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.Encoder
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import 
org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{COMPOSITE_KEY_ROW_SCHEMA,
 VALUE_ROW_SCHEMA_WITH_TTL}
+import 
org.apache.spark.sql.execution.streaming.state.{PrefixKeyScanStateEncoderSpec, 
StateStore, StateStoreErrors}
+import org.apache.spark.sql.streaming.{MapState, TTLConfig}
+import org.apache.spark.util.NextIterator
+
+
+class MapStateImplWithTTL[K, V](
+  store: StateStore,
+  stateName: String,
+  keyExprEnc: ExpressionEncoder[Any],
+  userKeyEnc: Encoder[K],
+  valEncoder: Encoder[V],
+  ttlConfig: TTLConfig,
+  batchTimestampMs: Long) extends CompositeKeyTTLStateImpl(stateName, store, 
batchTimestampMs)
+  with MapState[K, V] with Logging {
+
+  private val keySerializer = keyExprEnc.createSerializer()
+  private val stateTypesEncoder = new CompositeKeyStateEncoder(
+keySerializer, userKeyEnc, valEncoder, COMPOSITE_KEY_ROW_SCHEMA, 
stateName, hasTtl = true)
+
+  private val ttlExpirationMs =
+StateTTL.calculateExpirationTimeForDuration(ttlConfig.ttlDuration, 
batchTimestampMs)
+
+  initialize()
+
+  private def initialize(): Unit = {
+store.createColFamilyIfAbsent(stateName, COMPOSITE_KEY_ROW_SCHEMA, 
VALUE_ROW_SCHEMA_WITH_TTL,
+  PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, 1))
+  }
+
+  /** Whether state exists or not. */
+  override def exists(): Boolean = {
+iterator().nonEmpty
+  }
+
+  /** Get the state value if it exists */
+  override def getValue(key: K): V = {
+StateStoreErrors.requireNonNullStateValue(key, stateName)
+val encodedCompositeKey = stateTypesEncoder.encodeCompositeKey(key)
+val retRow = store.get(encodedCompositeKey, stateName)
+
+if (retRow != null) {
+  val resState = stateTypesEncoder.decodeValue(retRow)
+
+  if (!stateTypesEncoder.isExpired(retRow, batchTimestampMs)) {
+resState
+  } else {
+null.asInstanceOf[V]
+  }
+} else {
+  null.asInstanceOf[V]
+}
+  }
+
+  /** Check if the user key is contained in the map */
+  override def containsKey(key: K): Boolean = {
+StateStoreErrors.requireNonNullStateValue(key, stateName)
+getValue(key) != null
+  }
+
+  /** Update value for given user key */
+  override def updateValue(key: K, value: V): Unit = {
+StateStoreErrors.requireNonNullStateValue(key, stateName)
+StateStoreErrors.requireNonNullStateValue(value, stateName)
+val encodedValue = stateTypesEncoder.encodeValue(value, ttlExpirationMs)
+val encodedCompositeKey = stateTypesEncoder.encodeCompositeKey(key)
+store.put(encodedCompositeKey, encodedValue, stateName)
+val serializedGroupingKey = stateTypesEncoder.serializeGroupingKey()
+val serializedUserKey = stateTypesEncoder.serializeUserKey(key)
+upsertTTLForStateKey(ttlExpirationMs, serializedGroupingKey, 
serializedUserKey)
+  }
+
+  /** Get the map associated with grouping key */
+  override def iterator(): Iterator[(K, V)] = {
+val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey()
+val unsafeRowPairIterator = store.prefixScan(encodedGroupingKey, stateName)
+new NextIterator[(K, V)] {
+  override protected def getNext(): (K, V) = {
+val iter = unsafeRowPairIterator.dropWhile { rowPair =>
+  stateTypesEncoder.isExpired(rowPair.value, batchTimestampMs)
+}
+if (iter.hasNext) {
+  val currentRowPair = iter.next()
+  val key = stateTypesEncoder.decodeCompositeKey(currentRowPair.key)
+  val value = stateTypesEncoder.decodeValue(currentRowPair.value)
+  (key, value)
+} else {
+  finished = true
+  null.asInstanceOf[(K, V)]
+}
+  }
+
+  override protected def close(): Unit = {}
+}
+  }
+
+  /** 

Re: [PR] [SPARK-47805][SS] Implementing TTL for MapState [spark]

2024-04-16 Thread via GitHub


sahnib commented on code in PR #45991:
URL: https://github.com/apache/spark/pull/45991#discussion_r1567707784


##
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala:
##
@@ -0,0 +1,265 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.Encoder
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import 
org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{COMPOSITE_KEY_ROW_SCHEMA,
 VALUE_ROW_SCHEMA_WITH_TTL}
+import 
org.apache.spark.sql.execution.streaming.state.{PrefixKeyScanStateEncoderSpec, 
StateStore, StateStoreErrors}
+import org.apache.spark.sql.streaming.{MapState, TTLConfig}
+import org.apache.spark.util.NextIterator
+
+
+class MapStateImplWithTTL[K, V](
+  store: StateStore,
+  stateName: String,
+  keyExprEnc: ExpressionEncoder[Any],
+  userKeyEnc: Encoder[K],
+  valEncoder: Encoder[V],
+  ttlConfig: TTLConfig,
+  batchTimestampMs: Long) extends CompositeKeyTTLStateImpl(stateName, store, 
batchTimestampMs)
+  with MapState[K, V] with Logging {
+
+  private val keySerializer = keyExprEnc.createSerializer()
+  private val stateTypesEncoder = new CompositeKeyStateEncoder(
+keySerializer, userKeyEnc, valEncoder, COMPOSITE_KEY_ROW_SCHEMA, 
stateName, hasTtl = true)
+
+  private val ttlExpirationMs =
+StateTTL.calculateExpirationTimeForDuration(ttlConfig.ttlDuration, 
batchTimestampMs)
+
+  initialize()
+
+  private def initialize(): Unit = {
+store.createColFamilyIfAbsent(stateName, COMPOSITE_KEY_ROW_SCHEMA, 
VALUE_ROW_SCHEMA_WITH_TTL,
+  PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, 1))
+  }
+
+  /** Whether state exists or not. */
+  override def exists(): Boolean = {
+iterator().nonEmpty
+  }
+
+  /** Get the state value if it exists */
+  override def getValue(key: K): V = {
+StateStoreErrors.requireNonNullStateValue(key, stateName)
+val encodedCompositeKey = stateTypesEncoder.encodeCompositeKey(key)
+val retRow = store.get(encodedCompositeKey, stateName)
+
+if (retRow != null) {
+  val resState = stateTypesEncoder.decodeValue(retRow)
+
+  if (!stateTypesEncoder.isExpired(retRow, batchTimestampMs)) {
+resState
+  } else {
+null.asInstanceOf[V]
+  }
+} else {
+  null.asInstanceOf[V]
+}
+  }
+
+  /** Check if the user key is contained in the map */
+  override def containsKey(key: K): Boolean = {
+StateStoreErrors.requireNonNullStateValue(key, stateName)
+getValue(key) != null
+  }
+
+  /** Update value for given user key */
+  override def updateValue(key: K, value: V): Unit = {
+StateStoreErrors.requireNonNullStateValue(key, stateName)
+StateStoreErrors.requireNonNullStateValue(value, stateName)
+val encodedValue = stateTypesEncoder.encodeValue(value, ttlExpirationMs)
+val encodedCompositeKey = stateTypesEncoder.encodeCompositeKey(key)
+store.put(encodedCompositeKey, encodedValue, stateName)
+val serializedGroupingKey = stateTypesEncoder.serializeGroupingKey()
+val serializedUserKey = stateTypesEncoder.serializeUserKey(key)
+upsertTTLForStateKey(ttlExpirationMs, serializedGroupingKey, 
serializedUserKey)
+  }
+
+  /** Get the map associated with grouping key */
+  override def iterator(): Iterator[(K, V)] = {
+val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey()
+val unsafeRowPairIterator = store.prefixScan(encodedGroupingKey, stateName)
+new NextIterator[(K, V)] {
+  override protected def getNext(): (K, V) = {
+val iter = unsafeRowPairIterator.dropWhile { rowPair =>
+  stateTypesEncoder.isExpired(rowPair.value, batchTimestampMs)
+}
+if (iter.hasNext) {
+  val currentRowPair = iter.next()
+  val key = stateTypesEncoder.decodeCompositeKey(currentRowPair.key)
+  val value = stateTypesEncoder.decodeValue(currentRowPair.value)
+  (key, value)
+} else {
+  finished = true
+  null.asInstanceOf[(K, V)]
+}
+  }
+
+  override protected def close(): Unit = {}
+}
+  }
+
+  /**