sahnib commented on code in PR #45051:
URL: https://github.com/apache/spark/pull/45051#discussion_r1503581091


##########
sql/api/src/main/scala/org/apache/spark/sql/streaming/ExpiredTimerInfo.scala:
##########
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming
+
+import java.io.Serializable
+
+import org.apache.spark.annotation.{Evolving, Experimental}
+
+/**
+ * Class used to provide access to expired timer's expiry time and timeout 
mode. These values
+ * are only relevant if the ExpiredTimerInfo is valid.
+ */
+@Experimental
+@Evolving
+private[sql] trait ExpiredTimerInfo extends Serializable {
+  /**
+   * Check if provided ExpiredTimerInfo is valid.
+   */
+  def isValid(): Boolean
+
+  /**
+   * Get the expired timer's expiry time as milliseconds in epoch time.
+   */
+  def getExpiryTimeInMs(): Long
+
+  /**
+   * Get the expired timer's timeout mode.
+   */
+  def getTimeoutMode(): TimeoutMode

Review Comment:
   Would this ever be different than the timeout mode provided to 
`transformWithState` API? If no, do we need this here? 



##########
sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala:
##########
@@ -48,16 +48,19 @@ private[sql] trait StatefulProcessor[K, I, O] extends 
Serializable {
    * @param inputRows - iterator of input rows associated with grouping key
    * @param timerValues - instance of TimerValues that provides access to 
current processing/event
    *                    time if available
+   * @param expiredTimerInfo - instance of ExpiredTimerInfo that provides 
access to expired timer
+   *                         if applicable
    * @return - Zero or more output rows
    */
   def handleInputRows(
       key: K,
       inputRows: Iterator[I],
-      timerValues: TimerValues): Iterator[O]
+      timerValues: TimerValues,
+      expiredTimerInfo: ExpiredTimerInfo): Iterator[O]
 
   /**
    * Function called as the last method that allows for users to perform
    * any cleanup or teardown operations.
    */
-  def close (): Unit
+  def close (): Unit = {}

Review Comment:
   Nice idea to provide a default implementation here. 



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala:
##########
@@ -121,6 +123,42 @@ class StatefulProcessorHandleImpl(
 
   override def getQueryInfo(): QueryInfo = currQueryInfo
 
+  private def getTimerState[T](): TimerStateImpl[T] = {
+    new TimerStateImpl[T](store, timeoutMode, keyEncoder)
+  }
+
+  private val timerState = getTimerState[Boolean]()
+
+  override def registerTimer(expiryTimestampMs: Long): Unit = {
+    verify(timeoutMode == ProcessingTime || timeoutMode == EventTime,
+    s"Cannot register timers with incorrect TimeoutMode")
+    verify(currState == INITIALIZED || currState == DATA_PROCESSED,
+    s"Cannot register timers with " +
+      s"expiryTimestampMs=$expiryTimestampMs in current state=$currState")

Review Comment:
   We should use the NERF framework for these user errors. 



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala:
##########
@@ -121,6 +123,42 @@ class StatefulProcessorHandleImpl(
 
   override def getQueryInfo(): QueryInfo = currQueryInfo
 
+  private def getTimerState[T](): TimerStateImpl[T] = {
+    new TimerStateImpl[T](store, timeoutMode, keyEncoder)
+  }
+
+  private val timerState = getTimerState[Boolean]()
+
+  override def registerTimer(expiryTimestampMs: Long): Unit = {
+    verify(timeoutMode == ProcessingTime || timeoutMode == EventTime,

Review Comment:
   Shouldn't this be same as the timeoutMode in transformWithState API? 



##########
sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala:
##########
@@ -51,6 +51,25 @@ private[sql] trait StatefulProcessorHandle extends 
Serializable {
   /** Function to return queryInfo for currently running task */
   def getQueryInfo(): QueryInfo
 
+  /**
+   * Function to register a processing/event time based timer for given 
implicit key

Review Comment:
   [nit] `implicit key` -> `implicit grouping key`. 



##########
sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala:
##########
@@ -48,16 +48,19 @@ private[sql] trait StatefulProcessor[K, I, O] extends 
Serializable {
    * @param inputRows - iterator of input rows associated with grouping key
    * @param timerValues - instance of TimerValues that provides access to 
current processing/event
    *                    time if available
+   * @param expiredTimerInfo - instance of ExpiredTimerInfo that provides 
access to expired timer
+   *                         if applicable
    * @return - Zero or more output rows
    */
   def handleInputRows(
       key: K,
       inputRows: Iterator[I],
-      timerValues: TimerValues): Iterator[O]
+      timerValues: TimerValues,
+      expiredTimerInfo: ExpiredTimerInfo): Iterator[O]

Review Comment:
   Should we wrap this in a option, instead of having `isValid` inside 
`ExpiredTimerInfo`?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala:
##########
@@ -0,0 +1,218 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.streaming
+
+import java.io.Serializable
+import java.nio.{ByteBuffer, ByteOrder}
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.execution.streaming.state._
+import org.apache.spark.sql.streaming.TimeoutMode
+import org.apache.spark.sql.types._
+import org.apache.spark.util.NextIterator
+
+/**
+ * Singleton utils class used primarily while interacting with TimerState
+ */
+object TimerStateUtils {
+  case class TimestampWithKey(
+      key: Any,
+      expiryTimestampMs: Long) extends Serializable
+
+  val PROC_TIMERS_STATE_NAME = "_procTimers"
+  val EVENT_TIMERS_STATE_NAME = "_eventTimers"
+  val KEY_TO_TIMESTAMP_CF = "_keyToTimestamp"
+  val TIMESTAMP_TO_KEY_CF = "_timestampToKey"
+}
+
+/**
+ * Class that provides the implementation for storing timers
+ * used within the `transformWithState` operator.
+ * @param store - state store to be used for storing timer data
+ * @param timeoutMode - mode of timeout (event time or processing time)
+ * @param keyExprEnc - encoder for key expression
+ * @tparam S - type of timer value
+ */
+class TimerStateImpl[S](
+    store: StateStore,
+    timeoutMode: TimeoutMode,
+    keyExprEnc: ExpressionEncoder[Any]) extends Logging {
+
+  private val EMPTY_ROW =
+    
UnsafeProjection.create(Array[DataType](NullType)).apply(InternalRow.apply(null))
+
+  private val schemaForPrefixKey: StructType = new StructType()
+    .add("key", BinaryType)
+
+  private val schemaForKeyRow: StructType = new StructType()
+    .add("key", BinaryType)
+    .add("expiryTimestampMs", LongType, nullable = false)
+
+  private val keySchemaForSecIndex: StructType = new StructType()
+    .add("expiryTimestampMs", BinaryType, nullable = false)
+    .add("key", BinaryType)
+
+  private val schemaForValueRow: StructType =
+    StructType(Array(StructField("__dummy__", NullType)))
+
+  private val keySerializer = keyExprEnc.createSerializer()
+
+  private val prefixKeyEncoder = UnsafeProjection.create(schemaForPrefixKey)
+
+  private val keyEncoder = UnsafeProjection.create(schemaForKeyRow)
+
+  private val secIndexKeyEncoder = 
UnsafeProjection.create(keySchemaForSecIndex)
+
+  val timerCfName = if (timeoutMode == TimeoutMode.ProcessingTime) {
+    TimerStateUtils.PROC_TIMERS_STATE_NAME
+  } else {
+    TimerStateUtils.EVENT_TIMERS_STATE_NAME
+  }
+
+  val keyToTsCFName = timerCfName + TimerStateUtils.KEY_TO_TIMESTAMP_CF
+  store.createColFamilyIfAbsent(keyToTsCFName,
+    schemaForKeyRow, numColsPrefixKey = 1,
+    schemaForValueRow, useMultipleValuesPerKey = false,
+    isInternal = true)
+
+  val tsToKeyCFName = timerCfName + TimerStateUtils.TIMESTAMP_TO_KEY_CF
+  store.createColFamilyIfAbsent(tsToKeyCFName,
+    keySchemaForSecIndex, numColsPrefixKey = 0,
+    schemaForValueRow, useMultipleValuesPerKey = false,
+    isInternal = true)
+
+  private def encodeKey(expiryTimestampMs: Long): UnsafeRow = {
+    val keyOption = ImplicitGroupingKeyTracker.getImplicitKeyOption
+    if (!keyOption.isDefined) {
+      throw StateStoreErrors.implicitKeyNotFound(keyToTsCFName)
+    }
+
+    val keyByteArr = 
keySerializer.apply(keyOption.get).asInstanceOf[UnsafeRow].getBytes()
+    val keyRow = keyEncoder(InternalRow(keyByteArr, expiryTimestampMs))
+    keyRow
+  }
+
+  private def encodeSecIndexKey(expiryTimestampMs: Long): UnsafeRow = {
+    val keyOption = ImplicitGroupingKeyTracker.getImplicitKeyOption
+    if (!keyOption.isDefined) {
+      throw StateStoreErrors.implicitKeyNotFound(tsToKeyCFName)
+    }
+
+    val keyByteArr = 
keySerializer.apply(keyOption.get).asInstanceOf[UnsafeRow].getBytes()
+    val bbuf = ByteBuffer.allocate(8)
+    bbuf.order(ByteOrder.BIG_ENDIAN)
+    bbuf.putLong(expiryTimestampMs)

Review Comment:
   I can see why we are specifying endianness (as expired timestamp is encoded 
as binary) here. I have a few questions: 
   
   1. Why are we choosing to encode the expiryTimestampMs as Binary instead of 
Long? 
   2. Pardon my lack of knowledge about Spark's UnsafeRow here. AFAIK by 
default. RocksDB orders keys lexicographically. Is the encoding from UnsafeRow 
going to preserve ordering for timestamp? We rely on the ordering of keys to 
find expired timers in `getExpiredTimers`. 
   



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala:
##########
@@ -69,8 +70,20 @@ case class TransformWithStateExec(
 
   override def shortName: String = "transformWithStateExec"
 
-  // TODO: update this to run no-data batches when timer support is added
-  override def shouldRunAnotherBatch(newInputWatermark: Long): Boolean = false
+  override def shouldRunAnotherBatch(newInputWatermark: Long): Boolean = {
+    timeoutMode match {
+      // TODO: check if we can return true only if actual timers are registered

Review Comment:
   Would this be a simple check if there are any items in the timer column 
family?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to