HeartSaVioR commented on code in PR #53930:
URL: https://github.com/apache/spark/pull/53930#discussion_r2831065511
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala:
##########
@@ -27,16 +27,613 @@ import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
import org.apache.spark.internal.LogKeys.{END_INDEX, START_INDEX,
STATE_STORE_ID}
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute,
AttributeReference, Expression, JoinedRow, Literal, SafeProjection,
SpecificInternalRow, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.{Attribute,
AttributeReference, Expression, JoinedRow, Literal, NamedExpression,
SafeProjection, SpecificInternalRow, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.execution.metric.SQLMetric
-import
org.apache.spark.sql.execution.streaming.operators.stateful.StatefulOperatorStateInfo
-import
org.apache.spark.sql.execution.streaming.operators.stateful.StatefulOpStateStoreCheckpointInfo
+import
org.apache.spark.sql.execution.streaming.operators.stateful.{StatefulOperatorStateInfo,
StatefulOpStateStoreCheckpointInfo, WatermarkSupport}
import
org.apache.spark.sql.execution.streaming.operators.stateful.join.StreamingSymmetricHashJoinHelper._
-import
org.apache.spark.sql.execution.streaming.state.{DropLastNFieldsStatePartitionKeyExtractor,
KeyStateEncoderSpec, NoopStatePartitionKeyExtractor,
NoPrefixKeyStateEncoderSpec, StatePartitionKeyExtractor, StateSchemaBroadcast,
StateStore, StateStoreCheckpointInfo, StateStoreColFamilySchema,
StateStoreConf, StateStoreErrors, StateStoreId, StateStoreMetrics,
StateStoreProvider, StateStoreProviderId, SupportsFineGrainedReplay}
-import org.apache.spark.sql.types.{BooleanType, LongType, StructField,
StructType}
+import
org.apache.spark.sql.execution.streaming.state.{DropLastNFieldsStatePartitionKeyExtractor,
KeyStateEncoderSpec, NoopStatePartitionKeyExtractor,
NoPrefixKeyStateEncoderSpec, StatePartitionKeyExtractor, StateSchemaBroadcast,
StateStore, StateStoreCheckpointInfo, StateStoreColFamilySchema,
StateStoreConf, StateStoreErrors, StateStoreId, StateStoreMetrics,
StateStoreProvider, StateStoreProviderId, SupportsFineGrainedReplay,
TimestampAsPostfixKeyStateEncoderSpec, TimestampAsPrefixKeyStateEncoderSpec,
TimestampKeyStateEncoder}
+import org.apache.spark.sql.types.{BooleanType, DataType, LongType, NullType,
StructField, StructType}
import org.apache.spark.util.NextIterator
+trait SymmetricHashJoinStateManager {
+ import SymmetricHashJoinStateManager._
+
+ def append(key: UnsafeRow, value: UnsafeRow, matched: Boolean): Unit
+
+ def get(key: UnsafeRow): Iterator[UnsafeRow]
+
+ def getJoinedRows(
+ key: UnsafeRow,
+ generateJoinedRow: InternalRow => JoinedRow,
+ predicate: JoinedRow => Boolean,
+ excludeRowsAlreadyMatched: Boolean = false): Iterator[JoinedRow]
+
+ def iterator: Iterator[KeyToValuePair]
+
+ def commit(): Unit
+
+ def abortIfNeeded(): Unit
+
+ def metrics: StateStoreMetrics
+
+ def getLatestCheckpointInfo(): JoinerStateStoreCkptInfo
+}
+
+trait SupportsIndexedKeys {
+ def getInternalRowOfKeyWithIndex(currentKey: UnsafeRow): InternalRow
+
+ protected[streaming] def updateNumValuesTestOnly(key: UnsafeRow, numValues:
Long): Unit
+}
+
+trait SupportsEvictByCondition { self: SymmetricHashJoinStateManager =>
+ import SymmetricHashJoinStateManager._
+
+ def evictByKeyCondition(removalCondition: UnsafeRow => Boolean): Long
+
+ def evictAndReturnByKeyCondition(
+ removalCondition: UnsafeRow => Boolean): Iterator[KeyToValuePair]
+
+ def evictByValueCondition(removalCondition: UnsafeRow => Boolean): Long
+
+ def evictAndReturnByValueCondition(
+ removalCondition: UnsafeRow => Boolean): Iterator[KeyToValuePair]
+}
+
+trait SupportsEvictByTimestamp { self: SymmetricHashJoinStateManager =>
+ import SymmetricHashJoinStateManager._
+
+ def evictByTimestamp(endTimestamp: Long): Long
+
+ def evictAndReturnByTimestamp(endTimestamp: Long): Iterator[KeyToValuePair]
+}
+
+class SymmetricHashJoinStateManagerV4(
+ joinSide: JoinSide,
+ inputValueAttributes: Seq[Attribute],
+ joinKeys: Seq[Expression],
+ stateInfo: Option[StatefulOperatorStateInfo],
+ storeConf: StateStoreConf,
+ hadoopConf: Configuration,
+ partitionId: Int,
+ keyToNumValuesStateStoreCkptId: Option[String],
+ keyWithIndexToValueStateStoreCkptId: Option[String],
+ stateFormatVersion: Int,
+ skippedNullValueCount: Option[SQLMetric] = None,
+ useStateStoreCoordinator: Boolean = true,
+ snapshotOptions: Option[SnapshotOptions] = None,
+ joinStoreGenerator: JoinStateManagerStoreGenerator)
+ extends SymmetricHashJoinStateManager with SupportsEvictByTimestamp with
Logging {
+
+ import SymmetricHashJoinStateManager._
+
+ protected val keySchema = StructType(
+ joinKeys.zipWithIndex.map { case (k, i) => StructField(s"field$i",
k.dataType, k.nullable) })
+ protected val keyAttributes = toAttributes(keySchema)
+ private val eventTimeColIdxOpt = WatermarkSupport.findEventTimeColumnIndex(
+ inputValueAttributes,
+ // NOTE: This does not accept multiple event time columns. This is not the
same with the
+ // operator which we offer the backward compatibility, but it involves too
many layers to
+ // pass the information. The information is in SQLConf.
+ allowMultipleEventTimeColumns = false)
+
+ private val random = new scala.util.Random(System.currentTimeMillis())
+ private val bucketSizeForNoEventTime = 1024
+ private val extractEventTimeFn: UnsafeRow => Long = { row =>
+ eventTimeColIdxOpt match {
+ case Some(idx) =>
+ val attr = inputValueAttributes(idx)
+
+ if (attr.dataType.isInstanceOf[StructType]) {
+ // NOTE: We assume this is window struct, as same as
WatermarkSupport.watermarkExpression
+ row.getStruct(idx, 2).getLong(1)
+ } else {
+ row.getLong(idx)
+ }
+
+ case _ =>
+ // Need a strategy about bucketing when event time is not available
+ // - first attempt: random bucketing
+ random.nextInt(bucketSizeForNoEventTime)
+ }
+ }
+
+ private val eventTimeColIdxOptInKey: Option[Int] = {
+ joinKeys.zipWithIndex.collectFirst {
+ case (ne: NamedExpression, index)
+ if ne.metadata.contains(EventTimeWatermark.delayKey) => index
+ }
+ }
+
+ private val extractEventTimeFnFromKey: UnsafeRow => Option[Long] = { row =>
+ eventTimeColIdxOptInKey.map { idx =>
+ val attr = keyAttributes(idx)
+ if (attr.dataType.isInstanceOf[StructType]) {
+ // NOTE: We assume this is window struct, as same as
WatermarkSupport.watermarkExpression
+ row.getStruct(idx, 2).getLong(1)
+ } else {
+ row.getLong(idx)
+ }
+ }
+ }
+
+ private val dummySchema = StructType(
+ Seq(StructField("dummy", NullType, nullable = true))
+ )
+
+ private val stateStoreCkptId: Option[String] = None
Review Comment:
Ahh nice catch. I think I left this to the integration but forgot to leave a
TODO comment. Let me do this...
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala:
##########
@@ -27,16 +27,613 @@ import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
import org.apache.spark.internal.LogKeys.{END_INDEX, START_INDEX,
STATE_STORE_ID}
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute,
AttributeReference, Expression, JoinedRow, Literal, SafeProjection,
SpecificInternalRow, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.{Attribute,
AttributeReference, Expression, JoinedRow, Literal, NamedExpression,
SafeProjection, SpecificInternalRow, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.execution.metric.SQLMetric
-import
org.apache.spark.sql.execution.streaming.operators.stateful.StatefulOperatorStateInfo
-import
org.apache.spark.sql.execution.streaming.operators.stateful.StatefulOpStateStoreCheckpointInfo
+import
org.apache.spark.sql.execution.streaming.operators.stateful.{StatefulOperatorStateInfo,
StatefulOpStateStoreCheckpointInfo, WatermarkSupport}
import
org.apache.spark.sql.execution.streaming.operators.stateful.join.StreamingSymmetricHashJoinHelper._
-import
org.apache.spark.sql.execution.streaming.state.{DropLastNFieldsStatePartitionKeyExtractor,
KeyStateEncoderSpec, NoopStatePartitionKeyExtractor,
NoPrefixKeyStateEncoderSpec, StatePartitionKeyExtractor, StateSchemaBroadcast,
StateStore, StateStoreCheckpointInfo, StateStoreColFamilySchema,
StateStoreConf, StateStoreErrors, StateStoreId, StateStoreMetrics,
StateStoreProvider, StateStoreProviderId, SupportsFineGrainedReplay}
-import org.apache.spark.sql.types.{BooleanType, LongType, StructField,
StructType}
+import
org.apache.spark.sql.execution.streaming.state.{DropLastNFieldsStatePartitionKeyExtractor,
KeyStateEncoderSpec, NoopStatePartitionKeyExtractor,
NoPrefixKeyStateEncoderSpec, StatePartitionKeyExtractor, StateSchemaBroadcast,
StateStore, StateStoreCheckpointInfo, StateStoreColFamilySchema,
StateStoreConf, StateStoreErrors, StateStoreId, StateStoreMetrics,
StateStoreProvider, StateStoreProviderId, SupportsFineGrainedReplay,
TimestampAsPostfixKeyStateEncoderSpec, TimestampAsPrefixKeyStateEncoderSpec,
TimestampKeyStateEncoder}
+import org.apache.spark.sql.types.{BooleanType, DataType, LongType, NullType,
StructField, StructType}
import org.apache.spark.util.NextIterator
+trait SymmetricHashJoinStateManager {
+ import SymmetricHashJoinStateManager._
+
+ def append(key: UnsafeRow, value: UnsafeRow, matched: Boolean): Unit
+
+ def get(key: UnsafeRow): Iterator[UnsafeRow]
+
+ def getJoinedRows(
+ key: UnsafeRow,
+ generateJoinedRow: InternalRow => JoinedRow,
+ predicate: JoinedRow => Boolean,
+ excludeRowsAlreadyMatched: Boolean = false): Iterator[JoinedRow]
+
+ def iterator: Iterator[KeyToValuePair]
+
+ def commit(): Unit
+
+ def abortIfNeeded(): Unit
+
+ def metrics: StateStoreMetrics
+
+ def getLatestCheckpointInfo(): JoinerStateStoreCkptInfo
+}
+
+trait SupportsIndexedKeys {
+ def getInternalRowOfKeyWithIndex(currentKey: UnsafeRow): InternalRow
+
+ protected[streaming] def updateNumValuesTestOnly(key: UnsafeRow, numValues:
Long): Unit
+}
+
+trait SupportsEvictByCondition { self: SymmetricHashJoinStateManager =>
+ import SymmetricHashJoinStateManager._
+
+ def evictByKeyCondition(removalCondition: UnsafeRow => Boolean): Long
+
+ def evictAndReturnByKeyCondition(
+ removalCondition: UnsafeRow => Boolean): Iterator[KeyToValuePair]
+
+ def evictByValueCondition(removalCondition: UnsafeRow => Boolean): Long
+
+ def evictAndReturnByValueCondition(
+ removalCondition: UnsafeRow => Boolean): Iterator[KeyToValuePair]
+}
+
+trait SupportsEvictByTimestamp { self: SymmetricHashJoinStateManager =>
+ import SymmetricHashJoinStateManager._
+
+ def evictByTimestamp(endTimestamp: Long): Long
+
+ def evictAndReturnByTimestamp(endTimestamp: Long): Iterator[KeyToValuePair]
+}
+
+class SymmetricHashJoinStateManagerV4(
+ joinSide: JoinSide,
+ inputValueAttributes: Seq[Attribute],
+ joinKeys: Seq[Expression],
+ stateInfo: Option[StatefulOperatorStateInfo],
+ storeConf: StateStoreConf,
+ hadoopConf: Configuration,
+ partitionId: Int,
+ keyToNumValuesStateStoreCkptId: Option[String],
+ keyWithIndexToValueStateStoreCkptId: Option[String],
+ stateFormatVersion: Int,
+ skippedNullValueCount: Option[SQLMetric] = None,
+ useStateStoreCoordinator: Boolean = true,
+ snapshotOptions: Option[SnapshotOptions] = None,
+ joinStoreGenerator: JoinStateManagerStoreGenerator)
+ extends SymmetricHashJoinStateManager with SupportsEvictByTimestamp with
Logging {
+
+ import SymmetricHashJoinStateManager._
+
+ protected val keySchema = StructType(
+ joinKeys.zipWithIndex.map { case (k, i) => StructField(s"field$i",
k.dataType, k.nullable) })
+ protected val keyAttributes = toAttributes(keySchema)
+ private val eventTimeColIdxOpt = WatermarkSupport.findEventTimeColumnIndex(
+ inputValueAttributes,
+ // NOTE: This does not accept multiple event time columns. This is not the
same with the
+ // operator which we offer the backward compatibility, but it involves too
many layers to
+ // pass the information. The information is in SQLConf.
+ allowMultipleEventTimeColumns = false)
+
+ private val random = new scala.util.Random(System.currentTimeMillis())
+ private val bucketSizeForNoEventTime = 1024
+ private val extractEventTimeFn: UnsafeRow => Long = { row =>
+ eventTimeColIdxOpt match {
+ case Some(idx) =>
+ val attr = inputValueAttributes(idx)
+
+ if (attr.dataType.isInstanceOf[StructType]) {
+ // NOTE: We assume this is window struct, as same as
WatermarkSupport.watermarkExpression
+ row.getStruct(idx, 2).getLong(1)
+ } else {
+ row.getLong(idx)
+ }
+
+ case _ =>
+ // Need a strategy about bucketing when event time is not available
+ // - first attempt: random bucketing
+ random.nextInt(bucketSizeForNoEventTime)
+ }
+ }
+
+ private val eventTimeColIdxOptInKey: Option[Int] = {
+ joinKeys.zipWithIndex.collectFirst {
+ case (ne: NamedExpression, index)
+ if ne.metadata.contains(EventTimeWatermark.delayKey) => index
+ }
+ }
+
+ private val extractEventTimeFnFromKey: UnsafeRow => Option[Long] = { row =>
+ eventTimeColIdxOptInKey.map { idx =>
+ val attr = keyAttributes(idx)
+ if (attr.dataType.isInstanceOf[StructType]) {
+ // NOTE: We assume this is window struct, as same as
WatermarkSupport.watermarkExpression
+ row.getStruct(idx, 2).getLong(1)
+ } else {
+ row.getLong(idx)
+ }
+ }
+ }
+
+ private val dummySchema = StructType(
+ Seq(StructField("dummy", NullType, nullable = true))
+ )
+
+ private val stateStoreCkptId: Option[String] = None
+ private val handlerSnapshotOptions: Option[HandlerSnapshotOptions] = None
Review Comment:
Same here.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]