HeartSaVioR commented on code in PR #53930:
URL: https://github.com/apache/spark/pull/53930#discussion_r2797149557
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala:
##########
@@ -792,6 +789,29 @@ case class StreamingSymmetricHashJoinExec(
joinStateManager.get(key)
}
+ // FIXME: doc!
Review Comment:
Self review: add code comment
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala:
##########
@@ -27,16 +27,613 @@ import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
import org.apache.spark.internal.LogKeys.{END_INDEX, START_INDEX,
STATE_STORE_ID}
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute,
AttributeReference, Expression, JoinedRow, Literal, SafeProjection,
SpecificInternalRow, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.{Attribute,
AttributeReference, Expression, JoinedRow, Literal, NamedExpression,
SafeProjection, SpecificInternalRow, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.execution.metric.SQLMetric
-import
org.apache.spark.sql.execution.streaming.operators.stateful.StatefulOperatorStateInfo
-import
org.apache.spark.sql.execution.streaming.operators.stateful.StatefulOpStateStoreCheckpointInfo
+import
org.apache.spark.sql.execution.streaming.operators.stateful.{StatefulOperatorStateInfo,
StatefulOpStateStoreCheckpointInfo, WatermarkSupport}
import
org.apache.spark.sql.execution.streaming.operators.stateful.join.StreamingSymmetricHashJoinHelper._
-import
org.apache.spark.sql.execution.streaming.state.{DropLastNFieldsStatePartitionKeyExtractor,
KeyStateEncoderSpec, NoopStatePartitionKeyExtractor,
NoPrefixKeyStateEncoderSpec, StatePartitionKeyExtractor, StateSchemaBroadcast,
StateStore, StateStoreCheckpointInfo, StateStoreColFamilySchema,
StateStoreConf, StateStoreErrors, StateStoreId, StateStoreMetrics,
StateStoreProvider, StateStoreProviderId, SupportsFineGrainedReplay}
-import org.apache.spark.sql.types.{BooleanType, LongType, StructField,
StructType}
+import
org.apache.spark.sql.execution.streaming.state.{DropLastNFieldsStatePartitionKeyExtractor,
KeyStateEncoderSpec, NoopStatePartitionKeyExtractor,
NoPrefixKeyStateEncoderSpec, StatePartitionKeyExtractor, StateSchemaBroadcast,
StateStore, StateStoreCheckpointInfo, StateStoreColFamilySchema,
StateStoreConf, StateStoreErrors, StateStoreId, StateStoreMetrics,
StateStoreProvider, StateStoreProviderId, SupportsFineGrainedReplay,
TimestampAsPostfixKeyStateEncoderSpec, TimestampAsPrefixKeyStateEncoderSpec,
TimestampKeyStateEncoder}
+import org.apache.spark.sql.types.{BooleanType, DataType, LongType, NullType,
StructField, StructType}
import org.apache.spark.util.NextIterator
+trait SymmetricHashJoinStateManager {
+ import SymmetricHashJoinStateManager._
+
+ def append(key: UnsafeRow, value: UnsafeRow, matched: Boolean): Unit
+
+ def get(key: UnsafeRow): Iterator[UnsafeRow]
+
+ def getJoinedRows(
+ key: UnsafeRow,
+ generateJoinedRow: InternalRow => JoinedRow,
+ predicate: JoinedRow => Boolean,
+ excludeRowsAlreadyMatched: Boolean = false): Iterator[JoinedRow]
+
+ def iterator: Iterator[KeyToValuePair]
+
+ def commit(): Unit
+
+ def abortIfNeeded(): Unit
+
+ def metrics: StateStoreMetrics
+
+ def getLatestCheckpointInfo(): JoinerStateStoreCkptInfo
+}
+
+trait SupportsIndexedKeys {
Review Comment:
Self review: add code comment
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala:
##########
@@ -913,6 +1473,30 @@ class SymmetricHashJoinStateManagerV1(
}
override def metrics: StateStoreMetrics = {
+ // FIXME: purposed for benchmarking
+ // /*
+ val keyToNumValuesMetrics = keyToNumValues.metrics
+ val keyWithIndexToValueMetrics = keyWithIndexToValue.metrics
+ def newDesc(desc: String): String =
s"${joinSide.toString.toUpperCase(Locale.ROOT)}: $desc"
+
+ val mergedCustomMetrics = (keyToNumValuesMetrics.customMetrics.toSeq ++
+ keyWithIndexToValueMetrics.customMetrics.toSeq)
+ .groupBy(_._1)
+ .map { case (metric, metrics) =>
+ val mergedValue = metrics.map(_._2).sum
+ (metric.withNewDesc(desc = newDesc(metric.desc)), mergedValue)
+ }
+
+ StateStoreMetrics(
+ keyWithIndexToValueMetrics.numKeys, // represent each buffered row
only once
+ keyToNumValuesMetrics.memoryUsedBytes +
keyWithIndexToValueMetrics.memoryUsedBytes,
+ mergedCustomMetrics,
+ // We want to collect instance metrics from both state stores
+ keyWithIndexToValueMetrics.instanceMetrics ++
keyToNumValuesMetrics.instanceMetrics
+ )
+ // */
+
+ /*
Review Comment:
Self review: revert it
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala:
##########
@@ -27,16 +27,613 @@ import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
import org.apache.spark.internal.LogKeys.{END_INDEX, START_INDEX,
STATE_STORE_ID}
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute,
AttributeReference, Expression, JoinedRow, Literal, SafeProjection,
SpecificInternalRow, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.{Attribute,
AttributeReference, Expression, JoinedRow, Literal, NamedExpression,
SafeProjection, SpecificInternalRow, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.execution.metric.SQLMetric
-import
org.apache.spark.sql.execution.streaming.operators.stateful.StatefulOperatorStateInfo
-import
org.apache.spark.sql.execution.streaming.operators.stateful.StatefulOpStateStoreCheckpointInfo
+import
org.apache.spark.sql.execution.streaming.operators.stateful.{StatefulOperatorStateInfo,
StatefulOpStateStoreCheckpointInfo, WatermarkSupport}
import
org.apache.spark.sql.execution.streaming.operators.stateful.join.StreamingSymmetricHashJoinHelper._
-import
org.apache.spark.sql.execution.streaming.state.{DropLastNFieldsStatePartitionKeyExtractor,
KeyStateEncoderSpec, NoopStatePartitionKeyExtractor,
NoPrefixKeyStateEncoderSpec, StatePartitionKeyExtractor, StateSchemaBroadcast,
StateStore, StateStoreCheckpointInfo, StateStoreColFamilySchema,
StateStoreConf, StateStoreErrors, StateStoreId, StateStoreMetrics,
StateStoreProvider, StateStoreProviderId, SupportsFineGrainedReplay}
-import org.apache.spark.sql.types.{BooleanType, LongType, StructField,
StructType}
+import
org.apache.spark.sql.execution.streaming.state.{DropLastNFieldsStatePartitionKeyExtractor,
KeyStateEncoderSpec, NoopStatePartitionKeyExtractor,
NoPrefixKeyStateEncoderSpec, StatePartitionKeyExtractor, StateSchemaBroadcast,
StateStore, StateStoreCheckpointInfo, StateStoreColFamilySchema,
StateStoreConf, StateStoreErrors, StateStoreId, StateStoreMetrics,
StateStoreProvider, StateStoreProviderId, SupportsFineGrainedReplay,
TimestampAsPostfixKeyStateEncoderSpec, TimestampAsPrefixKeyStateEncoderSpec,
TimestampKeyStateEncoder}
+import org.apache.spark.sql.types.{BooleanType, DataType, LongType, NullType,
StructField, StructType}
import org.apache.spark.util.NextIterator
+trait SymmetricHashJoinStateManager {
+ import SymmetricHashJoinStateManager._
+
+ def append(key: UnsafeRow, value: UnsafeRow, matched: Boolean): Unit
+
+ def get(key: UnsafeRow): Iterator[UnsafeRow]
+
+ def getJoinedRows(
+ key: UnsafeRow,
+ generateJoinedRow: InternalRow => JoinedRow,
+ predicate: JoinedRow => Boolean,
+ excludeRowsAlreadyMatched: Boolean = false): Iterator[JoinedRow]
+
+ def iterator: Iterator[KeyToValuePair]
+
+ def commit(): Unit
+
+ def abortIfNeeded(): Unit
+
+ def metrics: StateStoreMetrics
+
+ def getLatestCheckpointInfo(): JoinerStateStoreCkptInfo
+}
+
+trait SupportsIndexedKeys {
+ def getInternalRowOfKeyWithIndex(currentKey: UnsafeRow): InternalRow
+
+ protected[streaming] def updateNumValuesTestOnly(key: UnsafeRow, numValues:
Long): Unit
+}
+
+trait SupportsEvictByCondition { self: SymmetricHashJoinStateManager =>
+ import SymmetricHashJoinStateManager._
+
+ def evictByKeyCondition(removalCondition: UnsafeRow => Boolean): Long
+
+ def evictAndReturnByKeyCondition(
+ removalCondition: UnsafeRow => Boolean): Iterator[KeyToValuePair]
+
+ def evictByValueCondition(removalCondition: UnsafeRow => Boolean): Long
+
+ def evictAndReturnByValueCondition(
+ removalCondition: UnsafeRow => Boolean): Iterator[KeyToValuePair]
+}
+
+trait SupportsEvictByTimestamp { self: SymmetricHashJoinStateManager =>
+ import SymmetricHashJoinStateManager._
+
+ def evictByTimestamp(endTimestamp: Long): Long
+
+ def evictAndReturnByTimestamp(endTimestamp: Long): Iterator[KeyToValuePair]
+}
+
+class SymmetricHashJoinStateManagerV4(
+ joinSide: JoinSide,
+ inputValueAttributes: Seq[Attribute],
+ joinKeys: Seq[Expression],
+ stateInfo: Option[StatefulOperatorStateInfo],
+ storeConf: StateStoreConf,
+ hadoopConf: Configuration,
+ partitionId: Int,
+ keyToNumValuesStateStoreCkptId: Option[String],
+ keyWithIndexToValueStateStoreCkptId: Option[String],
+ stateFormatVersion: Int,
+ skippedNullValueCount: Option[SQLMetric] = None,
+ useStateStoreCoordinator: Boolean = true,
+ snapshotOptions: Option[SnapshotOptions] = None,
+ joinStoreGenerator: JoinStateManagerStoreGenerator)
+ extends SymmetricHashJoinStateManager with SupportsEvictByTimestamp with
Logging {
+
+ import SymmetricHashJoinStateManager._
+
+ protected val keySchema = StructType(
+ joinKeys.zipWithIndex.map { case (k, i) => StructField(s"field$i",
k.dataType, k.nullable) })
+ protected val keyAttributes = toAttributes(keySchema)
+ private val eventTimeColIdxOpt = WatermarkSupport.findEventTimeColumnIndex(
+ inputValueAttributes,
+ // NOTE: This does not accept multiple event time columns. This is not the
same with the
+ // operator which we offer the backward compatibility, but it involves too
many layers to
+ // pass the information. The information is in SQLConf.
+ allowMultipleEventTimeColumns = false)
+
+ private val random = new scala.util.Random(System.currentTimeMillis())
+ private val bucketSizeForNoEventTime = 1024
+ private val extractEventTimeFn: UnsafeRow => Long = { row =>
+ eventTimeColIdxOpt match {
+ case Some(idx) =>
+ val attr = inputValueAttributes(idx)
+
+ if (attr.dataType.isInstanceOf[StructType]) {
+ // NOTE: We assume this is window struct, as same as
WatermarkSupport.watermarkExpression
+ row.getStruct(idx, 2).getLong(1)
+ } else {
+ row.getLong(idx)
+ }
+
+ case _ =>
+ // Need a strategy about bucketing when event time is not available
+ // - first attempt: random bucketing
+ random.nextInt(bucketSizeForNoEventTime)
+ }
+ }
+
+ private val eventTimeColIdxOptInKey: Option[Int] = {
+ joinKeys.zipWithIndex.collectFirst {
+ case (ne: NamedExpression, index)
+ if ne.metadata.contains(EventTimeWatermark.delayKey) => index
+ }
+ }
+
+ private val extractEventTimeFnFromKey: UnsafeRow => Option[Long] = { row =>
+ eventTimeColIdxOptInKey.map { idx =>
+ val attr = keyAttributes(idx)
+ if (attr.dataType.isInstanceOf[StructType]) {
+ // NOTE: We assume this is window struct, as same as
WatermarkSupport.watermarkExpression
+ row.getStruct(idx, 2).getLong(1)
+ } else {
+ row.getLong(idx)
+ }
+ }
+ }
+
+ private val dummySchema = StructType(
+ Seq(StructField("dummy", NullType, nullable = true))
+ )
+
+ private val stateStoreCkptId: Option[String] = None
+ private val handlerSnapshotOptions: Option[HandlerSnapshotOptions] = None
+ private var stateStoreProvider: StateStoreProvider = _
+
+ // We will use the dummy schema for the default CF since we will register CF
separately.
+ private val stateStore = getStateStore(
+ dummySchema, dummySchema, useVirtualColumnFamilies = true,
+ NoPrefixKeyStateEncoderSpec(dummySchema), useMultipleValuesPerKey = false
+ )
+
+ private def getStateStore(
+ keySchema: StructType,
+ valueSchema: StructType,
+ useVirtualColumnFamilies: Boolean,
+ keyStateEncoderSpec: KeyStateEncoderSpec,
+ useMultipleValuesPerKey: Boolean): StateStore = {
+ val storeName = StateStoreId.DEFAULT_STORE_NAME
+ val storeProviderId = StateStoreProviderId(stateInfo.get, partitionId,
storeName)
+ val store = if (useStateStoreCoordinator) {
+ assert(handlerSnapshotOptions.isEmpty, "Should not use state store
coordinator " +
+ "when reading state as data source.")
+ joinStoreGenerator.getStore(
+ storeProviderId, keySchema, valueSchema, keyStateEncoderSpec,
+ stateInfo.get.storeVersion, stateStoreCkptId, None,
useVirtualColumnFamilies,
+ useMultipleValuesPerKey, storeConf, hadoopConf)
+ } else {
+ // This class will manage the state store provider by itself.
+ stateStoreProvider = StateStoreProvider.createAndInit(
+ storeProviderId, keySchema, valueSchema, keyStateEncoderSpec,
+ useColumnFamilies = useVirtualColumnFamilies,
+ storeConf, hadoopConf, useMultipleValuesPerKey =
useMultipleValuesPerKey,
+ stateSchemaProvider = None)
+ if (handlerSnapshotOptions.isDefined) {
+ if (!stateStoreProvider.isInstanceOf[SupportsFineGrainedReplay]) {
+ throw
StateStoreErrors.stateStoreProviderDoesNotSupportFineGrainedReplay(
+ stateStoreProvider.getClass.toString)
+ }
+ val opts = handlerSnapshotOptions.get
+ stateStoreProvider.asInstanceOf[SupportsFineGrainedReplay]
+ .replayStateFromSnapshot(
+ opts.snapshotVersion,
+ opts.endVersion,
+ readOnly = true,
+ opts.startStateStoreCkptId,
+ opts.endStateStoreCkptId)
+ } else {
+ stateStoreProvider.getStore(stateInfo.get.storeVersion,
stateStoreCkptId)
+ }
+ }
+ logInfo(log"Loaded store ${MDC(STATE_STORE_ID, store.id)}")
+ store
+ }
+
+ private val keyWithTsToValues = new KeyWithTsToValuesStore
+
+ private val tsWithKey = new TsWithKeyTypeStore
+
+ override def append(key: UnsafeRow, value: UnsafeRow, matched: Boolean):
Unit = {
+ val eventTime = extractEventTimeFn(value)
+ // We always do blind merge for appending new value.
+ keyWithTsToValues.append(key, eventTime, value, matched)
+ tsWithKey.add(eventTime, key)
+ }
+
+ override def getJoinedRows(
+ key: UnsafeRow,
+ generateJoinedRow: InternalRow => JoinedRow,
+ predicate: JoinedRow => Boolean,
+ excludeRowsAlreadyMatched: Boolean): Iterator[JoinedRow] = {
+ // TODO: We could improve this method to get the scope of timestamp and
scan keys
+ // more efficiently. For now, we just get all values for the key.
+
+ def getJoinedRowsFromTsAndValues(
+ ts: Long,
+ valuesAndMatched: Array[ValueAndMatchPair]): Iterator[JoinedRow] = {
+ new NextIterator[JoinedRow] {
+ private var currentIndex = 0
+
+ private var shouldUpdateValuesIntoStateStore = false
+
+ override protected def getNext(): JoinedRow = {
+ var ret: JoinedRow = null
+ while (ret == null && currentIndex < valuesAndMatched.length) {
+ val vmp = valuesAndMatched(currentIndex)
+
+ if (excludeRowsAlreadyMatched && vmp.matched) {
+ // Skip this one
+ } else {
+ val joinedRow = generateJoinedRow(vmp.value)
+ if (predicate(joinedRow)) {
+ if (!vmp.matched) {
+ // Update the array to contain the value having matched =
true
+ valuesAndMatched(currentIndex) = vmp.copy(matched = true)
+ // Need to update matched flag
+ shouldUpdateValuesIntoStateStore = true
+ }
+
+ ret = joinedRow
+ } else {
+ // skip this one
+ }
+ }
+
+ currentIndex += 1
+ }
+
+ if (ret == null) {
+ assert(currentIndex == valuesAndMatched.length)
+ finished = true
+ null
+ } else {
+ ret
+ }
+ }
+
+ override protected def close(): Unit = {
+ if (shouldUpdateValuesIntoStateStore) {
+ // Update back to the state store
+ val updatedValuesWithMatched = valuesAndMatched.map { vmp =>
+ (vmp.value, vmp.matched)
+ }.toSeq
+ keyWithTsToValues.put(key, ts, updatedValuesWithMatched)
+ }
+ }
+ }
+ }
+
+ val ret = extractEventTimeFnFromKey(key) match {
+ case Some(ts) =>
+ val valuesAndMatchedIter = keyWithTsToValues.get(key, ts)
+ getJoinedRowsFromTsAndValues(ts, valuesAndMatchedIter.toArray)
+
+ case _ =>
+ keyWithTsToValues.getValues(key).flatMap { result =>
+ val ts = result.timestamp
+ val valuesAndMatched = result.values.toArray
+ getJoinedRowsFromTsAndValues(ts, valuesAndMatched)
+ }
+ }
+ ret.filter(_ != null)
+ }
+
+ override def iterator: Iterator[KeyToValuePair] = {
+ val reusableKeyToValuePair = KeyToValuePair()
+ keyWithTsToValues.iterator().map { kv =>
+ reusableKeyToValuePair.withNew(kv.key, kv.value, kv.matched)
+ }
+ }
+
+ override def evictByTimestamp(endTimestamp: Long): Long = {
+ var removed = 0L
+ tsWithKey.scanEvictedKeys(endTimestamp).foreach { evicted =>
+ val key = evicted.key
+ val timestamp = evicted.timestamp
+ val numValues = evicted.numValues
+
+ // Remove from both primary and secondary stores
+ keyWithTsToValues.remove(key, timestamp)
+ tsWithKey.remove(key, timestamp)
+
+ removed += numValues
+ }
+ removed
+ }
+
+ override def evictAndReturnByTimestamp(endTimestamp: Long):
Iterator[KeyToValuePair] = {
+ val reusableKeyToValuePair = KeyToValuePair()
+
+ tsWithKey.scanEvictedKeys(endTimestamp).flatMap { evicted =>
+ val key = evicted.key
+ val timestamp = evicted.timestamp
+ val values = keyWithTsToValues.get(key, timestamp)
+
+ // Remove from both primary and secondary stores
+ keyWithTsToValues.remove(key, timestamp)
+ tsWithKey.remove(key, timestamp)
+
+ values.map { value =>
+ reusableKeyToValuePair.withNew(key, value)
+ }
+ }
+ }
+
+ override def commit(): Unit = {
+ stateStore.commit()
+ logDebug("Committed, metrics = " + stateStore.metrics)
+ }
+
+ override def abortIfNeeded(): Unit = {
+ if (!stateStore.hasCommitted) {
+ logInfo(log"Aborted store ${MDC(STATE_STORE_ID, stateStore.id)}")
+ stateStore.abort()
+ }
+ // If this class manages a state store provider by itself, it should take
care of closing
+ // provider instance as well.
+ if (stateStoreProvider != null) {
+ stateStoreProvider.close()
+ }
+ }
+
+ // Clean up any state store resources if necessary at the end of the task
+ Option(TaskContext.get()).foreach { _.addTaskCompletionListener[Unit] { _ =>
abortIfNeeded() } }
+
+ class GetValuesResult(var timestamp: Long = -1, var values:
Seq[ValueAndMatchPair] = Seq.empty) {
+ def withNew(newTimestamp: Long, newValues: Seq[ValueAndMatchPair]):
GetValuesResult = {
+ this.timestamp = newTimestamp
+ this.values = newValues
+ this
+ }
+ }
+
+ private class KeyWithTsToValuesStore {
+
+ private val valueRowConverter =
StreamingSymmetricHashJoinValueRowConverter.create(
+ inputValueAttributes, stateFormatVersion = 4)
+
+ // Set up virtual column family name in the store if it is being used
+ private val colFamilyName = getStateStoreName(joinSide,
KeyWithTsToValuesType)
+
+ private val keySchemaWithTimestamp =
TimestampKeyStateEncoder.keySchemaWithTimestamp(keySchema)
+ private val detachTimestampProjection: UnsafeProjection =
+
TimestampKeyStateEncoder.getDetachTimestampProjection(keySchemaWithTimestamp)
+ private val attachTimestampProjection: UnsafeProjection =
+ TimestampKeyStateEncoder.getAttachTimestampProjection(keySchema)
+
+ // Create the specific column family in the store for this join side's
KeyWithIndexToValueStore
+ stateStore.createColFamilyIfAbsent(
+ colFamilyName,
+ keySchema,
+ valueRowConverter.valueAttributes.toStructType,
+ TimestampAsPostfixKeyStateEncoderSpec(keySchemaWithTimestamp),
+ useMultipleValuesPerKey = true
+ )
+
+ private def createKeyRow(key: UnsafeRow, timestamp: Long): UnsafeRow = {
+ TimestampKeyStateEncoder.attachTimestamp(
+ attachTimestampProjection, keySchemaWithTimestamp, key, timestamp)
+ }
+
+ def append(key: UnsafeRow, timestamp: Long, value: UnsafeRow, matched:
Boolean): Unit = {
+ val valueWithMatched = valueRowConverter.convertToValueRow(value,
matched)
+ stateStore.merge(createKeyRow(key, timestamp), valueWithMatched,
colFamilyName)
+ }
+
+ def put(
+ key: UnsafeRow,
+ timestamp: Long,
+ valuesWithMatched: Seq[(UnsafeRow, Boolean)]): Unit = {
+ val valuesToPut = valuesWithMatched.map { case (value, matched) =>
+ valueRowConverter.convertToValueRow(value, matched)
+ }.toArray
+ stateStore.putList(createKeyRow(key, timestamp), valuesToPut,
colFamilyName)
+ }
+
+ def get(key: UnsafeRow, timestamp: Long): Iterator[ValueAndMatchPair] = {
+ stateStore.valuesIterator(createKeyRow(key, timestamp),
colFamilyName).map { valueRow =>
+ valueRowConverter.convertValue(valueRow)
+ }
+ }
+
+ // NOTE: We do not have a case where we only remove a part of values. Even
if that is needed
+ // we handle it via put() with writing a new array.
+ def remove(key: UnsafeRow, timestamp: Long): Unit = {
+ stateStore.remove(createKeyRow(key, timestamp), colFamilyName)
+ }
+
+ // NOTE: This assumes we consume the whole iterator to trigger completion.
+ def getValues(key: UnsafeRow): Iterator[GetValuesResult] = {
+ val reusableGetValuesResult = new GetValuesResult()
+
+ new NextIterator[GetValuesResult] {
+ private val iter = stateStore.prefixScanWithMultiValues(key,
colFamilyName)
+
+ private var currentTs = -1L
+ private val valueAndMatchPairs =
scala.collection.mutable.ArrayBuffer[ValueAndMatchPair]()
+
+ @tailrec
+ override protected def getNext(): GetValuesResult = {
+ if (iter.hasNext) {
+ val unsafeRowPair = iter.next()
+
+ val ts =
TimestampKeyStateEncoder.extractTimestamp(unsafeRowPair.key)
+
+ if (currentTs == -1L) {
+ // First time
+ currentTs = ts
+ }
+
+ if (currentTs != ts) {
+ assert(valueAndMatchPairs.nonEmpty,
+ "timestamp has changed but no values collected from previous
timestamp! " +
+ s"This should not happen. currentTs: $currentTs, new ts: $ts")
+
+ // Return previous batch
+ val result = reusableGetValuesResult.withNew(
+ currentTs, valueAndMatchPairs.toSeq)
+
+ // Reset for new timestamp
+ currentTs = ts
+ valueAndMatchPairs.clear()
+
+ // Add current value
+ val value = valueRowConverter.convertValue(unsafeRowPair.value)
+ valueAndMatchPairs += value
+ result
+ } else {
+ // Same timestamp, accumulate values
+ val value = valueRowConverter.convertValue(unsafeRowPair.value)
+ valueAndMatchPairs += value
+
+ // Continue to next
+ getNext()
+ }
+ } else {
+ if (currentTs != -1L) {
+ assert(valueAndMatchPairs.nonEmpty)
+
+ // Return last batch
+ val result = reusableGetValuesResult.withNew(
+ currentTs, valueAndMatchPairs.toSeq)
+
+ // Mark as finished
+ currentTs = -1L
+ valueAndMatchPairs.clear()
+ result
+ } else {
+ finished = true
+ null
+ }
+ }
+ }
+
+ override protected def close(): Unit = iter.close()
+ }
+ }
+
+ def iterator(): Iterator[KeyAndTsToValuePair] = {
+ val iter = stateStore.iteratorWithMultiValues(colFamilyName)
+ val reusableKeyAndTsToValuePair = KeyAndTsToValuePair()
+ iter.map { kv =>
+ val keyRow = detachTimestampProjection(kv.key)
+ val ts = TimestampKeyStateEncoder.extractTimestamp(kv.key)
+ val value = valueRowConverter.convertValue(kv.value)
+
+ reusableKeyAndTsToValuePair.withNew(keyRow, ts, value)
+ }
+ }
+ }
+
+ private class TsWithKeyTypeStore {
Review Comment:
Self review: add code comment
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala:
##########
@@ -913,6 +1473,30 @@ class SymmetricHashJoinStateManagerV1(
}
override def metrics: StateStoreMetrics = {
+ // FIXME: purposed for benchmarking
Review Comment:
Self review: revert it
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala:
##########
@@ -27,16 +27,613 @@ import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
import org.apache.spark.internal.LogKeys.{END_INDEX, START_INDEX,
STATE_STORE_ID}
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute,
AttributeReference, Expression, JoinedRow, Literal, SafeProjection,
SpecificInternalRow, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.{Attribute,
AttributeReference, Expression, JoinedRow, Literal, NamedExpression,
SafeProjection, SpecificInternalRow, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.execution.metric.SQLMetric
-import
org.apache.spark.sql.execution.streaming.operators.stateful.StatefulOperatorStateInfo
-import
org.apache.spark.sql.execution.streaming.operators.stateful.StatefulOpStateStoreCheckpointInfo
+import
org.apache.spark.sql.execution.streaming.operators.stateful.{StatefulOperatorStateInfo,
StatefulOpStateStoreCheckpointInfo, WatermarkSupport}
import
org.apache.spark.sql.execution.streaming.operators.stateful.join.StreamingSymmetricHashJoinHelper._
-import
org.apache.spark.sql.execution.streaming.state.{DropLastNFieldsStatePartitionKeyExtractor,
KeyStateEncoderSpec, NoopStatePartitionKeyExtractor,
NoPrefixKeyStateEncoderSpec, StatePartitionKeyExtractor, StateSchemaBroadcast,
StateStore, StateStoreCheckpointInfo, StateStoreColFamilySchema,
StateStoreConf, StateStoreErrors, StateStoreId, StateStoreMetrics,
StateStoreProvider, StateStoreProviderId, SupportsFineGrainedReplay}
-import org.apache.spark.sql.types.{BooleanType, LongType, StructField,
StructType}
+import
org.apache.spark.sql.execution.streaming.state.{DropLastNFieldsStatePartitionKeyExtractor,
KeyStateEncoderSpec, NoopStatePartitionKeyExtractor,
NoPrefixKeyStateEncoderSpec, StatePartitionKeyExtractor, StateSchemaBroadcast,
StateStore, StateStoreCheckpointInfo, StateStoreColFamilySchema,
StateStoreConf, StateStoreErrors, StateStoreId, StateStoreMetrics,
StateStoreProvider, StateStoreProviderId, SupportsFineGrainedReplay,
TimestampAsPostfixKeyStateEncoderSpec, TimestampAsPrefixKeyStateEncoderSpec,
TimestampKeyStateEncoder}
+import org.apache.spark.sql.types.{BooleanType, DataType, LongType, NullType,
StructField, StructType}
import org.apache.spark.util.NextIterator
+trait SymmetricHashJoinStateManager {
+ import SymmetricHashJoinStateManager._
+
+ def append(key: UnsafeRow, value: UnsafeRow, matched: Boolean): Unit
+
+ def get(key: UnsafeRow): Iterator[UnsafeRow]
+
+ def getJoinedRows(
+ key: UnsafeRow,
+ generateJoinedRow: InternalRow => JoinedRow,
+ predicate: JoinedRow => Boolean,
+ excludeRowsAlreadyMatched: Boolean = false): Iterator[JoinedRow]
+
+ def iterator: Iterator[KeyToValuePair]
+
+ def commit(): Unit
+
+ def abortIfNeeded(): Unit
+
+ def metrics: StateStoreMetrics
+
+ def getLatestCheckpointInfo(): JoinerStateStoreCkptInfo
+}
+
+trait SupportsIndexedKeys {
+ def getInternalRowOfKeyWithIndex(currentKey: UnsafeRow): InternalRow
+
+ protected[streaming] def updateNumValuesTestOnly(key: UnsafeRow, numValues:
Long): Unit
+}
+
+trait SupportsEvictByCondition { self: SymmetricHashJoinStateManager =>
+ import SymmetricHashJoinStateManager._
+
+ def evictByKeyCondition(removalCondition: UnsafeRow => Boolean): Long
+
+ def evictAndReturnByKeyCondition(
+ removalCondition: UnsafeRow => Boolean): Iterator[KeyToValuePair]
+
+ def evictByValueCondition(removalCondition: UnsafeRow => Boolean): Long
+
+ def evictAndReturnByValueCondition(
+ removalCondition: UnsafeRow => Boolean): Iterator[KeyToValuePair]
+}
+
+trait SupportsEvictByTimestamp { self: SymmetricHashJoinStateManager =>
+ import SymmetricHashJoinStateManager._
+
+ def evictByTimestamp(endTimestamp: Long): Long
+
+ def evictAndReturnByTimestamp(endTimestamp: Long): Iterator[KeyToValuePair]
+}
+
+class SymmetricHashJoinStateManagerV4(
Review Comment:
Self review: add code comment
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala:
##########
@@ -27,16 +27,613 @@ import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
import org.apache.spark.internal.LogKeys.{END_INDEX, START_INDEX,
STATE_STORE_ID}
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute,
AttributeReference, Expression, JoinedRow, Literal, SafeProjection,
SpecificInternalRow, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.{Attribute,
AttributeReference, Expression, JoinedRow, Literal, NamedExpression,
SafeProjection, SpecificInternalRow, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.execution.metric.SQLMetric
-import
org.apache.spark.sql.execution.streaming.operators.stateful.StatefulOperatorStateInfo
-import
org.apache.spark.sql.execution.streaming.operators.stateful.StatefulOpStateStoreCheckpointInfo
+import
org.apache.spark.sql.execution.streaming.operators.stateful.{StatefulOperatorStateInfo,
StatefulOpStateStoreCheckpointInfo, WatermarkSupport}
import
org.apache.spark.sql.execution.streaming.operators.stateful.join.StreamingSymmetricHashJoinHelper._
-import
org.apache.spark.sql.execution.streaming.state.{DropLastNFieldsStatePartitionKeyExtractor,
KeyStateEncoderSpec, NoopStatePartitionKeyExtractor,
NoPrefixKeyStateEncoderSpec, StatePartitionKeyExtractor, StateSchemaBroadcast,
StateStore, StateStoreCheckpointInfo, StateStoreColFamilySchema,
StateStoreConf, StateStoreErrors, StateStoreId, StateStoreMetrics,
StateStoreProvider, StateStoreProviderId, SupportsFineGrainedReplay}
-import org.apache.spark.sql.types.{BooleanType, LongType, StructField,
StructType}
+import
org.apache.spark.sql.execution.streaming.state.{DropLastNFieldsStatePartitionKeyExtractor,
KeyStateEncoderSpec, NoopStatePartitionKeyExtractor,
NoPrefixKeyStateEncoderSpec, StatePartitionKeyExtractor, StateSchemaBroadcast,
StateStore, StateStoreCheckpointInfo, StateStoreColFamilySchema,
StateStoreConf, StateStoreErrors, StateStoreId, StateStoreMetrics,
StateStoreProvider, StateStoreProviderId, SupportsFineGrainedReplay,
TimestampAsPostfixKeyStateEncoderSpec, TimestampAsPrefixKeyStateEncoderSpec,
TimestampKeyStateEncoder}
+import org.apache.spark.sql.types.{BooleanType, DataType, LongType, NullType,
StructField, StructType}
import org.apache.spark.util.NextIterator
+trait SymmetricHashJoinStateManager {
+ import SymmetricHashJoinStateManager._
+
+ def append(key: UnsafeRow, value: UnsafeRow, matched: Boolean): Unit
+
+ def get(key: UnsafeRow): Iterator[UnsafeRow]
+
+ def getJoinedRows(
+ key: UnsafeRow,
+ generateJoinedRow: InternalRow => JoinedRow,
+ predicate: JoinedRow => Boolean,
+ excludeRowsAlreadyMatched: Boolean = false): Iterator[JoinedRow]
+
+ def iterator: Iterator[KeyToValuePair]
+
+ def commit(): Unit
+
+ def abortIfNeeded(): Unit
+
+ def metrics: StateStoreMetrics
+
+ def getLatestCheckpointInfo(): JoinerStateStoreCkptInfo
+}
+
+trait SupportsIndexedKeys {
+ def getInternalRowOfKeyWithIndex(currentKey: UnsafeRow): InternalRow
+
+ protected[streaming] def updateNumValuesTestOnly(key: UnsafeRow, numValues:
Long): Unit
+}
+
+trait SupportsEvictByCondition { self: SymmetricHashJoinStateManager =>
+ import SymmetricHashJoinStateManager._
+
+ def evictByKeyCondition(removalCondition: UnsafeRow => Boolean): Long
+
+ def evictAndReturnByKeyCondition(
+ removalCondition: UnsafeRow => Boolean): Iterator[KeyToValuePair]
+
+ def evictByValueCondition(removalCondition: UnsafeRow => Boolean): Long
+
+ def evictAndReturnByValueCondition(
+ removalCondition: UnsafeRow => Boolean): Iterator[KeyToValuePair]
+}
+
+trait SupportsEvictByTimestamp { self: SymmetricHashJoinStateManager =>
Review Comment:
Self review: add code comment
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinValueRowConverter.scala:
##########
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.operators.stateful.join
+
+import org.apache.spark.sql.catalyst.expressions.{Attribute,
AttributeReference, Literal, UnsafeProjection, UnsafeRow}
+import
org.apache.spark.sql.execution.streaming.operators.stateful.join.SymmetricHashJoinStateManager.ValueAndMatchPair
+import org.apache.spark.sql.types.BooleanType
+
+/**
+ * Converter between the value row stored in state store and the (actual
value, match) pair.
+ */
+trait StreamingSymmetricHashJoinValueRowConverter {
+ /** Defines the schema of the value row (the value side of K-V in state
store). */
+ def valueAttributes: Seq[Attribute]
+
+ /**
+ * Convert the value row to (actual value, match) pair.
+ *
+ * NOTE: implementations should ensure the result row is NOT reused during
execution, so
+ * that caller can safely read the value in any time.
+ */
+ def convertValue(value: UnsafeRow): ValueAndMatchPair
+
+ /**
+ * Build the value row from (actual value, match) pair. This is expected to
be called just
+ * before storing to the state store.
+ *
+ * NOTE: depending on the implementation, the result row "may" be reused
during execution
+ * (to avoid initialization of object), so the caller should ensure that the
logic doesn't
+ * affect by such behavior. Call copy() against the result row if needed.
+ */
+ def convertToValueRow(value: UnsafeRow, matched: Boolean): UnsafeRow
+}
+
+class StreamingSymmetricHashJoinValueRowConverterFormatV1(
+ inputValueAttributes: Seq[Attribute]) extends
StreamingSymmetricHashJoinValueRowConverter {
+ override val valueAttributes: Seq[Attribute] = inputValueAttributes
+
+ override def convertValue(value: UnsafeRow): ValueAndMatchPair = {
+ if (value != null) ValueAndMatchPair(value, false) else null
+ }
+
+ override def convertToValueRow(value: UnsafeRow, matched: Boolean):
UnsafeRow = value
+}
+
+class StreamingSymmetricHashJoinValueRowConverterFormatV2(
+ inputValueAttributes: Seq[Attribute]) extends
StreamingSymmetricHashJoinValueRowConverter {
+ private val valueWithMatchedExprs = inputValueAttributes :+ Literal(true)
+ private val indexOrdinalInValueWithMatchedRow = inputValueAttributes.size
+
+ private val valueWithMatchedRowGenerator =
UnsafeProjection.create(valueWithMatchedExprs,
+ inputValueAttributes)
+
+ override val valueAttributes: Seq[Attribute] = inputValueAttributes :+
+ AttributeReference("matched", BooleanType)()
+
+ // Projection to generate key row from (value + matched) row
+ private val valueRowGenerator = UnsafeProjection.create(
+ inputValueAttributes, valueAttributes)
+
+ override def convertValue(value: UnsafeRow): ValueAndMatchPair = {
+ if (value != null) {
+ ValueAndMatchPair(valueRowGenerator(value).copy(),
+ value.getBoolean(indexOrdinalInValueWithMatchedRow))
+ } else {
+ null
+ }
+ }
+
+ override def convertToValueRow(value: UnsafeRow, matched: Boolean):
UnsafeRow = {
+ val row = valueWithMatchedRowGenerator(value)
+ row.setBoolean(indexOrdinalInValueWithMatchedRow, matched)
+ row
+ }
+}
+
+object StreamingSymmetricHashJoinValueRowConverter {
Review Comment:
Self review: add code comment
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinValueRowConverter.scala:
##########
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.operators.stateful.join
+
+import org.apache.spark.sql.catalyst.expressions.{Attribute,
AttributeReference, Literal, UnsafeProjection, UnsafeRow}
+import
org.apache.spark.sql.execution.streaming.operators.stateful.join.SymmetricHashJoinStateManager.ValueAndMatchPair
+import org.apache.spark.sql.types.BooleanType
+
+/**
+ * Converter between the value row stored in state store and the (actual
value, match) pair.
+ */
+trait StreamingSymmetricHashJoinValueRowConverter {
+ /** Defines the schema of the value row (the value side of K-V in state
store). */
+ def valueAttributes: Seq[Attribute]
+
+ /**
+ * Convert the value row to (actual value, match) pair.
+ *
+ * NOTE: implementations should ensure the result row is NOT reused during
execution, so
+ * that caller can safely read the value in any time.
+ */
+ def convertValue(value: UnsafeRow): ValueAndMatchPair
+
+ /**
+ * Build the value row from (actual value, match) pair. This is expected to
be called just
+ * before storing to the state store.
+ *
+ * NOTE: depending on the implementation, the result row "may" be reused
during execution
+ * (to avoid initialization of object), so the caller should ensure that the
logic doesn't
+ * affect by such behavior. Call copy() against the result row if needed.
+ */
+ def convertToValueRow(value: UnsafeRow, matched: Boolean): UnsafeRow
+}
+
+class StreamingSymmetricHashJoinValueRowConverterFormatV1(
+ inputValueAttributes: Seq[Attribute]) extends
StreamingSymmetricHashJoinValueRowConverter {
+ override val valueAttributes: Seq[Attribute] = inputValueAttributes
+
+ override def convertValue(value: UnsafeRow): ValueAndMatchPair = {
+ if (value != null) ValueAndMatchPair(value, false) else null
+ }
+
+ override def convertToValueRow(value: UnsafeRow, matched: Boolean):
UnsafeRow = value
+}
+
+class StreamingSymmetricHashJoinValueRowConverterFormatV2(
Review Comment:
Self review: add code comment
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala:
##########
@@ -27,16 +27,613 @@ import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
import org.apache.spark.internal.LogKeys.{END_INDEX, START_INDEX,
STATE_STORE_ID}
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute,
AttributeReference, Expression, JoinedRow, Literal, SafeProjection,
SpecificInternalRow, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.{Attribute,
AttributeReference, Expression, JoinedRow, Literal, NamedExpression,
SafeProjection, SpecificInternalRow, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.execution.metric.SQLMetric
-import
org.apache.spark.sql.execution.streaming.operators.stateful.StatefulOperatorStateInfo
-import
org.apache.spark.sql.execution.streaming.operators.stateful.StatefulOpStateStoreCheckpointInfo
+import
org.apache.spark.sql.execution.streaming.operators.stateful.{StatefulOperatorStateInfo,
StatefulOpStateStoreCheckpointInfo, WatermarkSupport}
import
org.apache.spark.sql.execution.streaming.operators.stateful.join.StreamingSymmetricHashJoinHelper._
-import
org.apache.spark.sql.execution.streaming.state.{DropLastNFieldsStatePartitionKeyExtractor,
KeyStateEncoderSpec, NoopStatePartitionKeyExtractor,
NoPrefixKeyStateEncoderSpec, StatePartitionKeyExtractor, StateSchemaBroadcast,
StateStore, StateStoreCheckpointInfo, StateStoreColFamilySchema,
StateStoreConf, StateStoreErrors, StateStoreId, StateStoreMetrics,
StateStoreProvider, StateStoreProviderId, SupportsFineGrainedReplay}
-import org.apache.spark.sql.types.{BooleanType, LongType, StructField,
StructType}
+import
org.apache.spark.sql.execution.streaming.state.{DropLastNFieldsStatePartitionKeyExtractor,
KeyStateEncoderSpec, NoopStatePartitionKeyExtractor,
NoPrefixKeyStateEncoderSpec, StatePartitionKeyExtractor, StateSchemaBroadcast,
StateStore, StateStoreCheckpointInfo, StateStoreColFamilySchema,
StateStoreConf, StateStoreErrors, StateStoreId, StateStoreMetrics,
StateStoreProvider, StateStoreProviderId, SupportsFineGrainedReplay,
TimestampAsPostfixKeyStateEncoderSpec, TimestampAsPrefixKeyStateEncoderSpec,
TimestampKeyStateEncoder}
+import org.apache.spark.sql.types.{BooleanType, DataType, LongType, NullType,
StructField, StructType}
import org.apache.spark.util.NextIterator
+trait SymmetricHashJoinStateManager {
+ import SymmetricHashJoinStateManager._
+
+ def append(key: UnsafeRow, value: UnsafeRow, matched: Boolean): Unit
+
+ def get(key: UnsafeRow): Iterator[UnsafeRow]
+
+ def getJoinedRows(
+ key: UnsafeRow,
+ generateJoinedRow: InternalRow => JoinedRow,
+ predicate: JoinedRow => Boolean,
+ excludeRowsAlreadyMatched: Boolean = false): Iterator[JoinedRow]
+
+ def iterator: Iterator[KeyToValuePair]
+
+ def commit(): Unit
+
+ def abortIfNeeded(): Unit
+
+ def metrics: StateStoreMetrics
+
+ def getLatestCheckpointInfo(): JoinerStateStoreCkptInfo
+}
+
+trait SupportsIndexedKeys {
+ def getInternalRowOfKeyWithIndex(currentKey: UnsafeRow): InternalRow
+
+ protected[streaming] def updateNumValuesTestOnly(key: UnsafeRow, numValues:
Long): Unit
+}
+
+trait SupportsEvictByCondition { self: SymmetricHashJoinStateManager =>
+ import SymmetricHashJoinStateManager._
+
+ def evictByKeyCondition(removalCondition: UnsafeRow => Boolean): Long
+
+ def evictAndReturnByKeyCondition(
+ removalCondition: UnsafeRow => Boolean): Iterator[KeyToValuePair]
+
+ def evictByValueCondition(removalCondition: UnsafeRow => Boolean): Long
+
+ def evictAndReturnByValueCondition(
+ removalCondition: UnsafeRow => Boolean): Iterator[KeyToValuePair]
+}
+
+trait SupportsEvictByTimestamp { self: SymmetricHashJoinStateManager =>
+ import SymmetricHashJoinStateManager._
+
+ def evictByTimestamp(endTimestamp: Long): Long
+
+ def evictAndReturnByTimestamp(endTimestamp: Long): Iterator[KeyToValuePair]
+}
+
+class SymmetricHashJoinStateManagerV4(
+ joinSide: JoinSide,
+ inputValueAttributes: Seq[Attribute],
+ joinKeys: Seq[Expression],
+ stateInfo: Option[StatefulOperatorStateInfo],
+ storeConf: StateStoreConf,
+ hadoopConf: Configuration,
+ partitionId: Int,
+ keyToNumValuesStateStoreCkptId: Option[String],
+ keyWithIndexToValueStateStoreCkptId: Option[String],
+ stateFormatVersion: Int,
+ skippedNullValueCount: Option[SQLMetric] = None,
+ useStateStoreCoordinator: Boolean = true,
+ snapshotOptions: Option[SnapshotOptions] = None,
+ joinStoreGenerator: JoinStateManagerStoreGenerator)
+ extends SymmetricHashJoinStateManager with SupportsEvictByTimestamp with
Logging {
+
+ import SymmetricHashJoinStateManager._
+
+ protected val keySchema = StructType(
+ joinKeys.zipWithIndex.map { case (k, i) => StructField(s"field$i",
k.dataType, k.nullable) })
+ protected val keyAttributes = toAttributes(keySchema)
+ private val eventTimeColIdxOpt = WatermarkSupport.findEventTimeColumnIndex(
+ inputValueAttributes,
+ // NOTE: This does not accept multiple event time columns. This is not the
same with the
+ // operator which we offer the backward compatibility, but it involves too
many layers to
+ // pass the information. The information is in SQLConf.
+ allowMultipleEventTimeColumns = false)
+
+ private val random = new scala.util.Random(System.currentTimeMillis())
Review Comment:
Self review: add code comment - probably better to explain more?
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala:
##########
@@ -27,16 +27,613 @@ import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
import org.apache.spark.internal.LogKeys.{END_INDEX, START_INDEX,
STATE_STORE_ID}
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute,
AttributeReference, Expression, JoinedRow, Literal, SafeProjection,
SpecificInternalRow, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.{Attribute,
AttributeReference, Expression, JoinedRow, Literal, NamedExpression,
SafeProjection, SpecificInternalRow, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.execution.metric.SQLMetric
-import
org.apache.spark.sql.execution.streaming.operators.stateful.StatefulOperatorStateInfo
-import
org.apache.spark.sql.execution.streaming.operators.stateful.StatefulOpStateStoreCheckpointInfo
+import
org.apache.spark.sql.execution.streaming.operators.stateful.{StatefulOperatorStateInfo,
StatefulOpStateStoreCheckpointInfo, WatermarkSupport}
import
org.apache.spark.sql.execution.streaming.operators.stateful.join.StreamingSymmetricHashJoinHelper._
-import
org.apache.spark.sql.execution.streaming.state.{DropLastNFieldsStatePartitionKeyExtractor,
KeyStateEncoderSpec, NoopStatePartitionKeyExtractor,
NoPrefixKeyStateEncoderSpec, StatePartitionKeyExtractor, StateSchemaBroadcast,
StateStore, StateStoreCheckpointInfo, StateStoreColFamilySchema,
StateStoreConf, StateStoreErrors, StateStoreId, StateStoreMetrics,
StateStoreProvider, StateStoreProviderId, SupportsFineGrainedReplay}
-import org.apache.spark.sql.types.{BooleanType, LongType, StructField,
StructType}
+import
org.apache.spark.sql.execution.streaming.state.{DropLastNFieldsStatePartitionKeyExtractor,
KeyStateEncoderSpec, NoopStatePartitionKeyExtractor,
NoPrefixKeyStateEncoderSpec, StatePartitionKeyExtractor, StateSchemaBroadcast,
StateStore, StateStoreCheckpointInfo, StateStoreColFamilySchema,
StateStoreConf, StateStoreErrors, StateStoreId, StateStoreMetrics,
StateStoreProvider, StateStoreProviderId, SupportsFineGrainedReplay,
TimestampAsPostfixKeyStateEncoderSpec, TimestampAsPrefixKeyStateEncoderSpec,
TimestampKeyStateEncoder}
+import org.apache.spark.sql.types.{BooleanType, DataType, LongType, NullType,
StructField, StructType}
import org.apache.spark.util.NextIterator
+trait SymmetricHashJoinStateManager {
Review Comment:
Self review: add code comment
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala:
##########
@@ -27,16 +27,613 @@ import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
import org.apache.spark.internal.LogKeys.{END_INDEX, START_INDEX,
STATE_STORE_ID}
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute,
AttributeReference, Expression, JoinedRow, Literal, SafeProjection,
SpecificInternalRow, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.{Attribute,
AttributeReference, Expression, JoinedRow, Literal, NamedExpression,
SafeProjection, SpecificInternalRow, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.execution.metric.SQLMetric
-import
org.apache.spark.sql.execution.streaming.operators.stateful.StatefulOperatorStateInfo
-import
org.apache.spark.sql.execution.streaming.operators.stateful.StatefulOpStateStoreCheckpointInfo
+import
org.apache.spark.sql.execution.streaming.operators.stateful.{StatefulOperatorStateInfo,
StatefulOpStateStoreCheckpointInfo, WatermarkSupport}
import
org.apache.spark.sql.execution.streaming.operators.stateful.join.StreamingSymmetricHashJoinHelper._
-import
org.apache.spark.sql.execution.streaming.state.{DropLastNFieldsStatePartitionKeyExtractor,
KeyStateEncoderSpec, NoopStatePartitionKeyExtractor,
NoPrefixKeyStateEncoderSpec, StatePartitionKeyExtractor, StateSchemaBroadcast,
StateStore, StateStoreCheckpointInfo, StateStoreColFamilySchema,
StateStoreConf, StateStoreErrors, StateStoreId, StateStoreMetrics,
StateStoreProvider, StateStoreProviderId, SupportsFineGrainedReplay}
-import org.apache.spark.sql.types.{BooleanType, LongType, StructField,
StructType}
+import
org.apache.spark.sql.execution.streaming.state.{DropLastNFieldsStatePartitionKeyExtractor,
KeyStateEncoderSpec, NoopStatePartitionKeyExtractor,
NoPrefixKeyStateEncoderSpec, StatePartitionKeyExtractor, StateSchemaBroadcast,
StateStore, StateStoreCheckpointInfo, StateStoreColFamilySchema,
StateStoreConf, StateStoreErrors, StateStoreId, StateStoreMetrics,
StateStoreProvider, StateStoreProviderId, SupportsFineGrainedReplay,
TimestampAsPostfixKeyStateEncoderSpec, TimestampAsPrefixKeyStateEncoderSpec,
TimestampKeyStateEncoder}
+import org.apache.spark.sql.types.{BooleanType, DataType, LongType, NullType,
StructField, StructType}
import org.apache.spark.util.NextIterator
+trait SymmetricHashJoinStateManager {
+ import SymmetricHashJoinStateManager._
+
+ def append(key: UnsafeRow, value: UnsafeRow, matched: Boolean): Unit
+
+ def get(key: UnsafeRow): Iterator[UnsafeRow]
+
+ def getJoinedRows(
+ key: UnsafeRow,
+ generateJoinedRow: InternalRow => JoinedRow,
+ predicate: JoinedRow => Boolean,
+ excludeRowsAlreadyMatched: Boolean = false): Iterator[JoinedRow]
+
+ def iterator: Iterator[KeyToValuePair]
+
+ def commit(): Unit
+
+ def abortIfNeeded(): Unit
+
+ def metrics: StateStoreMetrics
+
+ def getLatestCheckpointInfo(): JoinerStateStoreCkptInfo
+}
+
+trait SupportsIndexedKeys {
+ def getInternalRowOfKeyWithIndex(currentKey: UnsafeRow): InternalRow
+
+ protected[streaming] def updateNumValuesTestOnly(key: UnsafeRow, numValues:
Long): Unit
+}
+
+trait SupportsEvictByCondition { self: SymmetricHashJoinStateManager =>
Review Comment:
Self review: add code comment
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinValueRowConverter.scala:
##########
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.operators.stateful.join
+
+import org.apache.spark.sql.catalyst.expressions.{Attribute,
AttributeReference, Literal, UnsafeProjection, UnsafeRow}
+import
org.apache.spark.sql.execution.streaming.operators.stateful.join.SymmetricHashJoinStateManager.ValueAndMatchPair
+import org.apache.spark.sql.types.BooleanType
+
+/**
+ * Converter between the value row stored in state store and the (actual
value, match) pair.
+ */
+trait StreamingSymmetricHashJoinValueRowConverter {
+ /** Defines the schema of the value row (the value side of K-V in state
store). */
+ def valueAttributes: Seq[Attribute]
+
+ /**
+ * Convert the value row to (actual value, match) pair.
+ *
+ * NOTE: implementations should ensure the result row is NOT reused during
execution, so
+ * that caller can safely read the value in any time.
+ */
+ def convertValue(value: UnsafeRow): ValueAndMatchPair
+
+ /**
+ * Build the value row from (actual value, match) pair. This is expected to
be called just
+ * before storing to the state store.
+ *
+ * NOTE: depending on the implementation, the result row "may" be reused
during execution
+ * (to avoid initialization of object), so the caller should ensure that the
logic doesn't
+ * affect by such behavior. Call copy() against the result row if needed.
+ */
+ def convertToValueRow(value: UnsafeRow, matched: Boolean): UnsafeRow
+}
+
+class StreamingSymmetricHashJoinValueRowConverterFormatV1(
Review Comment:
Self review: add code comment
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala:
##########
@@ -1277,7 +1881,8 @@ object SymmetricHashJoinStateManager {
storeName == getStateStoreName(RightSide, KeyWithIndexToValueType)) {
KeyWithIndexToValueType
} else {
- throw new IllegalArgumentException(s"Unknown join store name:
$storeName")
+ // TODO: Add support of KeyWithTsToValuesType and TsWithKeyType
Review Comment:
Self review: may need to have TODO JIRA ticket?
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala:
##########
@@ -1289,15 +1894,18 @@ object SymmetricHashJoinStateManager {
colFamilyName: String,
stateKeySchema: StructType,
stateFormatVersion: Int): StatePartitionKeyExtractor = {
- assert(stateFormatVersion <= 3, "State format version must be less than or
equal to 3")
- val name = if (stateFormatVersion == 3) colFamilyName else storeName
+ assert(stateFormatVersion <= 4, "State format version must be less than or
equal to 4")
+ val name = if (stateFormatVersion >= 3) colFamilyName else storeName
if (getStoreType(name) == KeyWithIndexToValueType) {
// For KeyWithIndex, the index is added to the join (i.e. partition) key.
// Drop the last field (index) to get the partition key
new DropLastNFieldsStatePartitionKeyExtractor(stateKeySchema,
numLastColsToDrop = 1)
- } else {
+ } else if (getStoreType(name) == KeyToNumValuesType) {
// State key is the partition key
new NoopStatePartitionKeyExtractor(stateKeySchema)
+ } else {
+ // TODO: Add support of KeyWithTsToValuesType and TsWithKeyType
Review Comment:
Self review: may need to have TODO JIRA ticket?
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala:
##########
@@ -1331,6 +1939,35 @@ object SymmetricHashJoinStateManager {
this
}
}
+
+ case class KeyAndTsToValuePair(
Review Comment:
Self review: add code comment
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala:
##########
@@ -926,6 +1510,7 @@ class SymmetricHashJoinStateManagerV1(
// We want to collect instance metrics from both state stores
keyWithIndexToValueMetrics.instanceMetrics ++
keyToNumValuesMetrics.instanceMetrics
)
+ */
Review Comment:
Self review: revert it
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala:
##########
@@ -27,16 +27,613 @@ import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
import org.apache.spark.internal.LogKeys.{END_INDEX, START_INDEX,
STATE_STORE_ID}
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute,
AttributeReference, Expression, JoinedRow, Literal, SafeProjection,
SpecificInternalRow, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.{Attribute,
AttributeReference, Expression, JoinedRow, Literal, NamedExpression,
SafeProjection, SpecificInternalRow, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.execution.metric.SQLMetric
-import
org.apache.spark.sql.execution.streaming.operators.stateful.StatefulOperatorStateInfo
-import
org.apache.spark.sql.execution.streaming.operators.stateful.StatefulOpStateStoreCheckpointInfo
+import
org.apache.spark.sql.execution.streaming.operators.stateful.{StatefulOperatorStateInfo,
StatefulOpStateStoreCheckpointInfo, WatermarkSupport}
import
org.apache.spark.sql.execution.streaming.operators.stateful.join.StreamingSymmetricHashJoinHelper._
-import
org.apache.spark.sql.execution.streaming.state.{DropLastNFieldsStatePartitionKeyExtractor,
KeyStateEncoderSpec, NoopStatePartitionKeyExtractor,
NoPrefixKeyStateEncoderSpec, StatePartitionKeyExtractor, StateSchemaBroadcast,
StateStore, StateStoreCheckpointInfo, StateStoreColFamilySchema,
StateStoreConf, StateStoreErrors, StateStoreId, StateStoreMetrics,
StateStoreProvider, StateStoreProviderId, SupportsFineGrainedReplay}
-import org.apache.spark.sql.types.{BooleanType, LongType, StructField,
StructType}
+import
org.apache.spark.sql.execution.streaming.state.{DropLastNFieldsStatePartitionKeyExtractor,
KeyStateEncoderSpec, NoopStatePartitionKeyExtractor,
NoPrefixKeyStateEncoderSpec, StatePartitionKeyExtractor, StateSchemaBroadcast,
StateStore, StateStoreCheckpointInfo, StateStoreColFamilySchema,
StateStoreConf, StateStoreErrors, StateStoreId, StateStoreMetrics,
StateStoreProvider, StateStoreProviderId, SupportsFineGrainedReplay,
TimestampAsPostfixKeyStateEncoderSpec, TimestampAsPrefixKeyStateEncoderSpec,
TimestampKeyStateEncoder}
+import org.apache.spark.sql.types.{BooleanType, DataType, LongType, NullType,
StructField, StructType}
import org.apache.spark.util.NextIterator
+trait SymmetricHashJoinStateManager {
+ import SymmetricHashJoinStateManager._
+
+ def append(key: UnsafeRow, value: UnsafeRow, matched: Boolean): Unit
+
+ def get(key: UnsafeRow): Iterator[UnsafeRow]
+
+ def getJoinedRows(
+ key: UnsafeRow,
+ generateJoinedRow: InternalRow => JoinedRow,
+ predicate: JoinedRow => Boolean,
+ excludeRowsAlreadyMatched: Boolean = false): Iterator[JoinedRow]
+
+ def iterator: Iterator[KeyToValuePair]
+
+ def commit(): Unit
+
+ def abortIfNeeded(): Unit
+
+ def metrics: StateStoreMetrics
+
+ def getLatestCheckpointInfo(): JoinerStateStoreCkptInfo
+}
+
+trait SupportsIndexedKeys {
+ def getInternalRowOfKeyWithIndex(currentKey: UnsafeRow): InternalRow
+
+ protected[streaming] def updateNumValuesTestOnly(key: UnsafeRow, numValues:
Long): Unit
+}
+
+trait SupportsEvictByCondition { self: SymmetricHashJoinStateManager =>
+ import SymmetricHashJoinStateManager._
+
+ def evictByKeyCondition(removalCondition: UnsafeRow => Boolean): Long
+
+ def evictAndReturnByKeyCondition(
+ removalCondition: UnsafeRow => Boolean): Iterator[KeyToValuePair]
+
+ def evictByValueCondition(removalCondition: UnsafeRow => Boolean): Long
+
+ def evictAndReturnByValueCondition(
+ removalCondition: UnsafeRow => Boolean): Iterator[KeyToValuePair]
+}
+
+trait SupportsEvictByTimestamp { self: SymmetricHashJoinStateManager =>
+ import SymmetricHashJoinStateManager._
+
+ def evictByTimestamp(endTimestamp: Long): Long
+
+ def evictAndReturnByTimestamp(endTimestamp: Long): Iterator[KeyToValuePair]
+}
+
+class SymmetricHashJoinStateManagerV4(
+ joinSide: JoinSide,
+ inputValueAttributes: Seq[Attribute],
+ joinKeys: Seq[Expression],
+ stateInfo: Option[StatefulOperatorStateInfo],
+ storeConf: StateStoreConf,
+ hadoopConf: Configuration,
+ partitionId: Int,
+ keyToNumValuesStateStoreCkptId: Option[String],
+ keyWithIndexToValueStateStoreCkptId: Option[String],
+ stateFormatVersion: Int,
+ skippedNullValueCount: Option[SQLMetric] = None,
+ useStateStoreCoordinator: Boolean = true,
+ snapshotOptions: Option[SnapshotOptions] = None,
+ joinStoreGenerator: JoinStateManagerStoreGenerator)
+ extends SymmetricHashJoinStateManager with SupportsEvictByTimestamp with
Logging {
+
+ import SymmetricHashJoinStateManager._
+
+ protected val keySchema = StructType(
+ joinKeys.zipWithIndex.map { case (k, i) => StructField(s"field$i",
k.dataType, k.nullable) })
+ protected val keyAttributes = toAttributes(keySchema)
+ private val eventTimeColIdxOpt = WatermarkSupport.findEventTimeColumnIndex(
+ inputValueAttributes,
+ // NOTE: This does not accept multiple event time columns. This is not the
same with the
+ // operator which we offer the backward compatibility, but it involves too
many layers to
+ // pass the information. The information is in SQLConf.
+ allowMultipleEventTimeColumns = false)
+
+ private val random = new scala.util.Random(System.currentTimeMillis())
+ private val bucketSizeForNoEventTime = 1024
+ private val extractEventTimeFn: UnsafeRow => Long = { row =>
+ eventTimeColIdxOpt match {
+ case Some(idx) =>
+ val attr = inputValueAttributes(idx)
+
+ if (attr.dataType.isInstanceOf[StructType]) {
+ // NOTE: We assume this is window struct, as same as
WatermarkSupport.watermarkExpression
+ row.getStruct(idx, 2).getLong(1)
+ } else {
+ row.getLong(idx)
+ }
+
+ case _ =>
+ // Need a strategy about bucketing when event time is not available
+ // - first attempt: random bucketing
+ random.nextInt(bucketSizeForNoEventTime)
+ }
+ }
+
+ private val eventTimeColIdxOptInKey: Option[Int] = {
+ joinKeys.zipWithIndex.collectFirst {
+ case (ne: NamedExpression, index)
+ if ne.metadata.contains(EventTimeWatermark.delayKey) => index
+ }
+ }
+
+ private val extractEventTimeFnFromKey: UnsafeRow => Option[Long] = { row =>
+ eventTimeColIdxOptInKey.map { idx =>
+ val attr = keyAttributes(idx)
+ if (attr.dataType.isInstanceOf[StructType]) {
+ // NOTE: We assume this is window struct, as same as
WatermarkSupport.watermarkExpression
+ row.getStruct(idx, 2).getLong(1)
+ } else {
+ row.getLong(idx)
+ }
+ }
+ }
+
+ private val dummySchema = StructType(
+ Seq(StructField("dummy", NullType, nullable = true))
+ )
+
+ private val stateStoreCkptId: Option[String] = None
+ private val handlerSnapshotOptions: Option[HandlerSnapshotOptions] = None
+ private var stateStoreProvider: StateStoreProvider = _
+
+ // We will use the dummy schema for the default CF since we will register CF
separately.
+ private val stateStore = getStateStore(
+ dummySchema, dummySchema, useVirtualColumnFamilies = true,
+ NoPrefixKeyStateEncoderSpec(dummySchema), useMultipleValuesPerKey = false
+ )
+
+ private def getStateStore(
+ keySchema: StructType,
+ valueSchema: StructType,
+ useVirtualColumnFamilies: Boolean,
+ keyStateEncoderSpec: KeyStateEncoderSpec,
+ useMultipleValuesPerKey: Boolean): StateStore = {
+ val storeName = StateStoreId.DEFAULT_STORE_NAME
+ val storeProviderId = StateStoreProviderId(stateInfo.get, partitionId,
storeName)
+ val store = if (useStateStoreCoordinator) {
+ assert(handlerSnapshotOptions.isEmpty, "Should not use state store
coordinator " +
+ "when reading state as data source.")
+ joinStoreGenerator.getStore(
+ storeProviderId, keySchema, valueSchema, keyStateEncoderSpec,
+ stateInfo.get.storeVersion, stateStoreCkptId, None,
useVirtualColumnFamilies,
+ useMultipleValuesPerKey, storeConf, hadoopConf)
+ } else {
+ // This class will manage the state store provider by itself.
+ stateStoreProvider = StateStoreProvider.createAndInit(
+ storeProviderId, keySchema, valueSchema, keyStateEncoderSpec,
+ useColumnFamilies = useVirtualColumnFamilies,
+ storeConf, hadoopConf, useMultipleValuesPerKey =
useMultipleValuesPerKey,
+ stateSchemaProvider = None)
+ if (handlerSnapshotOptions.isDefined) {
+ if (!stateStoreProvider.isInstanceOf[SupportsFineGrainedReplay]) {
+ throw
StateStoreErrors.stateStoreProviderDoesNotSupportFineGrainedReplay(
+ stateStoreProvider.getClass.toString)
+ }
+ val opts = handlerSnapshotOptions.get
+ stateStoreProvider.asInstanceOf[SupportsFineGrainedReplay]
+ .replayStateFromSnapshot(
+ opts.snapshotVersion,
+ opts.endVersion,
+ readOnly = true,
+ opts.startStateStoreCkptId,
+ opts.endStateStoreCkptId)
+ } else {
+ stateStoreProvider.getStore(stateInfo.get.storeVersion,
stateStoreCkptId)
+ }
+ }
+ logInfo(log"Loaded store ${MDC(STATE_STORE_ID, store.id)}")
+ store
+ }
+
+ private val keyWithTsToValues = new KeyWithTsToValuesStore
+
+ private val tsWithKey = new TsWithKeyTypeStore
+
+ override def append(key: UnsafeRow, value: UnsafeRow, matched: Boolean):
Unit = {
+ val eventTime = extractEventTimeFn(value)
+ // We always do blind merge for appending new value.
+ keyWithTsToValues.append(key, eventTime, value, matched)
+ tsWithKey.add(eventTime, key)
+ }
+
+ override def getJoinedRows(
+ key: UnsafeRow,
+ generateJoinedRow: InternalRow => JoinedRow,
+ predicate: JoinedRow => Boolean,
+ excludeRowsAlreadyMatched: Boolean): Iterator[JoinedRow] = {
+ // TODO: We could improve this method to get the scope of timestamp and
scan keys
+ // more efficiently. For now, we just get all values for the key.
+
+ def getJoinedRowsFromTsAndValues(
+ ts: Long,
+ valuesAndMatched: Array[ValueAndMatchPair]): Iterator[JoinedRow] = {
+ new NextIterator[JoinedRow] {
+ private var currentIndex = 0
+
+ private var shouldUpdateValuesIntoStateStore = false
+
+ override protected def getNext(): JoinedRow = {
+ var ret: JoinedRow = null
+ while (ret == null && currentIndex < valuesAndMatched.length) {
+ val vmp = valuesAndMatched(currentIndex)
+
+ if (excludeRowsAlreadyMatched && vmp.matched) {
+ // Skip this one
+ } else {
+ val joinedRow = generateJoinedRow(vmp.value)
+ if (predicate(joinedRow)) {
+ if (!vmp.matched) {
+ // Update the array to contain the value having matched =
true
+ valuesAndMatched(currentIndex) = vmp.copy(matched = true)
+ // Need to update matched flag
+ shouldUpdateValuesIntoStateStore = true
+ }
+
+ ret = joinedRow
+ } else {
+ // skip this one
+ }
+ }
+
+ currentIndex += 1
+ }
+
+ if (ret == null) {
+ assert(currentIndex == valuesAndMatched.length)
+ finished = true
+ null
+ } else {
+ ret
+ }
+ }
+
+ override protected def close(): Unit = {
+ if (shouldUpdateValuesIntoStateStore) {
+ // Update back to the state store
+ val updatedValuesWithMatched = valuesAndMatched.map { vmp =>
+ (vmp.value, vmp.matched)
+ }.toSeq
+ keyWithTsToValues.put(key, ts, updatedValuesWithMatched)
+ }
+ }
+ }
+ }
+
+ val ret = extractEventTimeFnFromKey(key) match {
+ case Some(ts) =>
+ val valuesAndMatchedIter = keyWithTsToValues.get(key, ts)
+ getJoinedRowsFromTsAndValues(ts, valuesAndMatchedIter.toArray)
+
+ case _ =>
+ keyWithTsToValues.getValues(key).flatMap { result =>
+ val ts = result.timestamp
+ val valuesAndMatched = result.values.toArray
+ getJoinedRowsFromTsAndValues(ts, valuesAndMatched)
+ }
+ }
+ ret.filter(_ != null)
+ }
+
+ override def iterator: Iterator[KeyToValuePair] = {
+ val reusableKeyToValuePair = KeyToValuePair()
+ keyWithTsToValues.iterator().map { kv =>
+ reusableKeyToValuePair.withNew(kv.key, kv.value, kv.matched)
+ }
+ }
+
+ override def evictByTimestamp(endTimestamp: Long): Long = {
+ var removed = 0L
+ tsWithKey.scanEvictedKeys(endTimestamp).foreach { evicted =>
+ val key = evicted.key
+ val timestamp = evicted.timestamp
+ val numValues = evicted.numValues
+
+ // Remove from both primary and secondary stores
+ keyWithTsToValues.remove(key, timestamp)
+ tsWithKey.remove(key, timestamp)
+
+ removed += numValues
+ }
+ removed
+ }
+
+ override def evictAndReturnByTimestamp(endTimestamp: Long):
Iterator[KeyToValuePair] = {
+ val reusableKeyToValuePair = KeyToValuePair()
+
+ tsWithKey.scanEvictedKeys(endTimestamp).flatMap { evicted =>
+ val key = evicted.key
+ val timestamp = evicted.timestamp
+ val values = keyWithTsToValues.get(key, timestamp)
+
+ // Remove from both primary and secondary stores
+ keyWithTsToValues.remove(key, timestamp)
+ tsWithKey.remove(key, timestamp)
+
+ values.map { value =>
+ reusableKeyToValuePair.withNew(key, value)
+ }
+ }
+ }
+
+ override def commit(): Unit = {
+ stateStore.commit()
+ logDebug("Committed, metrics = " + stateStore.metrics)
+ }
+
+ override def abortIfNeeded(): Unit = {
+ if (!stateStore.hasCommitted) {
+ logInfo(log"Aborted store ${MDC(STATE_STORE_ID, stateStore.id)}")
+ stateStore.abort()
+ }
+ // If this class manages a state store provider by itself, it should take
care of closing
+ // provider instance as well.
+ if (stateStoreProvider != null) {
+ stateStoreProvider.close()
+ }
+ }
+
+ // Clean up any state store resources if necessary at the end of the task
+ Option(TaskContext.get()).foreach { _.addTaskCompletionListener[Unit] { _ =>
abortIfNeeded() } }
+
+ class GetValuesResult(var timestamp: Long = -1, var values:
Seq[ValueAndMatchPair] = Seq.empty) {
+ def withNew(newTimestamp: Long, newValues: Seq[ValueAndMatchPair]):
GetValuesResult = {
+ this.timestamp = newTimestamp
+ this.values = newValues
+ this
+ }
+ }
+
+ private class KeyWithTsToValuesStore {
Review Comment:
Self review: add code comment
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]