eason-yuchen-liu commented on code in PR #53930:
URL: https://github.com/apache/spark/pull/53930#discussion_r2843292229
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala:
##########
@@ -27,16 +27,676 @@ import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
import org.apache.spark.internal.LogKeys.{END_INDEX, START_INDEX,
STATE_STORE_ID}
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute,
AttributeReference, Expression, JoinedRow, Literal, SafeProjection,
SpecificInternalRow, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.{Attribute,
AttributeReference, Expression, JoinedRow, Literal, NamedExpression,
SafeProjection, SpecificInternalRow, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.execution.metric.SQLMetric
-import
org.apache.spark.sql.execution.streaming.operators.stateful.StatefulOperatorStateInfo
-import
org.apache.spark.sql.execution.streaming.operators.stateful.StatefulOpStateStoreCheckpointInfo
+import
org.apache.spark.sql.execution.streaming.operators.stateful.{StatefulOperatorStateInfo,
StatefulOpStateStoreCheckpointInfo, WatermarkSupport}
import
org.apache.spark.sql.execution.streaming.operators.stateful.join.StreamingSymmetricHashJoinHelper._
-import
org.apache.spark.sql.execution.streaming.state.{DropLastNFieldsStatePartitionKeyExtractor,
KeyStateEncoderSpec, NoopStatePartitionKeyExtractor,
NoPrefixKeyStateEncoderSpec, StatePartitionKeyExtractor, StateSchemaBroadcast,
StateStore, StateStoreCheckpointInfo, StateStoreColFamilySchema,
StateStoreConf, StateStoreErrors, StateStoreId, StateStoreMetrics,
StateStoreProvider, StateStoreProviderId, SupportsFineGrainedReplay}
-import org.apache.spark.sql.types.{BooleanType, LongType, StructField,
StructType}
+import
org.apache.spark.sql.execution.streaming.state.{DropLastNFieldsStatePartitionKeyExtractor,
KeyStateEncoderSpec, NoopStatePartitionKeyExtractor,
NoPrefixKeyStateEncoderSpec, StatePartitionKeyExtractor, StateSchemaBroadcast,
StateStore, StateStoreCheckpointInfo, StateStoreColFamilySchema,
StateStoreConf, StateStoreErrors, StateStoreId, StateStoreMetrics,
StateStoreProvider, StateStoreProviderId, SupportsFineGrainedReplay,
TimestampAsPostfixKeyStateEncoderSpec, TimestampAsPrefixKeyStateEncoderSpec,
TimestampKeyStateEncoder}
+import org.apache.spark.sql.types.{BooleanType, DataType, LongType, NullType,
StructField, StructType}
import org.apache.spark.util.NextIterator
+/**
+ * Base trait of the state manager for stream-stream symmetric hash join
operator.
+ *
+ * This defines the basic APIs for the state manager, except the methods for
eviction which are
+ * defined in separate traits - See [[SupportsEvictByCondition]] and
[[SupportsEvictByTimestamp]].
+ *
+ * Implementation classes are expected to inherit those traits as needed,
depending on the eviction
+ * strategy they support.
+ */
+trait SymmetricHashJoinStateManager {
+ import SymmetricHashJoinStateManager._
+
+ def append(key: UnsafeRow, value: UnsafeRow, matched: Boolean): Unit
+
+ def get(key: UnsafeRow): Iterator[UnsafeRow]
+
+ def getJoinedRows(
+ key: UnsafeRow,
+ generateJoinedRow: InternalRow => JoinedRow,
+ predicate: JoinedRow => Boolean,
+ excludeRowsAlreadyMatched: Boolean = false): Iterator[JoinedRow]
+
+ def iterator: Iterator[KeyToValuePair]
+
+ def commit(): Unit
+
+ def abortIfNeeded(): Unit
+
+ def metrics: StateStoreMetrics
+
+ def getLatestCheckpointInfo(): JoinerStateStoreCkptInfo
+}
+
+/**
+ * This trait is specific to help the old version of state manager
implementation (v1-v3) to work
+ * with existing tests which look up the state store with key with index.
+ */
+trait SupportsIndexedKeys {
+ def getInternalRowOfKeyWithIndex(currentKey: UnsafeRow): InternalRow
+
+ protected[streaming] def updateNumValuesTestOnly(key: UnsafeRow, numValues:
Long): Unit
+}
+
+/**
+ * This trait is for state manager implementations that support eviction by
condition.
+ * This is for the state manager implementations which have to perform full
scan
+ * for eviction.
+ */
+trait SupportsEvictByCondition { self: SymmetricHashJoinStateManager =>
+ import SymmetricHashJoinStateManager._
+
+ def evictByKeyCondition(removalCondition: UnsafeRow => Boolean): Long
+
+ def evictAndReturnByKeyCondition(
+ removalCondition: UnsafeRow => Boolean): Iterator[KeyToValuePair]
+
+ def evictByValueCondition(removalCondition: UnsafeRow => Boolean): Long
+
+ def evictAndReturnByValueCondition(
+ removalCondition: UnsafeRow => Boolean): Iterator[KeyToValuePair]
+}
+
+/**
+ * This trait is for state manager implementations that support eviction by
timestamp. This is for
+ * the state manager implementations which maintain the state with event time
and can efficiently
+ * scan the keys with event time smaller than the given timestamp for eviction.
+ */
+trait SupportsEvictByTimestamp { self: SymmetricHashJoinStateManager =>
+ import SymmetricHashJoinStateManager._
+
+ def evictByTimestamp(endTimestamp: Long): Long
+
+ def evictAndReturnByTimestamp(endTimestamp: Long): Iterator[KeyToValuePair]
+}
+
+/**
+ * The version 4 of stream-stream join state manager implementation, which is
designed to optimize
+ * the eviction with watermark. Previous versions require full scan to find
the keys to evict,
+ * while this version only scans the keys with event time smaller than the
watermark.
+ *
+ * In this implementation, we no longer build a logical array of values;
instead, we store the
+ * (key, timestamp) -> values in the primary store, and maintain a secondary
index of
+ * (timestamp, key) to scan the keys to evict for each watermark. To retrieve
the values for a key,
+ * we perform prefix scan with the key to get all the (key, timestamp) ->
values.
+ *
+ * Refer to the [[KeyWithTsToValuesStore]] and [[TsWithKeyTypeStore]] for more
details.
+ */
+class SymmetricHashJoinStateManagerV4(
+ joinSide: JoinSide,
+ inputValueAttributes: Seq[Attribute],
+ joinKeys: Seq[Expression],
+ stateInfo: Option[StatefulOperatorStateInfo],
+ storeConf: StateStoreConf,
+ hadoopConf: Configuration,
+ partitionId: Int,
+ keyToNumValuesStateStoreCkptId: Option[String],
+ keyWithIndexToValueStateStoreCkptId: Option[String],
+ stateFormatVersion: Int,
+ skippedNullValueCount: Option[SQLMetric] = None,
+ useStateStoreCoordinator: Boolean = true,
+ snapshotOptions: Option[SnapshotOptions] = None,
+ joinStoreGenerator: JoinStateManagerStoreGenerator)
+ extends SymmetricHashJoinStateManager with SupportsEvictByTimestamp with
Logging {
+
+ import SymmetricHashJoinStateManager._
+
+ protected val keySchema = StructType(
+ joinKeys.zipWithIndex.map { case (k, i) => StructField(s"field$i",
k.dataType, k.nullable) })
+ protected val keyAttributes = toAttributes(keySchema)
+ private val eventTimeColIdxOpt = WatermarkSupport.findEventTimeColumnIndex(
+ inputValueAttributes,
+ // NOTE: This does not accept multiple event time columns. This is not the
same with the
+ // operator which we offer the backward compatibility, but it involves too
many layers to
+ // pass the information. The information is in SQLConf.
+ allowMultipleEventTimeColumns = false)
+
+ private val random = new scala.util.Random(System.currentTimeMillis())
+ private val bucketSizeForNoEventTime = 1024
+ private val extractEventTimeFn: UnsafeRow => Long = { row =>
+ eventTimeColIdxOpt match {
+ case Some(idx) =>
+ val attr = inputValueAttributes(idx)
+
+ if (attr.dataType.isInstanceOf[StructType]) {
+ // NOTE: We assume this is window struct, as same as
WatermarkSupport.watermarkExpression
+ row.getStruct(idx, 2).getLong(1)
+ } else {
+ row.getLong(idx)
+ }
+
+ case _ =>
+ // When event time column is not available, we will use random
bucketing strategy to decide
+ // where the new value will be stored. There is a trade-off between
the bucket size and the
+ // number of values in each bucket; we can tune the bucket size with
the configuration if
+ // we figure out the magic number to not work well.
+ random.nextInt(bucketSizeForNoEventTime)
+ }
+ }
+
+ private val eventTimeColIdxOptInKey: Option[Int] = {
+ joinKeys.zipWithIndex.collectFirst {
+ case (ne: NamedExpression, index)
+ if ne.metadata.contains(EventTimeWatermark.delayKey) => index
+ }
+ }
+
+ private val extractEventTimeFnFromKey: UnsafeRow => Option[Long] = { row =>
+ eventTimeColIdxOptInKey.map { idx =>
+ val attr = keyAttributes(idx)
+ if (attr.dataType.isInstanceOf[StructType]) {
+ // NOTE: We assume this is window struct, as same as
WatermarkSupport.watermarkExpression
+ row.getStruct(idx, 2).getLong(1)
+ } else {
+ row.getLong(idx)
+ }
+ }
+ }
+
+ private val dummySchema = StructType(
+ Seq(StructField("dummy", NullType, nullable = true))
+ )
+
+ // TODO: [SPARK-55628] Below two fields need to be handled properly during
integration with
+ // the operator.
+ private val stateStoreCkptId: Option[String] = None
+ private val handlerSnapshotOptions: Option[HandlerSnapshotOptions] = None
+
+ private var stateStoreProvider: StateStoreProvider = _
+
+ // We will use the dummy schema for the default CF since we will register CF
separately.
+ private val stateStore = getStateStore(
+ dummySchema, dummySchema, useVirtualColumnFamilies = true,
+ NoPrefixKeyStateEncoderSpec(dummySchema), useMultipleValuesPerKey = false
+ )
+
+ private def getStateStore(
+ keySchema: StructType,
+ valueSchema: StructType,
+ useVirtualColumnFamilies: Boolean,
+ keyStateEncoderSpec: KeyStateEncoderSpec,
+ useMultipleValuesPerKey: Boolean): StateStore = {
+ val storeName = StateStoreId.DEFAULT_STORE_NAME
+ val storeProviderId = StateStoreProviderId(stateInfo.get, partitionId,
storeName)
+ val store = if (useStateStoreCoordinator) {
+ assert(handlerSnapshotOptions.isEmpty, "Should not use state store
coordinator " +
+ "when reading state as data source.")
+ joinStoreGenerator.getStore(
+ storeProviderId, keySchema, valueSchema, keyStateEncoderSpec,
+ stateInfo.get.storeVersion, stateStoreCkptId, None,
useVirtualColumnFamilies,
+ useMultipleValuesPerKey, storeConf, hadoopConf)
+ } else {
+ // This class will manage the state store provider by itself.
+ stateStoreProvider = StateStoreProvider.createAndInit(
+ storeProviderId, keySchema, valueSchema, keyStateEncoderSpec,
+ useColumnFamilies = useVirtualColumnFamilies,
+ storeConf, hadoopConf, useMultipleValuesPerKey =
useMultipleValuesPerKey,
+ stateSchemaProvider = None)
+ if (handlerSnapshotOptions.isDefined) {
+ if (!stateStoreProvider.isInstanceOf[SupportsFineGrainedReplay]) {
+ throw
StateStoreErrors.stateStoreProviderDoesNotSupportFineGrainedReplay(
+ stateStoreProvider.getClass.toString)
+ }
+ val opts = handlerSnapshotOptions.get
+ stateStoreProvider.asInstanceOf[SupportsFineGrainedReplay]
+ .replayStateFromSnapshot(
+ opts.snapshotVersion,
+ opts.endVersion,
+ readOnly = true,
+ opts.startStateStoreCkptId,
+ opts.endStateStoreCkptId)
+ } else {
+ stateStoreProvider.getStore(stateInfo.get.storeVersion,
stateStoreCkptId)
+ }
+ }
+ logInfo(log"Loaded store ${MDC(STATE_STORE_ID, store.id)}")
+ store
+ }
+
+ private val keyWithTsToValues = new KeyWithTsToValuesStore
+
+ private val tsWithKey = new TsWithKeyTypeStore
+
+ override def append(key: UnsafeRow, value: UnsafeRow, matched: Boolean):
Unit = {
+ val eventTime = extractEventTimeFn(value)
+ // We always do blind merge for appending new value.
+ keyWithTsToValues.append(key, eventTime, value, matched)
+ tsWithKey.add(eventTime, key)
+ }
+
+ override def getJoinedRows(
+ key: UnsafeRow,
+ generateJoinedRow: InternalRow => JoinedRow,
+ predicate: JoinedRow => Boolean,
+ excludeRowsAlreadyMatched: Boolean): Iterator[JoinedRow] = {
+ // TODO: [SPARK-55147] We could improve this method to get the scope of
timestamp and scan keys
+ // more efficiently. For now, we just get all values for the key.
+ def getJoinedRowsFromTsAndValues(
+ ts: Long,
+ valuesAndMatched: Array[ValueAndMatchPair]): Iterator[JoinedRow] = {
+ new NextIterator[JoinedRow] {
+ private var currentIndex = 0
+
+ private var shouldUpdateValuesIntoStateStore = false
+
+ override protected def getNext(): JoinedRow = {
+ var ret: JoinedRow = null
+ while (ret == null && currentIndex < valuesAndMatched.length) {
+ val vmp = valuesAndMatched(currentIndex)
+
+ if (excludeRowsAlreadyMatched && vmp.matched) {
+ // Skip this one
+ } else {
+ val joinedRow = generateJoinedRow(vmp.value)
+ if (predicate(joinedRow)) {
+ if (!vmp.matched) {
+ // Update the array to contain the value having matched =
true
+ valuesAndMatched(currentIndex) = vmp.copy(matched = true)
+ // Need to update matched flag
+ shouldUpdateValuesIntoStateStore = true
+ }
+
+ ret = joinedRow
+ } else {
+ // skip this one
+ }
+ }
+
+ currentIndex += 1
+ }
+
+ if (ret == null) {
+ assert(currentIndex == valuesAndMatched.length)
+ finished = true
+ null
+ } else {
+ ret
+ }
+ }
+
Review Comment:
Let's also add the comment here about consuming the entire iterator to
trigger cleanup. Or is it always guaranteed that this iterator will be consumed?
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala:
##########
@@ -27,16 +27,676 @@ import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
import org.apache.spark.internal.LogKeys.{END_INDEX, START_INDEX,
STATE_STORE_ID}
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute,
AttributeReference, Expression, JoinedRow, Literal, SafeProjection,
SpecificInternalRow, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.{Attribute,
AttributeReference, Expression, JoinedRow, Literal, NamedExpression,
SafeProjection, SpecificInternalRow, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.execution.metric.SQLMetric
-import
org.apache.spark.sql.execution.streaming.operators.stateful.StatefulOperatorStateInfo
-import
org.apache.spark.sql.execution.streaming.operators.stateful.StatefulOpStateStoreCheckpointInfo
+import
org.apache.spark.sql.execution.streaming.operators.stateful.{StatefulOperatorStateInfo,
StatefulOpStateStoreCheckpointInfo, WatermarkSupport}
import
org.apache.spark.sql.execution.streaming.operators.stateful.join.StreamingSymmetricHashJoinHelper._
-import
org.apache.spark.sql.execution.streaming.state.{DropLastNFieldsStatePartitionKeyExtractor,
KeyStateEncoderSpec, NoopStatePartitionKeyExtractor,
NoPrefixKeyStateEncoderSpec, StatePartitionKeyExtractor, StateSchemaBroadcast,
StateStore, StateStoreCheckpointInfo, StateStoreColFamilySchema,
StateStoreConf, StateStoreErrors, StateStoreId, StateStoreMetrics,
StateStoreProvider, StateStoreProviderId, SupportsFineGrainedReplay}
-import org.apache.spark.sql.types.{BooleanType, LongType, StructField,
StructType}
+import
org.apache.spark.sql.execution.streaming.state.{DropLastNFieldsStatePartitionKeyExtractor,
KeyStateEncoderSpec, NoopStatePartitionKeyExtractor,
NoPrefixKeyStateEncoderSpec, StatePartitionKeyExtractor, StateSchemaBroadcast,
StateStore, StateStoreCheckpointInfo, StateStoreColFamilySchema,
StateStoreConf, StateStoreErrors, StateStoreId, StateStoreMetrics,
StateStoreProvider, StateStoreProviderId, SupportsFineGrainedReplay,
TimestampAsPostfixKeyStateEncoderSpec, TimestampAsPrefixKeyStateEncoderSpec,
TimestampKeyStateEncoder}
+import org.apache.spark.sql.types.{BooleanType, DataType, LongType, NullType,
StructField, StructType}
import org.apache.spark.util.NextIterator
+/**
+ * Base trait of the state manager for stream-stream symmetric hash join
operator.
+ *
+ * This defines the basic APIs for the state manager, except the methods for
eviction which are
+ * defined in separate traits - See [[SupportsEvictByCondition]] and
[[SupportsEvictByTimestamp]].
+ *
+ * Implementation classes are expected to inherit those traits as needed,
depending on the eviction
+ * strategy they support.
+ */
+trait SymmetricHashJoinStateManager {
+ import SymmetricHashJoinStateManager._
+
+ def append(key: UnsafeRow, value: UnsafeRow, matched: Boolean): Unit
+
+ def get(key: UnsafeRow): Iterator[UnsafeRow]
+
+ def getJoinedRows(
+ key: UnsafeRow,
+ generateJoinedRow: InternalRow => JoinedRow,
+ predicate: JoinedRow => Boolean,
+ excludeRowsAlreadyMatched: Boolean = false): Iterator[JoinedRow]
+
+ def iterator: Iterator[KeyToValuePair]
+
+ def commit(): Unit
+
+ def abortIfNeeded(): Unit
+
+ def metrics: StateStoreMetrics
+
+ def getLatestCheckpointInfo(): JoinerStateStoreCkptInfo
+}
+
+/**
+ * This trait is specific to help the old version of state manager
implementation (v1-v3) to work
+ * with existing tests which look up the state store with key with index.
+ */
+trait SupportsIndexedKeys {
+ def getInternalRowOfKeyWithIndex(currentKey: UnsafeRow): InternalRow
+
+ protected[streaming] def updateNumValuesTestOnly(key: UnsafeRow, numValues:
Long): Unit
+}
+
+/**
+ * This trait is for state manager implementations that support eviction by
condition.
+ * This is for the state manager implementations which have to perform full
scan
+ * for eviction.
+ */
+trait SupportsEvictByCondition { self: SymmetricHashJoinStateManager =>
+ import SymmetricHashJoinStateManager._
+
+ def evictByKeyCondition(removalCondition: UnsafeRow => Boolean): Long
+
+ def evictAndReturnByKeyCondition(
+ removalCondition: UnsafeRow => Boolean): Iterator[KeyToValuePair]
+
+ def evictByValueCondition(removalCondition: UnsafeRow => Boolean): Long
+
+ def evictAndReturnByValueCondition(
+ removalCondition: UnsafeRow => Boolean): Iterator[KeyToValuePair]
+}
+
+/**
+ * This trait is for state manager implementations that support eviction by
timestamp. This is for
+ * the state manager implementations which maintain the state with event time
and can efficiently
+ * scan the keys with event time smaller than the given timestamp for eviction.
+ */
+trait SupportsEvictByTimestamp { self: SymmetricHashJoinStateManager =>
+ import SymmetricHashJoinStateManager._
+
+ def evictByTimestamp(endTimestamp: Long): Long
+
+ def evictAndReturnByTimestamp(endTimestamp: Long): Iterator[KeyToValuePair]
+}
+
+/**
+ * The version 4 of stream-stream join state manager implementation, which is
designed to optimize
+ * the eviction with watermark. Previous versions require full scan to find
the keys to evict,
+ * while this version only scans the keys with event time smaller than the
watermark.
+ *
+ * In this implementation, we no longer build a logical array of values;
instead, we store the
+ * (key, timestamp) -> values in the primary store, and maintain a secondary
index of
+ * (timestamp, key) to scan the keys to evict for each watermark. To retrieve
the values for a key,
+ * we perform prefix scan with the key to get all the (key, timestamp) ->
values.
+ *
+ * Refer to the [[KeyWithTsToValuesStore]] and [[TsWithKeyTypeStore]] for more
details.
+ */
+class SymmetricHashJoinStateManagerV4(
+ joinSide: JoinSide,
+ inputValueAttributes: Seq[Attribute],
+ joinKeys: Seq[Expression],
+ stateInfo: Option[StatefulOperatorStateInfo],
+ storeConf: StateStoreConf,
+ hadoopConf: Configuration,
+ partitionId: Int,
+ keyToNumValuesStateStoreCkptId: Option[String],
+ keyWithIndexToValueStateStoreCkptId: Option[String],
+ stateFormatVersion: Int,
+ skippedNullValueCount: Option[SQLMetric] = None,
+ useStateStoreCoordinator: Boolean = true,
+ snapshotOptions: Option[SnapshotOptions] = None,
+ joinStoreGenerator: JoinStateManagerStoreGenerator)
+ extends SymmetricHashJoinStateManager with SupportsEvictByTimestamp with
Logging {
+
+ import SymmetricHashJoinStateManager._
+
+ protected val keySchema = StructType(
+ joinKeys.zipWithIndex.map { case (k, i) => StructField(s"field$i",
k.dataType, k.nullable) })
+ protected val keyAttributes = toAttributes(keySchema)
+ private val eventTimeColIdxOpt = WatermarkSupport.findEventTimeColumnIndex(
+ inputValueAttributes,
+ // NOTE: This does not accept multiple event time columns. This is not the
same with the
+ // operator which we offer the backward compatibility, but it involves too
many layers to
+ // pass the information. The information is in SQLConf.
+ allowMultipleEventTimeColumns = false)
+
+ private val random = new scala.util.Random(System.currentTimeMillis())
+ private val bucketSizeForNoEventTime = 1024
+ private val extractEventTimeFn: UnsafeRow => Long = { row =>
+ eventTimeColIdxOpt match {
+ case Some(idx) =>
+ val attr = inputValueAttributes(idx)
+
+ if (attr.dataType.isInstanceOf[StructType]) {
+ // NOTE: We assume this is window struct, as same as
WatermarkSupport.watermarkExpression
+ row.getStruct(idx, 2).getLong(1)
+ } else {
+ row.getLong(idx)
+ }
+
+ case _ =>
+ // When event time column is not available, we will use random
bucketing strategy to decide
+ // where the new value will be stored. There is a trade-off between
the bucket size and the
+ // number of values in each bucket; we can tune the bucket size with
the configuration if
+ // we figure out the magic number to not work well.
+ random.nextInt(bucketSizeForNoEventTime)
+ }
+ }
+
+ private val eventTimeColIdxOptInKey: Option[Int] = {
+ joinKeys.zipWithIndex.collectFirst {
+ case (ne: NamedExpression, index)
+ if ne.metadata.contains(EventTimeWatermark.delayKey) => index
+ }
+ }
+
+ private val extractEventTimeFnFromKey: UnsafeRow => Option[Long] = { row =>
+ eventTimeColIdxOptInKey.map { idx =>
+ val attr = keyAttributes(idx)
+ if (attr.dataType.isInstanceOf[StructType]) {
+ // NOTE: We assume this is window struct, as same as
WatermarkSupport.watermarkExpression
+ row.getStruct(idx, 2).getLong(1)
+ } else {
+ row.getLong(idx)
+ }
+ }
+ }
+
+ private val dummySchema = StructType(
+ Seq(StructField("dummy", NullType, nullable = true))
+ )
+
+ // TODO: [SPARK-55628] Below two fields need to be handled properly during
integration with
+ // the operator.
+ private val stateStoreCkptId: Option[String] = None
+ private val handlerSnapshotOptions: Option[HandlerSnapshotOptions] = None
+
+ private var stateStoreProvider: StateStoreProvider = _
+
+ // We will use the dummy schema for the default CF since we will register CF
separately.
+ private val stateStore = getStateStore(
+ dummySchema, dummySchema, useVirtualColumnFamilies = true,
+ NoPrefixKeyStateEncoderSpec(dummySchema), useMultipleValuesPerKey = false
+ )
+
+ private def getStateStore(
+ keySchema: StructType,
+ valueSchema: StructType,
+ useVirtualColumnFamilies: Boolean,
+ keyStateEncoderSpec: KeyStateEncoderSpec,
+ useMultipleValuesPerKey: Boolean): StateStore = {
+ val storeName = StateStoreId.DEFAULT_STORE_NAME
+ val storeProviderId = StateStoreProviderId(stateInfo.get, partitionId,
storeName)
+ val store = if (useStateStoreCoordinator) {
+ assert(handlerSnapshotOptions.isEmpty, "Should not use state store
coordinator " +
+ "when reading state as data source.")
+ joinStoreGenerator.getStore(
+ storeProviderId, keySchema, valueSchema, keyStateEncoderSpec,
+ stateInfo.get.storeVersion, stateStoreCkptId, None,
useVirtualColumnFamilies,
+ useMultipleValuesPerKey, storeConf, hadoopConf)
+ } else {
+ // This class will manage the state store provider by itself.
+ stateStoreProvider = StateStoreProvider.createAndInit(
+ storeProviderId, keySchema, valueSchema, keyStateEncoderSpec,
+ useColumnFamilies = useVirtualColumnFamilies,
+ storeConf, hadoopConf, useMultipleValuesPerKey =
useMultipleValuesPerKey,
+ stateSchemaProvider = None)
+ if (handlerSnapshotOptions.isDefined) {
+ if (!stateStoreProvider.isInstanceOf[SupportsFineGrainedReplay]) {
+ throw
StateStoreErrors.stateStoreProviderDoesNotSupportFineGrainedReplay(
+ stateStoreProvider.getClass.toString)
+ }
+ val opts = handlerSnapshotOptions.get
+ stateStoreProvider.asInstanceOf[SupportsFineGrainedReplay]
+ .replayStateFromSnapshot(
+ opts.snapshotVersion,
+ opts.endVersion,
+ readOnly = true,
+ opts.startStateStoreCkptId,
+ opts.endStateStoreCkptId)
+ } else {
+ stateStoreProvider.getStore(stateInfo.get.storeVersion,
stateStoreCkptId)
+ }
+ }
+ logInfo(log"Loaded store ${MDC(STATE_STORE_ID, store.id)}")
+ store
+ }
+
+ private val keyWithTsToValues = new KeyWithTsToValuesStore
+
+ private val tsWithKey = new TsWithKeyTypeStore
+
+ override def append(key: UnsafeRow, value: UnsafeRow, matched: Boolean):
Unit = {
+ val eventTime = extractEventTimeFn(value)
+ // We always do blind merge for appending new value.
+ keyWithTsToValues.append(key, eventTime, value, matched)
+ tsWithKey.add(eventTime, key)
+ }
+
+ override def getJoinedRows(
+ key: UnsafeRow,
+ generateJoinedRow: InternalRow => JoinedRow,
+ predicate: JoinedRow => Boolean,
+ excludeRowsAlreadyMatched: Boolean): Iterator[JoinedRow] = {
+ // TODO: [SPARK-55147] We could improve this method to get the scope of
timestamp and scan keys
+ // more efficiently. For now, we just get all values for the key.
+ def getJoinedRowsFromTsAndValues(
+ ts: Long,
+ valuesAndMatched: Array[ValueAndMatchPair]): Iterator[JoinedRow] = {
+ new NextIterator[JoinedRow] {
+ private var currentIndex = 0
+
+ private var shouldUpdateValuesIntoStateStore = false
+
+ override protected def getNext(): JoinedRow = {
+ var ret: JoinedRow = null
+ while (ret == null && currentIndex < valuesAndMatched.length) {
+ val vmp = valuesAndMatched(currentIndex)
+
+ if (excludeRowsAlreadyMatched && vmp.matched) {
+ // Skip this one
+ } else {
+ val joinedRow = generateJoinedRow(vmp.value)
+ if (predicate(joinedRow)) {
+ if (!vmp.matched) {
+ // Update the array to contain the value having matched =
true
+ valuesAndMatched(currentIndex) = vmp.copy(matched = true)
+ // Need to update matched flag
+ shouldUpdateValuesIntoStateStore = true
+ }
+
+ ret = joinedRow
+ } else {
+ // skip this one
+ }
+ }
+
+ currentIndex += 1
+ }
+
+ if (ret == null) {
+ assert(currentIndex == valuesAndMatched.length)
+ finished = true
+ null
+ } else {
+ ret
+ }
+ }
+
+ override protected def close(): Unit = {
+ if (shouldUpdateValuesIntoStateStore) {
+ // Update back to the state store
+ val updatedValuesWithMatched = valuesAndMatched.map { vmp =>
+ (vmp.value, vmp.matched)
+ }.toSeq
+ keyWithTsToValues.put(key, ts, updatedValuesWithMatched)
+ }
+ }
+ }
+ }
+
+ val ret = extractEventTimeFnFromKey(key) match {
+ case Some(ts) =>
+ val valuesAndMatchedIter = keyWithTsToValues.get(key, ts)
+ getJoinedRowsFromTsAndValues(ts, valuesAndMatchedIter.toArray)
+
+ case _ =>
+ keyWithTsToValues.getValues(key).flatMap { result =>
+ val ts = result.timestamp
+ val valuesAndMatched = result.values.toArray
+ getJoinedRowsFromTsAndValues(ts, valuesAndMatched)
+ }
+ }
+ ret.filter(_ != null)
+ }
+
+ override def iterator: Iterator[KeyToValuePair] = {
+ val reusableKeyToValuePair = KeyToValuePair()
+ keyWithTsToValues.iterator().map { kv =>
+ reusableKeyToValuePair.withNew(kv.key, kv.value, kv.matched)
+ }
+ }
+
+ override def evictByTimestamp(endTimestamp: Long): Long = {
+ var removed = 0L
+ tsWithKey.scanEvictedKeys(endTimestamp).foreach { evicted =>
+ val key = evicted.key
+ val timestamp = evicted.timestamp
+ val numValues = evicted.numValues
+
+ // Remove from both primary and secondary stores
+ keyWithTsToValues.remove(key, timestamp)
+ tsWithKey.remove(key, timestamp)
+
+ removed += numValues
+ }
+ removed
+ }
+
+ override def evictAndReturnByTimestamp(endTimestamp: Long):
Iterator[KeyToValuePair] = {
+ val reusableKeyToValuePair = KeyToValuePair()
+
+ tsWithKey.scanEvictedKeys(endTimestamp).flatMap { evicted =>
+ val key = evicted.key
+ val timestamp = evicted.timestamp
+ val values = keyWithTsToValues.get(key, timestamp)
+
+ // Remove from both primary and secondary stores
+ keyWithTsToValues.remove(key, timestamp)
+ tsWithKey.remove(key, timestamp)
+
+ values.map { value =>
+ reusableKeyToValuePair.withNew(key, value)
+ }
+ }
+ }
+
+ override def commit(): Unit = {
+ stateStore.commit()
+ logDebug("Committed, metrics = " + stateStore.metrics)
+ }
+
+ override def abortIfNeeded(): Unit = {
+ if (!stateStore.hasCommitted) {
+ logInfo(log"Aborted store ${MDC(STATE_STORE_ID, stateStore.id)}")
+ stateStore.abort()
+ }
+ // If this class manages a state store provider by itself, it should take
care of closing
+ // provider instance as well.
+ if (stateStoreProvider != null) {
+ stateStoreProvider.close()
+ }
+ }
+
+ // Clean up any state store resources if necessary at the end of the task
+ Option(TaskContext.get()).foreach { _.addTaskCompletionListener[Unit] { _ =>
abortIfNeeded() } }
+
+ class GetValuesResult(var timestamp: Long = -1, var values:
Seq[ValueAndMatchPair] = Seq.empty) {
+ def withNew(newTimestamp: Long, newValues: Seq[ValueAndMatchPair]):
GetValuesResult = {
+ this.timestamp = newTimestamp
+ this.values = newValues
+ this
+ }
+ }
+
+ /**
+ * The primary store to store the key-value pairs.
+ *
+ * The state format of the primary store is following:
+ * [key][timestamp (event time)] -> [(value, matched), (value, matched), ...]
+ *
+ * The values are bucketed by event time to facilitate efficient eviction by
watermark; the
+ * secondary index will provide the way to scan the key + timestamp pairs
for the eviction, and
+ * it will be easy to perform retrieval/removal of the values based on key +
timestamp pairs.
+ * There is no case where we evict only part of the values for the same key
+ timestamp.
+ *
+ * The matched flag is used to indicate whether the value has been matched
with any row from the
+ * other side.
+ */
+ private class KeyWithTsToValuesStore {
+ private val valueRowConverter =
StreamingSymmetricHashJoinValueRowConverter.create(
+ inputValueAttributes, stateFormatVersion = 4)
+
+ // Set up virtual column family name in the store if it is being used
+ private val colFamilyName = getStateStoreName(joinSide,
KeyWithTsToValuesType)
+
+ private val keySchemaWithTimestamp =
TimestampKeyStateEncoder.keySchemaWithTimestamp(keySchema)
+ private val detachTimestampProjection: UnsafeProjection =
+
TimestampKeyStateEncoder.getDetachTimestampProjection(keySchemaWithTimestamp)
+ private val attachTimestampProjection: UnsafeProjection =
+ TimestampKeyStateEncoder.getAttachTimestampProjection(keySchema)
+
+ // Create the specific column family in the store for this join side's
KeyWithIndexToValueStore
+ stateStore.createColFamilyIfAbsent(
+ colFamilyName,
+ keySchema,
+ valueRowConverter.valueAttributes.toStructType,
+ TimestampAsPostfixKeyStateEncoderSpec(keySchemaWithTimestamp),
+ useMultipleValuesPerKey = true
+ )
+
+ private def createKeyRow(key: UnsafeRow, timestamp: Long): UnsafeRow = {
+ TimestampKeyStateEncoder.attachTimestamp(
+ attachTimestampProjection, keySchemaWithTimestamp, key, timestamp)
+ }
+
+ def append(key: UnsafeRow, timestamp: Long, value: UnsafeRow, matched:
Boolean): Unit = {
+ val valueWithMatched = valueRowConverter.convertToValueRow(value,
matched)
+ stateStore.merge(createKeyRow(key, timestamp), valueWithMatched,
colFamilyName)
+ }
+
+ def put(
+ key: UnsafeRow,
+ timestamp: Long,
+ valuesWithMatched: Seq[(UnsafeRow, Boolean)]): Unit = {
+ val valuesToPut = valuesWithMatched.map { case (value, matched) =>
+ valueRowConverter.convertToValueRow(value, matched)
+ }.toArray
+ stateStore.putList(createKeyRow(key, timestamp), valuesToPut,
colFamilyName)
+ }
+
+ def get(key: UnsafeRow, timestamp: Long): Iterator[ValueAndMatchPair] = {
+ stateStore.valuesIterator(createKeyRow(key, timestamp),
colFamilyName).map { valueRow =>
+ valueRowConverter.convertValue(valueRow)
+ }
+ }
+
+ // NOTE: We do not have a case where we only remove a part of values. Even
if that is needed
+ // we handle it via put() with writing a new array.
+ def remove(key: UnsafeRow, timestamp: Long): Unit = {
+ stateStore.remove(createKeyRow(key, timestamp), colFamilyName)
+ }
+
+ // NOTE: This assumes we consume the whole iterator to trigger completion.
+ def getValues(key: UnsafeRow): Iterator[GetValuesResult] = {
+ val reusableGetValuesResult = new GetValuesResult()
+
+ new NextIterator[GetValuesResult] {
+ private val iter = stateStore.prefixScanWithMultiValues(key,
colFamilyName)
+
+ private var currentTs = -1L
+ private val valueAndMatchPairs =
scala.collection.mutable.ArrayBuffer[ValueAndMatchPair]()
+
+ @tailrec
+ override protected def getNext(): GetValuesResult = {
+ if (iter.hasNext) {
+ val unsafeRowPair = iter.next()
+
+ val ts =
TimestampKeyStateEncoder.extractTimestamp(unsafeRowPair.key)
+
+ if (currentTs == -1L) {
+ // First time
+ currentTs = ts
+ }
+
+ if (currentTs != ts) {
+ assert(valueAndMatchPairs.nonEmpty,
+ "timestamp has changed but no values collected from previous
timestamp! " +
+ s"This should not happen. currentTs: $currentTs, new ts: $ts")
+
+ // Return previous batch
+ val result = reusableGetValuesResult.withNew(
+ currentTs, valueAndMatchPairs.toSeq)
+
+ // Reset for new timestamp
+ currentTs = ts
+ valueAndMatchPairs.clear()
+
+ // Add current value
+ val value = valueRowConverter.convertValue(unsafeRowPair.value)
+ valueAndMatchPairs += value
+ result
+ } else {
+ // Same timestamp, accumulate values
+ val value = valueRowConverter.convertValue(unsafeRowPair.value)
+ valueAndMatchPairs += value
+
+ // Continue to next
+ getNext()
+ }
+ } else {
+ if (currentTs != -1L) {
+ assert(valueAndMatchPairs.nonEmpty)
+
+ // Return last batch
+ val result = reusableGetValuesResult.withNew(
+ currentTs, valueAndMatchPairs.toSeq)
+
+ // Mark as finished
+ currentTs = -1L
+ valueAndMatchPairs.clear()
+ result
+ } else {
+ finished = true
+ null
+ }
+ }
+ }
+
+ override protected def close(): Unit = iter.close()
+ }
+ }
+
Review Comment:
Can we document about the reuse of UnsafeRow that the caller should consume
the result immediately?
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala:
##########
@@ -27,16 +27,676 @@ import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
import org.apache.spark.internal.LogKeys.{END_INDEX, START_INDEX,
STATE_STORE_ID}
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute,
AttributeReference, Expression, JoinedRow, Literal, SafeProjection,
SpecificInternalRow, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.{Attribute,
AttributeReference, Expression, JoinedRow, Literal, NamedExpression,
SafeProjection, SpecificInternalRow, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.execution.metric.SQLMetric
-import
org.apache.spark.sql.execution.streaming.operators.stateful.StatefulOperatorStateInfo
-import
org.apache.spark.sql.execution.streaming.operators.stateful.StatefulOpStateStoreCheckpointInfo
+import
org.apache.spark.sql.execution.streaming.operators.stateful.{StatefulOperatorStateInfo,
StatefulOpStateStoreCheckpointInfo, WatermarkSupport}
import
org.apache.spark.sql.execution.streaming.operators.stateful.join.StreamingSymmetricHashJoinHelper._
-import
org.apache.spark.sql.execution.streaming.state.{DropLastNFieldsStatePartitionKeyExtractor,
KeyStateEncoderSpec, NoopStatePartitionKeyExtractor,
NoPrefixKeyStateEncoderSpec, StatePartitionKeyExtractor, StateSchemaBroadcast,
StateStore, StateStoreCheckpointInfo, StateStoreColFamilySchema,
StateStoreConf, StateStoreErrors, StateStoreId, StateStoreMetrics,
StateStoreProvider, StateStoreProviderId, SupportsFineGrainedReplay}
-import org.apache.spark.sql.types.{BooleanType, LongType, StructField,
StructType}
+import
org.apache.spark.sql.execution.streaming.state.{DropLastNFieldsStatePartitionKeyExtractor,
KeyStateEncoderSpec, NoopStatePartitionKeyExtractor,
NoPrefixKeyStateEncoderSpec, StatePartitionKeyExtractor, StateSchemaBroadcast,
StateStore, StateStoreCheckpointInfo, StateStoreColFamilySchema,
StateStoreConf, StateStoreErrors, StateStoreId, StateStoreMetrics,
StateStoreProvider, StateStoreProviderId, SupportsFineGrainedReplay,
TimestampAsPostfixKeyStateEncoderSpec, TimestampAsPrefixKeyStateEncoderSpec,
TimestampKeyStateEncoder}
+import org.apache.spark.sql.types.{BooleanType, DataType, LongType, NullType,
StructField, StructType}
import org.apache.spark.util.NextIterator
+/**
+ * Base trait of the state manager for stream-stream symmetric hash join
operator.
+ *
+ * This defines the basic APIs for the state manager, except the methods for
eviction which are
+ * defined in separate traits - See [[SupportsEvictByCondition]] and
[[SupportsEvictByTimestamp]].
+ *
+ * Implementation classes are expected to inherit those traits as needed,
depending on the eviction
+ * strategy they support.
+ */
+trait SymmetricHashJoinStateManager {
+ import SymmetricHashJoinStateManager._
+
+ def append(key: UnsafeRow, value: UnsafeRow, matched: Boolean): Unit
+
+ def get(key: UnsafeRow): Iterator[UnsafeRow]
+
+ def getJoinedRows(
+ key: UnsafeRow,
+ generateJoinedRow: InternalRow => JoinedRow,
+ predicate: JoinedRow => Boolean,
+ excludeRowsAlreadyMatched: Boolean = false): Iterator[JoinedRow]
+
+ def iterator: Iterator[KeyToValuePair]
+
+ def commit(): Unit
+
+ def abortIfNeeded(): Unit
+
+ def metrics: StateStoreMetrics
+
+ def getLatestCheckpointInfo(): JoinerStateStoreCkptInfo
+}
+
+/**
+ * This trait is specific to help the old version of state manager
implementation (v1-v3) to work
+ * with existing tests which look up the state store with key with index.
+ */
+trait SupportsIndexedKeys {
+ def getInternalRowOfKeyWithIndex(currentKey: UnsafeRow): InternalRow
+
+ protected[streaming] def updateNumValuesTestOnly(key: UnsafeRow, numValues:
Long): Unit
+}
+
+/**
+ * This trait is for state manager implementations that support eviction by
condition.
+ * This is for the state manager implementations which have to perform full
scan
+ * for eviction.
+ */
+trait SupportsEvictByCondition { self: SymmetricHashJoinStateManager =>
+ import SymmetricHashJoinStateManager._
+
+ def evictByKeyCondition(removalCondition: UnsafeRow => Boolean): Long
+
+ def evictAndReturnByKeyCondition(
+ removalCondition: UnsafeRow => Boolean): Iterator[KeyToValuePair]
+
+ def evictByValueCondition(removalCondition: UnsafeRow => Boolean): Long
+
+ def evictAndReturnByValueCondition(
+ removalCondition: UnsafeRow => Boolean): Iterator[KeyToValuePair]
+}
+
+/**
+ * This trait is for state manager implementations that support eviction by
timestamp. This is for
+ * the state manager implementations which maintain the state with event time
and can efficiently
+ * scan the keys with event time smaller than the given timestamp for eviction.
+ */
+trait SupportsEvictByTimestamp { self: SymmetricHashJoinStateManager =>
+ import SymmetricHashJoinStateManager._
+
+ def evictByTimestamp(endTimestamp: Long): Long
+
+ def evictAndReturnByTimestamp(endTimestamp: Long): Iterator[KeyToValuePair]
+}
+
+/**
+ * The version 4 of stream-stream join state manager implementation, which is
designed to optimize
+ * the eviction with watermark. Previous versions require full scan to find
the keys to evict,
+ * while this version only scans the keys with event time smaller than the
watermark.
+ *
+ * In this implementation, we no longer build a logical array of values;
instead, we store the
+ * (key, timestamp) -> values in the primary store, and maintain a secondary
index of
+ * (timestamp, key) to scan the keys to evict for each watermark. To retrieve
the values for a key,
+ * we perform prefix scan with the key to get all the (key, timestamp) ->
values.
+ *
+ * Refer to the [[KeyWithTsToValuesStore]] and [[TsWithKeyTypeStore]] for more
details.
+ */
+class SymmetricHashJoinStateManagerV4(
+ joinSide: JoinSide,
+ inputValueAttributes: Seq[Attribute],
+ joinKeys: Seq[Expression],
+ stateInfo: Option[StatefulOperatorStateInfo],
+ storeConf: StateStoreConf,
+ hadoopConf: Configuration,
+ partitionId: Int,
+ keyToNumValuesStateStoreCkptId: Option[String],
+ keyWithIndexToValueStateStoreCkptId: Option[String],
+ stateFormatVersion: Int,
+ skippedNullValueCount: Option[SQLMetric] = None,
+ useStateStoreCoordinator: Boolean = true,
+ snapshotOptions: Option[SnapshotOptions] = None,
+ joinStoreGenerator: JoinStateManagerStoreGenerator)
+ extends SymmetricHashJoinStateManager with SupportsEvictByTimestamp with
Logging {
+
+ import SymmetricHashJoinStateManager._
+
+ protected val keySchema = StructType(
+ joinKeys.zipWithIndex.map { case (k, i) => StructField(s"field$i",
k.dataType, k.nullable) })
+ protected val keyAttributes = toAttributes(keySchema)
+ private val eventTimeColIdxOpt = WatermarkSupport.findEventTimeColumnIndex(
+ inputValueAttributes,
+ // NOTE: This does not accept multiple event time columns. This is not the
same with the
+ // operator which we offer the backward compatibility, but it involves too
many layers to
+ // pass the information. The information is in SQLConf.
+ allowMultipleEventTimeColumns = false)
+
+ private val random = new scala.util.Random(System.currentTimeMillis())
+ private val bucketSizeForNoEventTime = 1024
+ private val extractEventTimeFn: UnsafeRow => Long = { row =>
+ eventTimeColIdxOpt match {
+ case Some(idx) =>
+ val attr = inputValueAttributes(idx)
+
+ if (attr.dataType.isInstanceOf[StructType]) {
+ // NOTE: We assume this is window struct, as same as
WatermarkSupport.watermarkExpression
+ row.getStruct(idx, 2).getLong(1)
+ } else {
+ row.getLong(idx)
+ }
+
+ case _ =>
+ // When event time column is not available, we will use random
bucketing strategy to decide
+ // where the new value will be stored. There is a trade-off between
the bucket size and the
+ // number of values in each bucket; we can tune the bucket size with
the configuration if
+ // we figure out the magic number to not work well.
+ random.nextInt(bucketSizeForNoEventTime)
Review Comment:
IIUC `bucketCountForNoEventTime` would describe the purpose better?
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala:
##########
@@ -27,16 +27,676 @@ import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
import org.apache.spark.internal.LogKeys.{END_INDEX, START_INDEX,
STATE_STORE_ID}
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute,
AttributeReference, Expression, JoinedRow, Literal, SafeProjection,
SpecificInternalRow, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.{Attribute,
AttributeReference, Expression, JoinedRow, Literal, NamedExpression,
SafeProjection, SpecificInternalRow, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.execution.metric.SQLMetric
-import
org.apache.spark.sql.execution.streaming.operators.stateful.StatefulOperatorStateInfo
-import
org.apache.spark.sql.execution.streaming.operators.stateful.StatefulOpStateStoreCheckpointInfo
+import
org.apache.spark.sql.execution.streaming.operators.stateful.{StatefulOperatorStateInfo,
StatefulOpStateStoreCheckpointInfo, WatermarkSupport}
import
org.apache.spark.sql.execution.streaming.operators.stateful.join.StreamingSymmetricHashJoinHelper._
-import
org.apache.spark.sql.execution.streaming.state.{DropLastNFieldsStatePartitionKeyExtractor,
KeyStateEncoderSpec, NoopStatePartitionKeyExtractor,
NoPrefixKeyStateEncoderSpec, StatePartitionKeyExtractor, StateSchemaBroadcast,
StateStore, StateStoreCheckpointInfo, StateStoreColFamilySchema,
StateStoreConf, StateStoreErrors, StateStoreId, StateStoreMetrics,
StateStoreProvider, StateStoreProviderId, SupportsFineGrainedReplay}
-import org.apache.spark.sql.types.{BooleanType, LongType, StructField,
StructType}
+import
org.apache.spark.sql.execution.streaming.state.{DropLastNFieldsStatePartitionKeyExtractor,
KeyStateEncoderSpec, NoopStatePartitionKeyExtractor,
NoPrefixKeyStateEncoderSpec, StatePartitionKeyExtractor, StateSchemaBroadcast,
StateStore, StateStoreCheckpointInfo, StateStoreColFamilySchema,
StateStoreConf, StateStoreErrors, StateStoreId, StateStoreMetrics,
StateStoreProvider, StateStoreProviderId, SupportsFineGrainedReplay,
TimestampAsPostfixKeyStateEncoderSpec, TimestampAsPrefixKeyStateEncoderSpec,
TimestampKeyStateEncoder}
+import org.apache.spark.sql.types.{BooleanType, DataType, LongType, NullType,
StructField, StructType}
import org.apache.spark.util.NextIterator
+/**
+ * Base trait of the state manager for stream-stream symmetric hash join
operator.
+ *
+ * This defines the basic APIs for the state manager, except the methods for
eviction which are
+ * defined in separate traits - See [[SupportsEvictByCondition]] and
[[SupportsEvictByTimestamp]].
+ *
+ * Implementation classes are expected to inherit those traits as needed,
depending on the eviction
+ * strategy they support.
+ */
+trait SymmetricHashJoinStateManager {
+ import SymmetricHashJoinStateManager._
+
+ def append(key: UnsafeRow, value: UnsafeRow, matched: Boolean): Unit
+
+ def get(key: UnsafeRow): Iterator[UnsafeRow]
+
+ def getJoinedRows(
+ key: UnsafeRow,
+ generateJoinedRow: InternalRow => JoinedRow,
+ predicate: JoinedRow => Boolean,
+ excludeRowsAlreadyMatched: Boolean = false): Iterator[JoinedRow]
+
+ def iterator: Iterator[KeyToValuePair]
+
+ def commit(): Unit
+
+ def abortIfNeeded(): Unit
+
+ def metrics: StateStoreMetrics
+
+ def getLatestCheckpointInfo(): JoinerStateStoreCkptInfo
+}
+
+/**
+ * This trait is specific to help the old version of state manager
implementation (v1-v3) to work
+ * with existing tests which look up the state store with key with index.
+ */
+trait SupportsIndexedKeys {
+ def getInternalRowOfKeyWithIndex(currentKey: UnsafeRow): InternalRow
+
+ protected[streaming] def updateNumValuesTestOnly(key: UnsafeRow, numValues:
Long): Unit
+}
+
+/**
+ * This trait is for state manager implementations that support eviction by
condition.
+ * This is for the state manager implementations which have to perform full
scan
+ * for eviction.
+ */
+trait SupportsEvictByCondition { self: SymmetricHashJoinStateManager =>
+ import SymmetricHashJoinStateManager._
+
+ def evictByKeyCondition(removalCondition: UnsafeRow => Boolean): Long
+
+ def evictAndReturnByKeyCondition(
+ removalCondition: UnsafeRow => Boolean): Iterator[KeyToValuePair]
+
+ def evictByValueCondition(removalCondition: UnsafeRow => Boolean): Long
+
+ def evictAndReturnByValueCondition(
+ removalCondition: UnsafeRow => Boolean): Iterator[KeyToValuePair]
+}
+
+/**
+ * This trait is for state manager implementations that support eviction by
timestamp. This is for
+ * the state manager implementations which maintain the state with event time
and can efficiently
+ * scan the keys with event time smaller than the given timestamp for eviction.
+ */
+trait SupportsEvictByTimestamp { self: SymmetricHashJoinStateManager =>
+ import SymmetricHashJoinStateManager._
+
+ def evictByTimestamp(endTimestamp: Long): Long
+
+ def evictAndReturnByTimestamp(endTimestamp: Long): Iterator[KeyToValuePair]
+}
+
+/**
+ * The version 4 of stream-stream join state manager implementation, which is
designed to optimize
+ * the eviction with watermark. Previous versions require full scan to find
the keys to evict,
+ * while this version only scans the keys with event time smaller than the
watermark.
+ *
+ * In this implementation, we no longer build a logical array of values;
instead, we store the
+ * (key, timestamp) -> values in the primary store, and maintain a secondary
index of
+ * (timestamp, key) to scan the keys to evict for each watermark. To retrieve
the values for a key,
+ * we perform prefix scan with the key to get all the (key, timestamp) ->
values.
+ *
+ * Refer to the [[KeyWithTsToValuesStore]] and [[TsWithKeyTypeStore]] for more
details.
+ */
+class SymmetricHashJoinStateManagerV4(
+ joinSide: JoinSide,
+ inputValueAttributes: Seq[Attribute],
+ joinKeys: Seq[Expression],
+ stateInfo: Option[StatefulOperatorStateInfo],
+ storeConf: StateStoreConf,
+ hadoopConf: Configuration,
+ partitionId: Int,
+ keyToNumValuesStateStoreCkptId: Option[String],
+ keyWithIndexToValueStateStoreCkptId: Option[String],
+ stateFormatVersion: Int,
+ skippedNullValueCount: Option[SQLMetric] = None,
+ useStateStoreCoordinator: Boolean = true,
+ snapshotOptions: Option[SnapshotOptions] = None,
+ joinStoreGenerator: JoinStateManagerStoreGenerator)
+ extends SymmetricHashJoinStateManager with SupportsEvictByTimestamp with
Logging {
+
+ import SymmetricHashJoinStateManager._
+
+ protected val keySchema = StructType(
+ joinKeys.zipWithIndex.map { case (k, i) => StructField(s"field$i",
k.dataType, k.nullable) })
+ protected val keyAttributes = toAttributes(keySchema)
+ private val eventTimeColIdxOpt = WatermarkSupport.findEventTimeColumnIndex(
+ inputValueAttributes,
+ // NOTE: This does not accept multiple event time columns. This is not the
same with the
+ // operator which we offer the backward compatibility, but it involves too
many layers to
+ // pass the information. The information is in SQLConf.
+ allowMultipleEventTimeColumns = false)
+
+ private val random = new scala.util.Random(System.currentTimeMillis())
+ private val bucketSizeForNoEventTime = 1024
+ private val extractEventTimeFn: UnsafeRow => Long = { row =>
+ eventTimeColIdxOpt match {
+ case Some(idx) =>
+ val attr = inputValueAttributes(idx)
+
+ if (attr.dataType.isInstanceOf[StructType]) {
+ // NOTE: We assume this is window struct, as same as
WatermarkSupport.watermarkExpression
+ row.getStruct(idx, 2).getLong(1)
+ } else {
+ row.getLong(idx)
+ }
+
+ case _ =>
+ // When event time column is not available, we will use random
bucketing strategy to decide
+ // where the new value will be stored. There is a trade-off between
the bucket size and the
+ // number of values in each bucket; we can tune the bucket size with
the configuration if
+ // we figure out the magic number to not work well.
+ random.nextInt(bucketSizeForNoEventTime)
+ }
+ }
+
+ private val eventTimeColIdxOptInKey: Option[Int] = {
+ joinKeys.zipWithIndex.collectFirst {
+ case (ne: NamedExpression, index)
+ if ne.metadata.contains(EventTimeWatermark.delayKey) => index
+ }
+ }
+
+ private val extractEventTimeFnFromKey: UnsafeRow => Option[Long] = { row =>
+ eventTimeColIdxOptInKey.map { idx =>
+ val attr = keyAttributes(idx)
+ if (attr.dataType.isInstanceOf[StructType]) {
+ // NOTE: We assume this is window struct, as same as
WatermarkSupport.watermarkExpression
+ row.getStruct(idx, 2).getLong(1)
+ } else {
+ row.getLong(idx)
+ }
+ }
+ }
+
+ private val dummySchema = StructType(
+ Seq(StructField("dummy", NullType, nullable = true))
+ )
+
+ // TODO: [SPARK-55628] Below two fields need to be handled properly during
integration with
+ // the operator.
+ private val stateStoreCkptId: Option[String] = None
+ private val handlerSnapshotOptions: Option[HandlerSnapshotOptions] = None
+
+ private var stateStoreProvider: StateStoreProvider = _
+
+ // We will use the dummy schema for the default CF since we will register CF
separately.
+ private val stateStore = getStateStore(
+ dummySchema, dummySchema, useVirtualColumnFamilies = true,
+ NoPrefixKeyStateEncoderSpec(dummySchema), useMultipleValuesPerKey = false
+ )
+
+ private def getStateStore(
+ keySchema: StructType,
+ valueSchema: StructType,
+ useVirtualColumnFamilies: Boolean,
+ keyStateEncoderSpec: KeyStateEncoderSpec,
+ useMultipleValuesPerKey: Boolean): StateStore = {
+ val storeName = StateStoreId.DEFAULT_STORE_NAME
+ val storeProviderId = StateStoreProviderId(stateInfo.get, partitionId,
storeName)
+ val store = if (useStateStoreCoordinator) {
+ assert(handlerSnapshotOptions.isEmpty, "Should not use state store
coordinator " +
+ "when reading state as data source.")
+ joinStoreGenerator.getStore(
+ storeProviderId, keySchema, valueSchema, keyStateEncoderSpec,
+ stateInfo.get.storeVersion, stateStoreCkptId, None,
useVirtualColumnFamilies,
+ useMultipleValuesPerKey, storeConf, hadoopConf)
+ } else {
+ // This class will manage the state store provider by itself.
+ stateStoreProvider = StateStoreProvider.createAndInit(
+ storeProviderId, keySchema, valueSchema, keyStateEncoderSpec,
+ useColumnFamilies = useVirtualColumnFamilies,
+ storeConf, hadoopConf, useMultipleValuesPerKey =
useMultipleValuesPerKey,
+ stateSchemaProvider = None)
+ if (handlerSnapshotOptions.isDefined) {
+ if (!stateStoreProvider.isInstanceOf[SupportsFineGrainedReplay]) {
+ throw
StateStoreErrors.stateStoreProviderDoesNotSupportFineGrainedReplay(
+ stateStoreProvider.getClass.toString)
+ }
+ val opts = handlerSnapshotOptions.get
+ stateStoreProvider.asInstanceOf[SupportsFineGrainedReplay]
+ .replayStateFromSnapshot(
+ opts.snapshotVersion,
+ opts.endVersion,
+ readOnly = true,
+ opts.startStateStoreCkptId,
+ opts.endStateStoreCkptId)
+ } else {
+ stateStoreProvider.getStore(stateInfo.get.storeVersion,
stateStoreCkptId)
+ }
+ }
+ logInfo(log"Loaded store ${MDC(STATE_STORE_ID, store.id)}")
+ store
+ }
+
+ private val keyWithTsToValues = new KeyWithTsToValuesStore
+
+ private val tsWithKey = new TsWithKeyTypeStore
+
+ override def append(key: UnsafeRow, value: UnsafeRow, matched: Boolean):
Unit = {
+ val eventTime = extractEventTimeFn(value)
+ // We always do blind merge for appending new value.
+ keyWithTsToValues.append(key, eventTime, value, matched)
+ tsWithKey.add(eventTime, key)
+ }
+
+ override def getJoinedRows(
+ key: UnsafeRow,
+ generateJoinedRow: InternalRow => JoinedRow,
+ predicate: JoinedRow => Boolean,
+ excludeRowsAlreadyMatched: Boolean): Iterator[JoinedRow] = {
+ // TODO: [SPARK-55147] We could improve this method to get the scope of
timestamp and scan keys
+ // more efficiently. For now, we just get all values for the key.
+ def getJoinedRowsFromTsAndValues(
+ ts: Long,
+ valuesAndMatched: Array[ValueAndMatchPair]): Iterator[JoinedRow] = {
+ new NextIterator[JoinedRow] {
+ private var currentIndex = 0
+
+ private var shouldUpdateValuesIntoStateStore = false
+
+ override protected def getNext(): JoinedRow = {
+ var ret: JoinedRow = null
+ while (ret == null && currentIndex < valuesAndMatched.length) {
+ val vmp = valuesAndMatched(currentIndex)
+
+ if (excludeRowsAlreadyMatched && vmp.matched) {
+ // Skip this one
+ } else {
+ val joinedRow = generateJoinedRow(vmp.value)
+ if (predicate(joinedRow)) {
+ if (!vmp.matched) {
+ // Update the array to contain the value having matched =
true
+ valuesAndMatched(currentIndex) = vmp.copy(matched = true)
+ // Need to update matched flag
+ shouldUpdateValuesIntoStateStore = true
+ }
+
+ ret = joinedRow
+ } else {
+ // skip this one
+ }
+ }
+
+ currentIndex += 1
+ }
+
+ if (ret == null) {
+ assert(currentIndex == valuesAndMatched.length)
+ finished = true
+ null
+ } else {
+ ret
+ }
+ }
+
+ override protected def close(): Unit = {
+ if (shouldUpdateValuesIntoStateStore) {
+ // Update back to the state store
+ val updatedValuesWithMatched = valuesAndMatched.map { vmp =>
+ (vmp.value, vmp.matched)
+ }.toSeq
+ keyWithTsToValues.put(key, ts, updatedValuesWithMatched)
+ }
+ }
+ }
+ }
+
+ val ret = extractEventTimeFnFromKey(key) match {
+ case Some(ts) =>
+ val valuesAndMatchedIter = keyWithTsToValues.get(key, ts)
+ getJoinedRowsFromTsAndValues(ts, valuesAndMatchedIter.toArray)
+
+ case _ =>
+ keyWithTsToValues.getValues(key).flatMap { result =>
+ val ts = result.timestamp
+ val valuesAndMatched = result.values.toArray
+ getJoinedRowsFromTsAndValues(ts, valuesAndMatched)
+ }
+ }
+ ret.filter(_ != null)
+ }
+
Review Comment:
Add a comment about reference to the same UnsafeRow.
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala:
##########
@@ -27,16 +27,676 @@ import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
import org.apache.spark.internal.LogKeys.{END_INDEX, START_INDEX,
STATE_STORE_ID}
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute,
AttributeReference, Expression, JoinedRow, Literal, SafeProjection,
SpecificInternalRow, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.{Attribute,
AttributeReference, Expression, JoinedRow, Literal, NamedExpression,
SafeProjection, SpecificInternalRow, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.execution.metric.SQLMetric
-import
org.apache.spark.sql.execution.streaming.operators.stateful.StatefulOperatorStateInfo
-import
org.apache.spark.sql.execution.streaming.operators.stateful.StatefulOpStateStoreCheckpointInfo
+import
org.apache.spark.sql.execution.streaming.operators.stateful.{StatefulOperatorStateInfo,
StatefulOpStateStoreCheckpointInfo, WatermarkSupport}
import
org.apache.spark.sql.execution.streaming.operators.stateful.join.StreamingSymmetricHashJoinHelper._
-import
org.apache.spark.sql.execution.streaming.state.{DropLastNFieldsStatePartitionKeyExtractor,
KeyStateEncoderSpec, NoopStatePartitionKeyExtractor,
NoPrefixKeyStateEncoderSpec, StatePartitionKeyExtractor, StateSchemaBroadcast,
StateStore, StateStoreCheckpointInfo, StateStoreColFamilySchema,
StateStoreConf, StateStoreErrors, StateStoreId, StateStoreMetrics,
StateStoreProvider, StateStoreProviderId, SupportsFineGrainedReplay}
-import org.apache.spark.sql.types.{BooleanType, LongType, StructField,
StructType}
+import
org.apache.spark.sql.execution.streaming.state.{DropLastNFieldsStatePartitionKeyExtractor,
KeyStateEncoderSpec, NoopStatePartitionKeyExtractor,
NoPrefixKeyStateEncoderSpec, StatePartitionKeyExtractor, StateSchemaBroadcast,
StateStore, StateStoreCheckpointInfo, StateStoreColFamilySchema,
StateStoreConf, StateStoreErrors, StateStoreId, StateStoreMetrics,
StateStoreProvider, StateStoreProviderId, SupportsFineGrainedReplay,
TimestampAsPostfixKeyStateEncoderSpec, TimestampAsPrefixKeyStateEncoderSpec,
TimestampKeyStateEncoder}
+import org.apache.spark.sql.types.{BooleanType, DataType, LongType, NullType,
StructField, StructType}
import org.apache.spark.util.NextIterator
+/**
+ * Base trait of the state manager for stream-stream symmetric hash join
operator.
+ *
+ * This defines the basic APIs for the state manager, except the methods for
eviction which are
+ * defined in separate traits - See [[SupportsEvictByCondition]] and
[[SupportsEvictByTimestamp]].
+ *
+ * Implementation classes are expected to inherit those traits as needed,
depending on the eviction
+ * strategy they support.
+ */
+trait SymmetricHashJoinStateManager {
+ import SymmetricHashJoinStateManager._
+
+ def append(key: UnsafeRow, value: UnsafeRow, matched: Boolean): Unit
+
+ def get(key: UnsafeRow): Iterator[UnsafeRow]
+
+ def getJoinedRows(
+ key: UnsafeRow,
+ generateJoinedRow: InternalRow => JoinedRow,
+ predicate: JoinedRow => Boolean,
+ excludeRowsAlreadyMatched: Boolean = false): Iterator[JoinedRow]
+
+ def iterator: Iterator[KeyToValuePair]
+
+ def commit(): Unit
+
+ def abortIfNeeded(): Unit
+
+ def metrics: StateStoreMetrics
+
+ def getLatestCheckpointInfo(): JoinerStateStoreCkptInfo
+}
+
+/**
+ * This trait is specific to help the old version of state manager
implementation (v1-v3) to work
+ * with existing tests which look up the state store with key with index.
+ */
+trait SupportsIndexedKeys {
+ def getInternalRowOfKeyWithIndex(currentKey: UnsafeRow): InternalRow
+
+ protected[streaming] def updateNumValuesTestOnly(key: UnsafeRow, numValues:
Long): Unit
+}
+
+/**
+ * This trait is for state manager implementations that support eviction by
condition.
+ * This is for the state manager implementations which have to perform full
scan
+ * for eviction.
+ */
+trait SupportsEvictByCondition { self: SymmetricHashJoinStateManager =>
+ import SymmetricHashJoinStateManager._
+
+ def evictByKeyCondition(removalCondition: UnsafeRow => Boolean): Long
+
+ def evictAndReturnByKeyCondition(
+ removalCondition: UnsafeRow => Boolean): Iterator[KeyToValuePair]
+
+ def evictByValueCondition(removalCondition: UnsafeRow => Boolean): Long
+
+ def evictAndReturnByValueCondition(
+ removalCondition: UnsafeRow => Boolean): Iterator[KeyToValuePair]
+}
+
+/**
+ * This trait is for state manager implementations that support eviction by
timestamp. This is for
+ * the state manager implementations which maintain the state with event time
and can efficiently
+ * scan the keys with event time smaller than the given timestamp for eviction.
+ */
+trait SupportsEvictByTimestamp { self: SymmetricHashJoinStateManager =>
+ import SymmetricHashJoinStateManager._
+
+ def evictByTimestamp(endTimestamp: Long): Long
+
+ def evictAndReturnByTimestamp(endTimestamp: Long): Iterator[KeyToValuePair]
+}
+
+/**
+ * The version 4 of stream-stream join state manager implementation, which is
designed to optimize
+ * the eviction with watermark. Previous versions require full scan to find
the keys to evict,
+ * while this version only scans the keys with event time smaller than the
watermark.
+ *
+ * In this implementation, we no longer build a logical array of values;
instead, we store the
+ * (key, timestamp) -> values in the primary store, and maintain a secondary
index of
+ * (timestamp, key) to scan the keys to evict for each watermark. To retrieve
the values for a key,
+ * we perform prefix scan with the key to get all the (key, timestamp) ->
values.
+ *
+ * Refer to the [[KeyWithTsToValuesStore]] and [[TsWithKeyTypeStore]] for more
details.
+ */
+class SymmetricHashJoinStateManagerV4(
+ joinSide: JoinSide,
+ inputValueAttributes: Seq[Attribute],
+ joinKeys: Seq[Expression],
+ stateInfo: Option[StatefulOperatorStateInfo],
+ storeConf: StateStoreConf,
+ hadoopConf: Configuration,
+ partitionId: Int,
+ keyToNumValuesStateStoreCkptId: Option[String],
+ keyWithIndexToValueStateStoreCkptId: Option[String],
+ stateFormatVersion: Int,
+ skippedNullValueCount: Option[SQLMetric] = None,
+ useStateStoreCoordinator: Boolean = true,
+ snapshotOptions: Option[SnapshotOptions] = None,
+ joinStoreGenerator: JoinStateManagerStoreGenerator)
+ extends SymmetricHashJoinStateManager with SupportsEvictByTimestamp with
Logging {
+
+ import SymmetricHashJoinStateManager._
+
+ protected val keySchema = StructType(
+ joinKeys.zipWithIndex.map { case (k, i) => StructField(s"field$i",
k.dataType, k.nullable) })
+ protected val keyAttributes = toAttributes(keySchema)
+ private val eventTimeColIdxOpt = WatermarkSupport.findEventTimeColumnIndex(
+ inputValueAttributes,
+ // NOTE: This does not accept multiple event time columns. This is not the
same with the
+ // operator which we offer the backward compatibility, but it involves too
many layers to
+ // pass the information. The information is in SQLConf.
+ allowMultipleEventTimeColumns = false)
+
+ private val random = new scala.util.Random(System.currentTimeMillis())
+ private val bucketSizeForNoEventTime = 1024
+ private val extractEventTimeFn: UnsafeRow => Long = { row =>
+ eventTimeColIdxOpt match {
+ case Some(idx) =>
+ val attr = inputValueAttributes(idx)
+
+ if (attr.dataType.isInstanceOf[StructType]) {
+ // NOTE: We assume this is window struct, as same as
WatermarkSupport.watermarkExpression
+ row.getStruct(idx, 2).getLong(1)
+ } else {
+ row.getLong(idx)
+ }
+
+ case _ =>
+ // When event time column is not available, we will use random
bucketing strategy to decide
+ // where the new value will be stored. There is a trade-off between
the bucket size and the
+ // number of values in each bucket; we can tune the bucket size with
the configuration if
+ // we figure out the magic number to not work well.
+ random.nextInt(bucketSizeForNoEventTime)
+ }
+ }
+
+ private val eventTimeColIdxOptInKey: Option[Int] = {
+ joinKeys.zipWithIndex.collectFirst {
+ case (ne: NamedExpression, index)
+ if ne.metadata.contains(EventTimeWatermark.delayKey) => index
+ }
+ }
+
+ private val extractEventTimeFnFromKey: UnsafeRow => Option[Long] = { row =>
+ eventTimeColIdxOptInKey.map { idx =>
+ val attr = keyAttributes(idx)
+ if (attr.dataType.isInstanceOf[StructType]) {
+ // NOTE: We assume this is window struct, as same as
WatermarkSupport.watermarkExpression
+ row.getStruct(idx, 2).getLong(1)
+ } else {
+ row.getLong(idx)
+ }
+ }
+ }
+
+ private val dummySchema = StructType(
+ Seq(StructField("dummy", NullType, nullable = true))
+ )
+
+ // TODO: [SPARK-55628] Below two fields need to be handled properly during
integration with
+ // the operator.
+ private val stateStoreCkptId: Option[String] = None
+ private val handlerSnapshotOptions: Option[HandlerSnapshotOptions] = None
+
+ private var stateStoreProvider: StateStoreProvider = _
+
+ // We will use the dummy schema for the default CF since we will register CF
separately.
+ private val stateStore = getStateStore(
+ dummySchema, dummySchema, useVirtualColumnFamilies = true,
+ NoPrefixKeyStateEncoderSpec(dummySchema), useMultipleValuesPerKey = false
+ )
+
+ private def getStateStore(
+ keySchema: StructType,
+ valueSchema: StructType,
+ useVirtualColumnFamilies: Boolean,
+ keyStateEncoderSpec: KeyStateEncoderSpec,
+ useMultipleValuesPerKey: Boolean): StateStore = {
+ val storeName = StateStoreId.DEFAULT_STORE_NAME
+ val storeProviderId = StateStoreProviderId(stateInfo.get, partitionId,
storeName)
+ val store = if (useStateStoreCoordinator) {
+ assert(handlerSnapshotOptions.isEmpty, "Should not use state store
coordinator " +
+ "when reading state as data source.")
+ joinStoreGenerator.getStore(
+ storeProviderId, keySchema, valueSchema, keyStateEncoderSpec,
+ stateInfo.get.storeVersion, stateStoreCkptId, None,
useVirtualColumnFamilies,
+ useMultipleValuesPerKey, storeConf, hadoopConf)
+ } else {
+ // This class will manage the state store provider by itself.
+ stateStoreProvider = StateStoreProvider.createAndInit(
+ storeProviderId, keySchema, valueSchema, keyStateEncoderSpec,
+ useColumnFamilies = useVirtualColumnFamilies,
+ storeConf, hadoopConf, useMultipleValuesPerKey =
useMultipleValuesPerKey,
+ stateSchemaProvider = None)
+ if (handlerSnapshotOptions.isDefined) {
+ if (!stateStoreProvider.isInstanceOf[SupportsFineGrainedReplay]) {
+ throw
StateStoreErrors.stateStoreProviderDoesNotSupportFineGrainedReplay(
+ stateStoreProvider.getClass.toString)
+ }
+ val opts = handlerSnapshotOptions.get
+ stateStoreProvider.asInstanceOf[SupportsFineGrainedReplay]
+ .replayStateFromSnapshot(
+ opts.snapshotVersion,
+ opts.endVersion,
+ readOnly = true,
+ opts.startStateStoreCkptId,
+ opts.endStateStoreCkptId)
+ } else {
+ stateStoreProvider.getStore(stateInfo.get.storeVersion,
stateStoreCkptId)
+ }
+ }
+ logInfo(log"Loaded store ${MDC(STATE_STORE_ID, store.id)}")
+ store
+ }
+
+ private val keyWithTsToValues = new KeyWithTsToValuesStore
+
+ private val tsWithKey = new TsWithKeyTypeStore
+
+ override def append(key: UnsafeRow, value: UnsafeRow, matched: Boolean):
Unit = {
+ val eventTime = extractEventTimeFn(value)
+ // We always do blind merge for appending new value.
+ keyWithTsToValues.append(key, eventTime, value, matched)
+ tsWithKey.add(eventTime, key)
+ }
+
+ override def getJoinedRows(
+ key: UnsafeRow,
+ generateJoinedRow: InternalRow => JoinedRow,
+ predicate: JoinedRow => Boolean,
+ excludeRowsAlreadyMatched: Boolean): Iterator[JoinedRow] = {
+ // TODO: [SPARK-55147] We could improve this method to get the scope of
timestamp and scan keys
+ // more efficiently. For now, we just get all values for the key.
+ def getJoinedRowsFromTsAndValues(
+ ts: Long,
+ valuesAndMatched: Array[ValueAndMatchPair]): Iterator[JoinedRow] = {
+ new NextIterator[JoinedRow] {
+ private var currentIndex = 0
+
+ private var shouldUpdateValuesIntoStateStore = false
+
+ override protected def getNext(): JoinedRow = {
+ var ret: JoinedRow = null
+ while (ret == null && currentIndex < valuesAndMatched.length) {
+ val vmp = valuesAndMatched(currentIndex)
+
+ if (excludeRowsAlreadyMatched && vmp.matched) {
+ // Skip this one
+ } else {
+ val joinedRow = generateJoinedRow(vmp.value)
+ if (predicate(joinedRow)) {
+ if (!vmp.matched) {
+ // Update the array to contain the value having matched =
true
+ valuesAndMatched(currentIndex) = vmp.copy(matched = true)
+ // Need to update matched flag
+ shouldUpdateValuesIntoStateStore = true
+ }
+
+ ret = joinedRow
+ } else {
+ // skip this one
+ }
+ }
+
+ currentIndex += 1
+ }
+
+ if (ret == null) {
+ assert(currentIndex == valuesAndMatched.length)
+ finished = true
+ null
+ } else {
+ ret
+ }
+ }
+
+ override protected def close(): Unit = {
+ if (shouldUpdateValuesIntoStateStore) {
+ // Update back to the state store
+ val updatedValuesWithMatched = valuesAndMatched.map { vmp =>
+ (vmp.value, vmp.matched)
+ }.toSeq
+ keyWithTsToValues.put(key, ts, updatedValuesWithMatched)
+ }
+ }
+ }
+ }
+
+ val ret = extractEventTimeFnFromKey(key) match {
+ case Some(ts) =>
+ val valuesAndMatchedIter = keyWithTsToValues.get(key, ts)
+ getJoinedRowsFromTsAndValues(ts, valuesAndMatchedIter.toArray)
+
+ case _ =>
+ keyWithTsToValues.getValues(key).flatMap { result =>
+ val ts = result.timestamp
+ val valuesAndMatched = result.values.toArray
+ getJoinedRowsFromTsAndValues(ts, valuesAndMatched)
+ }
+ }
+ ret.filter(_ != null)
Review Comment:
why do we need this?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]