Github user zsxwing commented on a diff in the pull request:
https://github.com/apache/spark/pull/19271#discussion_r140055717
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala
---
@@ -0,0 +1,403 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state
+
+import scala.reflect.ClassTag
+
+import org.apache.hadoop.conf.Configuration
+
+import org.apache.spark.{Partition, SparkContext, TaskContext}
+import org.apache.spark.internal.Logging
+import org.apache.spark.rdd.{RDD, ZippedPartitionsRDD2}
+import org.apache.spark.sql.catalyst.expressions.{Attribute,
AttributeReference, BindReferences, Expression, LessThanOrEqual, Literal,
SpecificInternalRow, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.codegen.Predicate
+import
org.apache.spark.sql.execution.streaming.{StatefulOperatorStateInfo,
StreamingSymmetricHashJoinExec}
+import
org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper._
+import org.apache.spark.sql.types.{LongType, StructField, StructType,
TimestampType}
+import org.apache.spark.util.NextIterator
+
+/**
+ * Helper class to manage state required by a single side of
[[StreamingSymmetricHashJoinExec]].
+ * The interface of this class is basically that of a multi-map:
+ * - Get: Returns an iterator of multiple values for given key
+ * - Append: Append a new value to the given key
+ * - Remove Data by predicate: Drop any state using a predicate condition
on keys or values
+ *
+ * @param joinSide Defines the join side
+ * @param inputValueAttributes Attributes of the input row which will be
stored as value
+ * @param joinKeys Expressions to generate rows that will be
used to key the value rows
+ * @param stateInfo Information about how to retrieve the
correct version of state
+ * @param storeConf Configuration for the state store.
+ * @param hadoopConf Hadoop configuration for reading state
data from storage
+ *
+ * Internally, the key -> multiple values is stored in two [[StateStore]]s.
+ * - Store 1 ([[KeyToNumValuesStore]]) maintains mapping between key ->
number of values
+ * - Store 2 ([[KeyWithIndexToValueStore]]) maintains mapping between
(key, index) -> value
+ * - Put: update count in KeyToNumValuesStore,
+ * insert new (key, count) -> value in KeyWithIndexToValueStore
+ * - Get: read count from KeyToNumValuesStore,
+ * read each of the n values in KeyWithIndexToValueStore
+ * - Remove state by predicate on keys:
+ * scan all keys in KeyToNumValuesStore to find keys that do
match the predicate,
+ * delete from key from KeyToNumValuesStore, delete values in
KeyWithIndexToValueStore
+ * - Remove state by condition on values:
+ * scan all [(key, index) -> value] in KeyWithIndexToValueStore
to find values that match
+ * the predicate, delete corresponding (key, indexToDelete) from
KeyWithIndexToValueStore
+ * by overwriting with the value of (key, maxIndex), and removing
[(key, maxIndex),
+ * decrement corresponding num values in KeyToNumValuesStore
+ */
+class SymmetricHashJoinStateManager(
+ val joinSide: JoinSide,
+ inputValueAttributes: Seq[Attribute],
+ joinKeys: Seq[Expression],
+ stateInfo: Option[StatefulOperatorStateInfo],
+ storeConf: StateStoreConf,
+ hadoopConf: Configuration) extends Logging {
+
+ import SymmetricHashJoinStateManager._
+
+ // Clean up any state store resources if necessary at the end of the task
+ Option(TaskContext.get()).foreach { _.addTaskCompletionListener { _ =>
abortIfNeeded() } }
+
+ /*
+ =====================================================
+ Public methods
+ =====================================================
+ */
+
+ /** Get all the values of a key */
+ def get(key: UnsafeRow): Iterator[UnsafeRow] = {
+ val numValues = keyToNumValues.get(key)
+ keyWithIndexToValue.getAll(key, numValues)
+ }
+
+ /** Append a new value to the key */
+ def append(key: UnsafeRow, value: UnsafeRow): Unit = {
+ val numExistingValues = keyToNumValues.get(key)
+ keyWithIndexToValue.put(key, numExistingValues, value)
+ keyToNumValues.put(key, numExistingValues + 1)
+ }
+
+ /**
+ * Remove using a predicate on keys. See class docs for more context and
implement details.
+ */
+ def removeByKeyCondition(condition: UnsafeRow => Boolean): Unit = {
+ val allKeyToNumValues = keyToNumValues.iterator
+
+ while (allKeyToNumValues.hasNext) {
+ val keyToNumValue = allKeyToNumValues.next
+ if (condition(keyToNumValue.key)) {
+ keyToNumValues.remove(keyToNumValue.key)
+ keyWithIndexToValue.removeAllValues(keyToNumValue.key,
keyToNumValue.numValue)
+ }
+ }
+ }
+
+ /**
+ * Remove using a predicate on values. See class docs for more context
and implementation details.
+ */
+ def removeByValueCondition(condition: UnsafeRow => Boolean): Unit = {
+ val allKeyToNumValues = keyToNumValues.iterator
+ var numValues: Long = 0L
--- End diff --
nit: could you declare them inside the while loop? Then it's easy to tell
that they will not be used cross keys.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]