anishshri-db commented on code in PR #53911:
URL: https://github.com/apache/spark/pull/53911#discussion_r2795020987


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala:
##########
@@ -1713,6 +1715,206 @@ class NoPrefixKeyStateEncoder(
   }
 }
 
+object TimestampKeyStateEncoder {
+  val INTERNAL_TIMESTAMP_COLUMN_NAME = "__event_time"
+
+  val SIGN_MASK_FOR_LONG: Long = 0x8000000000000000L
+
+  def finalKeySchema(keySchema: StructType): StructType = {
+    StructType(keySchema.fields)
+      .add(name = INTERNAL_TIMESTAMP_COLUMN_NAME, dataType = LongType, 
nullable = false)
+  }
+
+  def getByteBufferForBigEndianLong(): ByteBuffer = {
+    ByteBuffer.allocate(8).order(ByteOrder.BIG_ENDIAN)
+  }
+
+  def encodeTimestamp(buff: ByteBuffer, timestamp: Long): Array[Byte] = {
+    // Flip the sign bit to ensure correct lexicographical ordering, even for 
negative timestamps.
+    // We should flip the sign bit back when decoding the timestamp.
+    val signFlippedTimestamp = timestamp ^ 
TimestampKeyStateEncoder.SIGN_MASK_FOR_LONG
+    buff.putLong(0, signFlippedTimestamp)
+    buff.array()
+  }
+
+  def decodeTimestamp(buff: ByteBuffer, keyBytes: Array[Byte], startPos: Int): 
Long = {
+    buff.put(0, keyBytes, startPos, 8)
+    val signFlippedTimestamp = buff.getLong(0)
+    // Flip the sign bit back to get the original timestamp.
+    signFlippedTimestamp ^ TimestampKeyStateEncoder.SIGN_MASK_FOR_LONG
+  }
+}
+
+/**
+ * FIXME: doc...

Review Comment:
   pending ?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala:
##########
@@ -1713,6 +1715,206 @@ class NoPrefixKeyStateEncoder(
   }
 }
 
+object TimestampKeyStateEncoder {

Review Comment:
   Lets add some high level comments here ?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala:
##########
@@ -1713,6 +1715,206 @@ class NoPrefixKeyStateEncoder(
   }
 }
 
+object TimestampKeyStateEncoder {
+  val INTERNAL_TIMESTAMP_COLUMN_NAME = "__event_time"
+
+  val SIGN_MASK_FOR_LONG: Long = 0x8000000000000000L
+
+  def finalKeySchema(keySchema: StructType): StructType = {
+    StructType(keySchema.fields)
+      .add(name = INTERNAL_TIMESTAMP_COLUMN_NAME, dataType = LongType, 
nullable = false)
+  }
+
+  def getByteBufferForBigEndianLong(): ByteBuffer = {
+    ByteBuffer.allocate(8).order(ByteOrder.BIG_ENDIAN)
+  }
+
+  def encodeTimestamp(buff: ByteBuffer, timestamp: Long): Array[Byte] = {
+    // Flip the sign bit to ensure correct lexicographical ordering, even for 
negative timestamps.
+    // We should flip the sign bit back when decoding the timestamp.
+    val signFlippedTimestamp = timestamp ^ 
TimestampKeyStateEncoder.SIGN_MASK_FOR_LONG
+    buff.putLong(0, signFlippedTimestamp)
+    buff.array()
+  }
+
+  def decodeTimestamp(buff: ByteBuffer, keyBytes: Array[Byte], startPos: Int): 
Long = {
+    buff.put(0, keyBytes, startPos, 8)
+    val signFlippedTimestamp = buff.getLong(0)
+    // Flip the sign bit back to get the original timestamp.
+    signFlippedTimestamp ^ TimestampKeyStateEncoder.SIGN_MASK_FOR_LONG
+  }
+}
+
+/**
+ * FIXME: doc...
+ */
+class TimestampAsPrefixKeyStateEncoder(
+    dataEncoder: RocksDBDataEncoder,
+    keySchema: StructType,
+    useColumnFamilies: Boolean = false)
+  extends RocksDBKeyStateEncoder with Logging {
+
+  import TimestampKeyStateEncoder._
+  import org.apache.spark.sql.catalyst.types.DataTypeUtils
+
+  // keySchema includes the event time column as the last field, hence we 
remove it to project key.
+  private val keySchemaWithoutTimestampWithIdx: Seq[(StructField, Int)] = {
+    keySchema.zipWithIndex.dropRight(1)
+  }
+
+  private val keyWithoutTimestampProjection: UnsafeProjection = {
+    val refs = keySchemaWithoutTimestampWithIdx.map(x =>
+      BoundReference(x._2, x._1.dataType, x._1.nullable))
+    UnsafeProjection.create(refs)
+  }
+
+  private val keySchemaWithoutTimestampAttrs = DataTypeUtils.toAttributes(
+    StructType(keySchema.dropRight(1)))
+  private val keyWithTimestampProjection: UnsafeProjection = {
+    val refs = keySchema.zipWithIndex.map(x =>
+      BoundReference(x._2, x._1.dataType, x._1.nullable))
+    UnsafeProjection.create(
+      refs :+ Literal(0L), // placeholder for timestamp
+      keySchemaWithoutTimestampAttrs)
+  }
+
+  private def extractTimestamp(key: UnsafeRow): Long = {
+    key.getLong(key.numFields - 1)
+  }
+
+  override def supportPrefixKeyScan: Boolean = false
+
+  override def encodePrefixKey(prefixKey: UnsafeRow): Array[Byte] = {
+    throw new IllegalStateException("This encoder doesn't support key without 
event time!")
+  }
+
+  // NOTE: We reuse the ByteBuffer to avoid allocating a new one for every 
encoding/decoding,
+  // which means the encoder is not thread-safe. Built-in operators do not 
access the encoder in
+  // multiple threads, but if we are concerned about thread-safety in the 
future, we can maintain
+  // the thread-local of ByteBuffer to retain the reusability of the instance 
while avoiding
+  // thread-safety issue. We do not use position - we always put/get at offset 
0.
+  private val buffForBigEndianLong = getByteBufferForBigEndianLong()
+
+  override def encodeKey(row: UnsafeRow): Array[Byte] = {
+    val prefix = dataEncoder.encodeKey(keyWithoutTimestampProjection(row))
+    val timestamp = extractTimestamp(row)
+
+    val byteArray = new Array[Byte](prefix.length + 8)
+    Platform.copyMemory(
+      encodeTimestamp(buffForBigEndianLong, timestamp), 
Platform.BYTE_ARRAY_OFFSET,
+      byteArray, Platform.BYTE_ARRAY_OFFSET, 8)
+    Platform.copyMemory(prefix, Platform.BYTE_ARRAY_OFFSET,
+      byteArray, Platform.BYTE_ARRAY_OFFSET + 8, prefix.length)
+
+    byteArray
+  }
+
+  override def decodeKey(keyBytes: Array[Byte]): UnsafeRow = {
+    val timestamp = decodeTimestamp(buffForBigEndianLong, keyBytes, 0)
+
+    val rowBytesLength = keyBytes.length - 8
+    val rowBytes = new Array[Byte](rowBytesLength)
+    Platform.copyMemory(
+      keyBytes, Platform.BYTE_ARRAY_OFFSET + 8,
+      rowBytes, Platform.BYTE_ARRAY_OFFSET,
+      rowBytesLength
+    )
+    val row = dataEncoder.decodeToUnsafeRow(rowBytes, keySchema.length)
+
+    val rowWithTimestamp = keyWithTimestampProjection(row)
+    rowWithTimestamp.setLong(keySchema.length - 1, timestamp)
+    rowWithTimestamp
+  }
+
+  // TODO: Revisit this to support delete range if needed.
+  override def supportsDeleteRange: Boolean = false
+}
+
+/**
+ * FIXME: doc...

Review Comment:
   pending ?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala:
##########
@@ -1713,6 +1715,206 @@ class NoPrefixKeyStateEncoder(
   }
 }
 
+object TimestampKeyStateEncoder {
+  val INTERNAL_TIMESTAMP_COLUMN_NAME = "__event_time"
+
+  val SIGN_MASK_FOR_LONG: Long = 0x8000000000000000L
+
+  def finalKeySchema(keySchema: StructType): StructType = {
+    StructType(keySchema.fields)
+      .add(name = INTERNAL_TIMESTAMP_COLUMN_NAME, dataType = LongType, 
nullable = false)
+  }
+
+  def getByteBufferForBigEndianLong(): ByteBuffer = {
+    ByteBuffer.allocate(8).order(ByteOrder.BIG_ENDIAN)
+  }
+
+  def encodeTimestamp(buff: ByteBuffer, timestamp: Long): Array[Byte] = {
+    // Flip the sign bit to ensure correct lexicographical ordering, even for 
negative timestamps.
+    // We should flip the sign bit back when decoding the timestamp.
+    val signFlippedTimestamp = timestamp ^ 
TimestampKeyStateEncoder.SIGN_MASK_FOR_LONG
+    buff.putLong(0, signFlippedTimestamp)
+    buff.array()
+  }
+
+  def decodeTimestamp(buff: ByteBuffer, keyBytes: Array[Byte], startPos: Int): 
Long = {
+    buff.put(0, keyBytes, startPos, 8)
+    val signFlippedTimestamp = buff.getLong(0)
+    // Flip the sign bit back to get the original timestamp.
+    signFlippedTimestamp ^ TimestampKeyStateEncoder.SIGN_MASK_FOR_LONG
+  }
+}
+
+/**
+ * FIXME: doc...
+ */
+class TimestampAsPrefixKeyStateEncoder(
+    dataEncoder: RocksDBDataEncoder,
+    keySchema: StructType,
+    useColumnFamilies: Boolean = false)
+  extends RocksDBKeyStateEncoder with Logging {
+
+  import TimestampKeyStateEncoder._
+  import org.apache.spark.sql.catalyst.types.DataTypeUtils
+
+  // keySchema includes the event time column as the last field, hence we 
remove it to project key.
+  private val keySchemaWithoutTimestampWithIdx: Seq[(StructField, Int)] = {
+    keySchema.zipWithIndex.dropRight(1)
+  }
+
+  private val keyWithoutTimestampProjection: UnsafeProjection = {
+    val refs = keySchemaWithoutTimestampWithIdx.map(x =>
+      BoundReference(x._2, x._1.dataType, x._1.nullable))
+    UnsafeProjection.create(refs)
+  }
+
+  private val keySchemaWithoutTimestampAttrs = DataTypeUtils.toAttributes(
+    StructType(keySchema.dropRight(1)))
+  private val keyWithTimestampProjection: UnsafeProjection = {
+    val refs = keySchema.zipWithIndex.map(x =>
+      BoundReference(x._2, x._1.dataType, x._1.nullable))
+    UnsafeProjection.create(
+      refs :+ Literal(0L), // placeholder for timestamp
+      keySchemaWithoutTimestampAttrs)
+  }
+
+  private def extractTimestamp(key: UnsafeRow): Long = {
+    key.getLong(key.numFields - 1)
+  }
+
+  override def supportPrefixKeyScan: Boolean = false
+
+  override def encodePrefixKey(prefixKey: UnsafeRow): Array[Byte] = {
+    throw new IllegalStateException("This encoder doesn't support key without 
event time!")
+  }
+
+  // NOTE: We reuse the ByteBuffer to avoid allocating a new one for every 
encoding/decoding,
+  // which means the encoder is not thread-safe. Built-in operators do not 
access the encoder in
+  // multiple threads, but if we are concerned about thread-safety in the 
future, we can maintain
+  // the thread-local of ByteBuffer to retain the reusability of the instance 
while avoiding
+  // thread-safety issue. We do not use position - we always put/get at offset 
0.
+  private val buffForBigEndianLong = getByteBufferForBigEndianLong()
+
+  override def encodeKey(row: UnsafeRow): Array[Byte] = {
+    val prefix = dataEncoder.encodeKey(keyWithoutTimestampProjection(row))
+    val timestamp = extractTimestamp(row)
+
+    val byteArray = new Array[Byte](prefix.length + 8)
+    Platform.copyMemory(
+      encodeTimestamp(buffForBigEndianLong, timestamp), 
Platform.BYTE_ARRAY_OFFSET,
+      byteArray, Platform.BYTE_ARRAY_OFFSET, 8)
+    Platform.copyMemory(prefix, Platform.BYTE_ARRAY_OFFSET,
+      byteArray, Platform.BYTE_ARRAY_OFFSET + 8, prefix.length)
+
+    byteArray
+  }
+
+  override def decodeKey(keyBytes: Array[Byte]): UnsafeRow = {
+    val timestamp = decodeTimestamp(buffForBigEndianLong, keyBytes, 0)
+
+    val rowBytesLength = keyBytes.length - 8
+    val rowBytes = new Array[Byte](rowBytesLength)
+    Platform.copyMemory(
+      keyBytes, Platform.BYTE_ARRAY_OFFSET + 8,
+      rowBytes, Platform.BYTE_ARRAY_OFFSET,
+      rowBytesLength
+    )
+    val row = dataEncoder.decodeToUnsafeRow(rowBytes, keySchema.length)
+
+    val rowWithTimestamp = keyWithTimestampProjection(row)
+    rowWithTimestamp.setLong(keySchema.length - 1, timestamp)
+    rowWithTimestamp
+  }
+
+  // TODO: Revisit this to support delete range if needed.

Review Comment:
   could we add a SPARK JIRA for this then ?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala:
##########
@@ -1713,6 +1715,206 @@ class NoPrefixKeyStateEncoder(
   }
 }
 
+object TimestampKeyStateEncoder {
+  val INTERNAL_TIMESTAMP_COLUMN_NAME = "__event_time"
+
+  val SIGN_MASK_FOR_LONG: Long = 0x8000000000000000L
+
+  def finalKeySchema(keySchema: StructType): StructType = {

Review Comment:
   Does it need to be public ? if so, can we rename the function ?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala:
##########
@@ -1713,6 +1715,206 @@ class NoPrefixKeyStateEncoder(
   }
 }
 
+object TimestampKeyStateEncoder {
+  val INTERNAL_TIMESTAMP_COLUMN_NAME = "__event_time"
+
+  val SIGN_MASK_FOR_LONG: Long = 0x8000000000000000L
+
+  def finalKeySchema(keySchema: StructType): StructType = {
+    StructType(keySchema.fields)
+      .add(name = INTERNAL_TIMESTAMP_COLUMN_NAME, dataType = LongType, 
nullable = false)
+  }
+
+  def getByteBufferForBigEndianLong(): ByteBuffer = {
+    ByteBuffer.allocate(8).order(ByteOrder.BIG_ENDIAN)
+  }
+
+  def encodeTimestamp(buff: ByteBuffer, timestamp: Long): Array[Byte] = {
+    // Flip the sign bit to ensure correct lexicographical ordering, even for 
negative timestamps.
+    // We should flip the sign bit back when decoding the timestamp.
+    val signFlippedTimestamp = timestamp ^ 
TimestampKeyStateEncoder.SIGN_MASK_FOR_LONG
+    buff.putLong(0, signFlippedTimestamp)
+    buff.array()
+  }
+
+  def decodeTimestamp(buff: ByteBuffer, keyBytes: Array[Byte], startPos: Int): 
Long = {
+    buff.put(0, keyBytes, startPos, 8)
+    val signFlippedTimestamp = buff.getLong(0)
+    // Flip the sign bit back to get the original timestamp.
+    signFlippedTimestamp ^ TimestampKeyStateEncoder.SIGN_MASK_FOR_LONG
+  }
+}
+
+/**
+ * FIXME: doc...
+ */
+class TimestampAsPrefixKeyStateEncoder(
+    dataEncoder: RocksDBDataEncoder,
+    keySchema: StructType,
+    useColumnFamilies: Boolean = false)
+  extends RocksDBKeyStateEncoder with Logging {
+
+  import TimestampKeyStateEncoder._
+  import org.apache.spark.sql.catalyst.types.DataTypeUtils
+
+  // keySchema includes the event time column as the last field, hence we 
remove it to project key.
+  private val keySchemaWithoutTimestampWithIdx: Seq[(StructField, Int)] = {
+    keySchema.zipWithIndex.dropRight(1)
+  }
+
+  private val keyWithoutTimestampProjection: UnsafeProjection = {
+    val refs = keySchemaWithoutTimestampWithIdx.map(x =>
+      BoundReference(x._2, x._1.dataType, x._1.nullable))
+    UnsafeProjection.create(refs)
+  }
+
+  private val keySchemaWithoutTimestampAttrs = DataTypeUtils.toAttributes(
+    StructType(keySchema.dropRight(1)))
+  private val keyWithTimestampProjection: UnsafeProjection = {
+    val refs = keySchema.zipWithIndex.map(x =>
+      BoundReference(x._2, x._1.dataType, x._1.nullable))
+    UnsafeProjection.create(
+      refs :+ Literal(0L), // placeholder for timestamp
+      keySchemaWithoutTimestampAttrs)
+  }
+
+  private def extractTimestamp(key: UnsafeRow): Long = {
+    key.getLong(key.numFields - 1)
+  }
+
+  override def supportPrefixKeyScan: Boolean = false
+
+  override def encodePrefixKey(prefixKey: UnsafeRow): Array[Byte] = {
+    throw new IllegalStateException("This encoder doesn't support key without 
event time!")
+  }
+
+  // NOTE: We reuse the ByteBuffer to avoid allocating a new one for every 
encoding/decoding,
+  // which means the encoder is not thread-safe. Built-in operators do not 
access the encoder in
+  // multiple threads, but if we are concerned about thread-safety in the 
future, we can maintain
+  // the thread-local of ByteBuffer to retain the reusability of the instance 
while avoiding
+  // thread-safety issue. We do not use position - we always put/get at offset 
0.
+  private val buffForBigEndianLong = getByteBufferForBigEndianLong()
+
+  override def encodeKey(row: UnsafeRow): Array[Byte] = {
+    val prefix = dataEncoder.encodeKey(keyWithoutTimestampProjection(row))
+    val timestamp = extractTimestamp(row)
+
+    val byteArray = new Array[Byte](prefix.length + 8)
+    Platform.copyMemory(
+      encodeTimestamp(buffForBigEndianLong, timestamp), 
Platform.BYTE_ARRAY_OFFSET,
+      byteArray, Platform.BYTE_ARRAY_OFFSET, 8)
+    Platform.copyMemory(prefix, Platform.BYTE_ARRAY_OFFSET,
+      byteArray, Platform.BYTE_ARRAY_OFFSET + 8, prefix.length)
+
+    byteArray
+  }
+
+  override def decodeKey(keyBytes: Array[Byte]): UnsafeRow = {
+    val timestamp = decodeTimestamp(buffForBigEndianLong, keyBytes, 0)
+
+    val rowBytesLength = keyBytes.length - 8
+    val rowBytes = new Array[Byte](rowBytesLength)
+    Platform.copyMemory(
+      keyBytes, Platform.BYTE_ARRAY_OFFSET + 8,
+      rowBytes, Platform.BYTE_ARRAY_OFFSET,
+      rowBytesLength
+    )
+    val row = dataEncoder.decodeToUnsafeRow(rowBytes, keySchema.length)
+
+    val rowWithTimestamp = keyWithTimestampProjection(row)
+    rowWithTimestamp.setLong(keySchema.length - 1, timestamp)
+    rowWithTimestamp
+  }
+
+  // TODO: Revisit this to support delete range if needed.
+  override def supportsDeleteRange: Boolean = false
+}
+
+/**
+ * FIXME: doc...
+ */
+class TimestampAsPostfixKeyStateEncoder(
+    dataEncoder: RocksDBDataEncoder,
+    keySchema: StructType,
+    useColumnFamilies: Boolean = false)
+  extends RocksDBKeyStateEncoder with Logging {
+
+  import TimestampKeyStateEncoder._
+  import org.apache.spark.sql.catalyst.types.DataTypeUtils
+
+  // keySchema includes the event time column as the last field, hence we 
remove it to project key.
+  private val keySchemaWithoutTimestampWithIdx: Seq[(StructField, Int)] = {
+    keySchema.zipWithIndex.dropRight(1)
+  }
+
+  private val keyWithoutTimestampProjection: UnsafeProjection = {
+    val refs = keySchemaWithoutTimestampWithIdx.map(x =>
+      BoundReference(x._2, x._1.dataType, x._1.nullable))
+    UnsafeProjection.create(refs)
+  }
+
+  private val keySchemaWithoutTimestampAttrs = DataTypeUtils.toAttributes(
+    StructType(keySchema.dropRight(1)))

Review Comment:
   The implementations of the functions in both the encoders seem similar. Any 
common code we can reuse across both ?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala:
##########
@@ -846,6 +846,8 @@ class AvroStateEncoder(
           }
         }
         StructType(remainingSchema)
+      case _ =>
+        throw unsupportedOperationForKeyStateEncoder("createAvroEnc")

Review Comment:
   Can we improve the passed arg/error message ?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to