Github user tdas commented on a diff in the pull request:
https://github.com/apache/spark/pull/19271#discussion_r140062031
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala
---
@@ -0,0 +1,405 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state
+
+import scala.reflect.ClassTag
+
+import org.apache.hadoop.conf.Configuration
+
+import org.apache.spark.{Partition, SparkContext, TaskContext}
+import org.apache.spark.internal.Logging
+import org.apache.spark.rdd.{RDD, ZippedPartitionsRDD2}
+import org.apache.spark.sql.catalyst.expressions.{Attribute,
AttributeReference, BindReferences, Expression, LessThanOrEqual, Literal,
SpecificInternalRow, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.codegen.Predicate
+import
org.apache.spark.sql.execution.streaming.{StatefulOperatorStateInfo,
StreamingSymmetricHashJoinExec}
+import
org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper._
+import org.apache.spark.sql.types.{LongType, StructField, StructType}
+import org.apache.spark.util.NextIterator
+
+/**
+ * Helper class to manage state required by a single side of
[[StreamingSymmetricHashJoinExec]].
+ * The interface of this class is basically that of a multi-map:
+ * - Get: Returns an iterator of multiple values for given key
+ * - Append: Append a new value to the given key
+ * - Remove Data by predicate: Drop any state using a predicate condition
on keys or values
+ *
+ * @param joinSide Defines the join side
+ * @param inputValueAttributes Attributes of the input row which will be
stored as value
+ * @param joinKeys Expressions to find the join key that will be
used to key the value rows
+ * @param stateInfo Information about how to retrieve the
correct version of state
+ * @param storeConf Configuration for the state store.
+ * @param hadoopConf Hadoop configuration for reading state data
from storage
+ *
+ * Internally, the key -> multiple values is stored in two [[StateStore]]s.
+ * - Store 1 ([[KeyToNumValuesStore]]) maintains mapping between key ->
number of values
+ * - Store 2 ([[KeyWithIndexToValueStore]]) maintains mapping between
(key, index) -> value
+ * - Put: update count in KeyToNumValuesStore,
+ * insert new (key, count) -> value in KeyWithIndexToValueStore
+ * - Get: read count from KeyToNumValuesStore,
+ * read each of the n values in KeyWithIndexToValueStore
+ * - Remove state by predicate on keys:
+ * scan all keys in KeyToNumValuesStore to find keys that do
match the predicate,
+ * delete from key from KeyToNumValuesStore, delete values in
KeyWithIndexToValueStore
+ * - Remove state by condition on values:
+ * scan all [(key, index) -> value] in KeyWithIndexToValueStore
to find values that match
+ * the predicate, delete corresponding (key, indexToDelete) from
KeyWithIndexToValueStore
+ * by overwriting with the value of (key, maxIndex), and removing
[(key, maxIndex),
+ * decrement corresponding num values in KeyToNumValuesStore
+ */
+class SymmetricHashJoinStateManager(
+ val joinSide: JoinSide,
+ val inputValueAttributes: Seq[Attribute],
+ joinKeys: Seq[Expression],
+ stateInfo: Option[StatefulOperatorStateInfo],
+ storeConf: StateStoreConf,
+ hadoopConf: Configuration) extends Logging {
+
+ import SymmetricHashJoinStateManager._
+
+ // Clean up any state store resources if necessary at the end of the task
+ Option(TaskContext.get()).foreach { _.addTaskCompletionListener { _ =>
abortIfNeeded() } }
+
+ /*
+ =====================================================
+ Public methods
+ =====================================================
+ */
+
+ /** Get all the values of a key */
+ def get(key: UnsafeRow): Iterator[UnsafeRow] = {
+ val numValues = keyToNumValues.get(key)
+ keyWithIndexToValue.getAll(key, numValues)
+ }
+
+ /** Append a new value to the key */
+ def append(key: UnsafeRow, value: UnsafeRow): Unit = {
+ val numExistingValues = keyToNumValues.get(key)
+ keyWithIndexToValue.put(key, numExistingValues, value)
+ keyToNumValues.put(key, numExistingValues + 1)
+ }
+
+ /**
+ * Remove using a predicate on keys. See class docs for more context and
implement details.
+ */
+ def removeByKeyCondition(condition: UnsafeRow => Boolean): Unit = {
+ val allKeyToNumValues = keyToNumValues.iterator
+
+ while (allKeyToNumValues.hasNext) {
+ val keyToNumValue = allKeyToNumValues.next
+ if (condition(keyToNumValue.key)) {
+ keyToNumValues.remove(keyToNumValue.key)
+ keyWithIndexToValue.removeAllValues(keyToNumValue.key,
keyToNumValue.numValue)
+ }
+ }
+ }
+
+ /**
+ * Remove using a predicate on values. See class docs for more context
and implementation details.
+ */
+ def removeByPredicateOnValues(condition: UnsafeRow => Boolean): Unit = {
+ val allKeyToNumValues = keyToNumValues.iterator
+
+ var numValues: Long = 0L
+ var index: Long = 0L
+ var valueRemoved = false
+ var valueForIndex: UnsafeRow = null
+
+
+ while (allKeyToNumValues.hasNext) {
+ val keyToNumValue = allKeyToNumValues.next
+ val key = keyToNumValue.key
+
+ numValues = keyToNumValue.numValue
+ index = 0L
+ valueRemoved = false
+ valueForIndex = null
+
+ while (index < numValues) {
+ if (valueForIndex == null) {
+ valueForIndex = keyWithIndexToValue.get(key, index)
+ }
+ if (condition(valueForIndex)) {
+ if (numValues > 1) {
+ val valueAtMaxIndex = keyWithIndexToValue.get(key, numValues -
1)
+ keyWithIndexToValue.put(key, index, valueAtMaxIndex)
+ keyWithIndexToValue.remove(key, numValues - 1)
+ valueForIndex = valueAtMaxIndex
+ } else {
+ keyWithIndexToValue.remove(key, 0)
+ valueForIndex = null
+ }
+ numValues -= 1
+ valueRemoved = true
+ } else {
+ valueForIndex = null
+ index += 1
+ }
+ }
+ if (valueRemoved) {
+ if (numValues >= 1) {
+ keyToNumValues.put(key, numValues)
+ } else {
+ keyToNumValues.remove(key)
+ }
+ }
+ }
+ }
+
+ def iterator(): Iterator[UnsafeRowPair] = {
+ val pair = new UnsafeRowPair()
+ keyWithIndexToValue.iterator.map { x =>
+ pair.withRows(x.key, x.value)
+ }
+ }
+
+ /** Commit all the changes to all the state stores */
+ def commit(): Unit = {
+ keyToNumValues.commit()
+ keyWithIndexToValue.commit()
+ }
+
+ /** Abort any changes to the state stores if needed */
+ def abortIfNeeded(): Unit = {
+ keyWithIndexToValue.abortIfNeeded()
+ keyWithIndexToValue.abortIfNeeded()
--- End diff --
good catch!
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]