Github user brkyvz commented on a diff in the pull request:
https://github.com/apache/spark/pull/19271#discussion_r140017131
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala
---
@@ -0,0 +1,403 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state
+
+import scala.reflect.ClassTag
+
+import org.apache.hadoop.conf.Configuration
+
+import org.apache.spark.{Partition, SparkContext, TaskContext}
+import org.apache.spark.internal.Logging
+import org.apache.spark.rdd.{RDD, ZippedPartitionsRDD2}
+import org.apache.spark.sql.catalyst.expressions.{Attribute,
AttributeReference, BindReferences, Expression, LessThanOrEqual, Literal,
SpecificInternalRow, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.codegen.Predicate
+import
org.apache.spark.sql.execution.streaming.{StatefulOperatorStateInfo,
StreamingSymmetricHashJoinExec}
+import
org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper._
+import org.apache.spark.sql.types.{LongType, StructField, StructType}
+import org.apache.spark.util.NextIterator
+
+/**
+ * Helper class to manage state required by a single side of
[[StreamingSymmetricHashJoinExec]].
+ * The interface of this class is basically that of a multi-map:
+ * - Get: Returns an iterator of multiple values for given key
+ * - Append: Append a new value to the given key
+ * - Remove Data by predicate: Drop any state using a predicate condition
on keys or values
+ *
+ * @param joinSide Defines the join side
+ * @param inputValueAttributes Attributes of the input row which will be
stored as value
+ * @param joinKeys Expressions to generate rows that will be
used to key the value rows
+ * @param stateInfo Information about how to retrieve the
correct version of state
+ * @param storeConf Configuration for the state store.
+ * @param hadoopConf Hadoop configuration for reading state
data from storage
+ *
+ * Internally, the key -> multiple values is stored in two [[StateStore]]s.
+ * - Store 1 ([[KeyToNumValuesStore]]) maintains mapping between key ->
number of values
+ * - Store 2 ([[KeyWithIndexToValueStore]]) maintains mapping between
(key, index) -> value
+ * - Put: update count in KeyToNumValuesStore,
+ * insert new (key, count) -> value in KeyWithIndexToValueStore
+ * - Get: read count from KeyToNumValuesStore,
+ * read each of the n values in KeyWithIndexToValueStore
+ * - Remove state by predicate on keys:
+ * scan all keys in KeyToNumValuesStore to find keys that do
match the predicate,
+ * delete from key from KeyToNumValuesStore, delete values in
KeyWithIndexToValueStore
+ * - Remove state by condition on values:
+ * scan all [(key, index) -> value] in KeyWithIndexToValueStore
to find values that match
+ * the predicate, delete corresponding (key, indexToDelete) from
KeyWithIndexToValueStore
+ * by overwriting with the value of (key, maxIndex), and removing
[(key, maxIndex),
+ * decrement corresponding num values in KeyToNumValuesStore
+ */
+class SymmetricHashJoinStateManager(
+ val joinSide: JoinSide,
+ inputValueAttributes: Seq[Attribute],
+ joinKeys: Seq[Expression],
+ stateInfo: Option[StatefulOperatorStateInfo],
+ storeConf: StateStoreConf,
+ hadoopConf: Configuration) extends Logging {
+
+ import SymmetricHashJoinStateManager._
+
+ // Clean up any state store resources if necessary at the end of the task
+ Option(TaskContext.get()).foreach { _.addTaskCompletionListener { _ =>
abortIfNeeded() } }
+
+ /*
+ =====================================================
+ Public methods
+ =====================================================
+ */
+
+ /** Get all the values of a key */
+ def get(key: UnsafeRow): Iterator[UnsafeRow] = {
+ val numValues = keyToNumValues.get(key)
+ keyWithIndexToValue.getAll(key, numValues)
+ }
+
+ /** Append a new value to the key */
+ def append(key: UnsafeRow, value: UnsafeRow): Unit = {
+ val numExistingValues = keyToNumValues.get(key)
+ keyWithIndexToValue.put(key, numExistingValues, value)
+ keyToNumValues.put(key, numExistingValues + 1)
+ }
+
+ /**
+ * Remove using a predicate on keys. See class docs for more context and
implement details.
+ */
+ def removeByKeyCondition(condition: UnsafeRow => Boolean): Unit = {
+ val allKeyToNumValues = keyToNumValues.iterator
+
+ while (allKeyToNumValues.hasNext) {
+ val keyToNumValue = allKeyToNumValues.next
+ if (condition(keyToNumValue.key)) {
+ keyToNumValues.remove(keyToNumValue.key)
+ keyWithIndexToValue.removeAllValues(keyToNumValue.key,
keyToNumValue.numValue)
+ }
+ }
+ }
+
+ /**
+ * Remove using a predicate on values. See class docs for more context
and implementation details.
+ */
+ def removeByValueCondition(condition: UnsafeRow => Boolean): Unit = {
+ val allKeyToNumValues = keyToNumValues.iterator
+ var numValues: Long = 0L
+ var index: Long = 0L
+ var valueRemoved = false
+ var valueForIndex: UnsafeRow = null
+
+ while (allKeyToNumValues.hasNext) {
+ val keyToNumValue = allKeyToNumValues.next
+ val key = keyToNumValue.key
+
+ numValues = keyToNumValue.numValue
+ index = 0L
+ valueRemoved = false
+ valueForIndex = null
+
+ while (index < numValues) {
+ if (valueForIndex == null) {
+ valueForIndex = keyWithIndexToValue.get(key, index)
+ }
+ if (condition(valueForIndex)) {
+ if (numValues > 1) {
+ val valueAtMaxIndex = keyWithIndexToValue.get(key, numValues -
1)
+ keyWithIndexToValue.put(key, index, valueAtMaxIndex)
+ keyWithIndexToValue.remove(key, numValues - 1)
+ valueForIndex = valueAtMaxIndex
+ } else {
+ keyWithIndexToValue.remove(key, 0)
+ valueForIndex = null
+ }
+ numValues -= 1
+ valueRemoved = true
+ } else {
+ valueForIndex = null
+ index += 1
+ }
+ }
+ if (valueRemoved) {
+ if (numValues >= 1) {
+ keyToNumValues.put(key, numValues)
+ } else {
+ keyToNumValues.remove(key)
+ }
+ }
+ }
+ }
+
+ def iterator(): Iterator[UnsafeRowPair] = {
+ val pair = new UnsafeRowPair()
+ keyWithIndexToValue.iterator.map { x =>
+ pair.withRows(x.key, x.value)
+ }
+ }
+
+ /** Commit all the changes to all the state stores */
+ def commit(): Unit = {
+ keyToNumValues.commit()
+ keyWithIndexToValue.commit()
+ }
+
+ /** Abort any changes to the state stores if needed */
+ def abortIfNeeded(): Unit = {
+ keyWithIndexToValue.abortIfNeeded()
+ keyWithIndexToValue.abortIfNeeded()
+ }
+
+ /** Get the combined metrics of all the state stores */
+ def metrics: StateStoreMetrics = {
+ val keyToNumValuesMetrics = keyToNumValues.metrics
+ val keyWithIndexToValueMetrics = keyWithIndexToValue.metrics
+ def newDesc(desc: String): String =
s"${joinSide.toString.toUpperCase}: $desc"
+
+ StateStoreMetrics(
+ keyWithIndexToValueMetrics.numKeys, // represent each buffered
row only once
+ keyToNumValuesMetrics.memoryUsedBytes +
keyWithIndexToValueMetrics.memoryUsedBytes,
+ keyWithIndexToValueMetrics.customMetrics.map {
+ case (s @ StateStoreCustomSizeMetric(_, desc), value) =>
+ s.copy(desc = newDesc(desc)) -> value
+ case (s @ StateStoreCustomTimingMetric(_, desc), value) =>
+ s.copy(desc = newDesc(desc)) -> value
+ }
+ )
+ }
+
+ /*
+ =====================================================
+ Private methods and inner classes
+ =====================================================
+ */
+
+ private val keySchema = StructType(
+ joinKeys.zipWithIndex.map { case (k, i) => StructField(s"field$i",
k.dataType, k.nullable) })
+ private val keyAttributes = keySchema.toAttributes
+ private val keyToNumValues = new KeyToNumValuesStore()
+ private val keyWithIndexToValue = new KeyWithIndexToValueStore()
+
+ /** Helper trait for invoking common functionalities of a state store. */
+ private abstract class StateStoreHandler(stateStoreType: StateStoreType)
extends Logging {
+
+ /** StateStore that the subclasses of this class is going to operate
on */
+ protected def stateStore: StateStore
+
+ def commit(): Unit = {
+ stateStore.commit()
+ logDebug("Committed, metrics = " + stateStore.metrics)
+ }
+
+ def abortIfNeeded(): Unit = {
+ if (!stateStore.hasCommitted) {
+ logInfo(s"Aborted store ${stateStore.id}")
+ stateStore.abort()
+ }
+ }
+
+ def metrics: StateStoreMetrics = stateStore.metrics
+
+ /** Get the StateStore with the given schema */
+ protected def getStateStore(keySchema: StructType, valueSchema:
StructType): StateStore = {
+ val storeProviderId = StateStoreProviderId(
+ stateInfo.get, TaskContext.getPartitionId(),
getStateStoreName(joinSide, stateStoreType))
+ val store = StateStore.get(
+ storeProviderId, keySchema, valueSchema, None,
+ stateInfo.get.storeVersion, storeConf, hadoopConf)
+ logInfo(s"Loaded store ${store.id}")
+ store
+ }
+ }
+
+ /**
+ * Helper class for representing data returned by
[[KeyWithIndexToValueStore]].
+ * Designed for object reuse.
+ */
+ private case class KeyAndNumValues(var key: UnsafeRow = null, var
numValue: Long = 0) {
+ def withNew(newKey: UnsafeRow, newNumValues: Long): this.type = {
+ this.key = newKey
+ this.numValue = newNumValues
+ this
+ }
+ }
+
+
+ /** A wrapper around a [[StateStore]] that stores [key -> number of
values]. */
+ private class KeyToNumValuesStore extends
StateStoreHandler(KeyToNumValuesType) {
+ private val longValueSchema = new StructType().add("value", "long")
+ private val longToUnsafeRow = UnsafeProjection.create(longValueSchema)
+ private val valueRow = longToUnsafeRow(new
SpecificInternalRow(longValueSchema))
+ protected val stateStore: StateStore = getStateStore(keySchema,
longValueSchema)
+
+ /** Get the number of values the key has */
+ def get(key: UnsafeRow): Long = {
+ val longValueRow = stateStore.get(key)
+ if (longValueRow != null) longValueRow.getLong(0) else 0L
+ }
+
+ /** Set the number of values the key has */
+ def put(key: UnsafeRow, numValues: Long): Unit = {
+ require(numValues > 0)
+ valueRow.setLong(0, numValues)
+ stateStore.put(key, valueRow)
+ }
+
+ def remove(key: UnsafeRow): Unit = {
+ stateStore.remove(key)
+ }
+
+ def iterator: Iterator[KeyAndNumValues] = {
+ val keyAndNumValues = new KeyAndNumValues()
+ stateStore.getRange(None, None).map { case pair =>
+ keyAndNumValues.withNew(pair.key, pair.value.getLong(0))
+ }
+ }
+ }
+
+ /**
+ * Helper class for representing data returned by
[[KeyWithIndexToValueStore]].
+ * Designed for object reuse.
+ */
+ private case class KeyWithIndexAndValue(
+ var key: UnsafeRow = null, var valueIndex: Long = -1, var value:
UnsafeRow = null) {
+ def withNew(newKey: UnsafeRow, newIndex: Long, newValue: UnsafeRow):
this.type = {
+ this.key = newKey
+ this.valueIndex = newIndex
+ this.value = newValue
+ this
+ }
+ }
+
+ /** A wrapper around a [[StateStore]] that stores [(key, index) ->
value]. */
+ private class KeyWithIndexToValueStore extends
StateStoreHandler(KeyWithIndexToValuesType) {
+ private val keyWithIndexExprs = keyAttributes :+ Literal(1L)
+ private val keyWithIndexSchema =
StructType(keySchema.fields).add("index", LongType)
--- End diff --
can't you just call `keySchema.add(..)`?
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]