Github user brkyvz commented on a diff in the pull request: https://github.com/apache/spark/pull/19271#discussion_r140017131 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala --- @@ -0,0 +1,403 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import scala.reflect.ClassTag + +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.{Partition, SparkContext, TaskContext} +import org.apache.spark.internal.Logging +import org.apache.spark.rdd.{RDD, ZippedPartitionsRDD2} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BindReferences, Expression, LessThanOrEqual, Literal, SpecificInternalRow, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.codegen.Predicate +import org.apache.spark.sql.execution.streaming.{StatefulOperatorStateInfo, StreamingSymmetricHashJoinExec} +import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper._ +import org.apache.spark.sql.types.{LongType, StructField, StructType} +import org.apache.spark.util.NextIterator + +/** + * Helper class to manage state required by a single side of [[StreamingSymmetricHashJoinExec]]. + * The interface of this class is basically that of a multi-map: + * - Get: Returns an iterator of multiple values for given key + * - Append: Append a new value to the given key + * - Remove Data by predicate: Drop any state using a predicate condition on keys or values + * + * @param joinSide Defines the join side + * @param inputValueAttributes Attributes of the input row which will be stored as value + * @param joinKeys Expressions to generate rows that will be used to key the value rows + * @param stateInfo Information about how to retrieve the correct version of state + * @param storeConf Configuration for the state store. + * @param hadoopConf Hadoop configuration for reading state data from storage + * + * Internally, the key -> multiple values is stored in two [[StateStore]]s. + * - Store 1 ([[KeyToNumValuesStore]]) maintains mapping between key -> number of values + * - Store 2 ([[KeyWithIndexToValueStore]]) maintains mapping between (key, index) -> value + * - Put: update count in KeyToNumValuesStore, + * insert new (key, count) -> value in KeyWithIndexToValueStore + * - Get: read count from KeyToNumValuesStore, + * read each of the n values in KeyWithIndexToValueStore + * - Remove state by predicate on keys: + * scan all keys in KeyToNumValuesStore to find keys that do match the predicate, + * delete from key from KeyToNumValuesStore, delete values in KeyWithIndexToValueStore + * - Remove state by condition on values: + * scan all [(key, index) -> value] in KeyWithIndexToValueStore to find values that match + * the predicate, delete corresponding (key, indexToDelete) from KeyWithIndexToValueStore + * by overwriting with the value of (key, maxIndex), and removing [(key, maxIndex), + * decrement corresponding num values in KeyToNumValuesStore + */ +class SymmetricHashJoinStateManager( + val joinSide: JoinSide, + inputValueAttributes: Seq[Attribute], + joinKeys: Seq[Expression], + stateInfo: Option[StatefulOperatorStateInfo], + storeConf: StateStoreConf, + hadoopConf: Configuration) extends Logging { + + import SymmetricHashJoinStateManager._ + + // Clean up any state store resources if necessary at the end of the task + Option(TaskContext.get()).foreach { _.addTaskCompletionListener { _ => abortIfNeeded() } } + + /* + ===================================================== + Public methods + ===================================================== + */ + + /** Get all the values of a key */ + def get(key: UnsafeRow): Iterator[UnsafeRow] = { + val numValues = keyToNumValues.get(key) + keyWithIndexToValue.getAll(key, numValues) + } + + /** Append a new value to the key */ + def append(key: UnsafeRow, value: UnsafeRow): Unit = { + val numExistingValues = keyToNumValues.get(key) + keyWithIndexToValue.put(key, numExistingValues, value) + keyToNumValues.put(key, numExistingValues + 1) + } + + /** + * Remove using a predicate on keys. See class docs for more context and implement details. + */ + def removeByKeyCondition(condition: UnsafeRow => Boolean): Unit = { + val allKeyToNumValues = keyToNumValues.iterator + + while (allKeyToNumValues.hasNext) { + val keyToNumValue = allKeyToNumValues.next + if (condition(keyToNumValue.key)) { + keyToNumValues.remove(keyToNumValue.key) + keyWithIndexToValue.removeAllValues(keyToNumValue.key, keyToNumValue.numValue) + } + } + } + + /** + * Remove using a predicate on values. See class docs for more context and implementation details. + */ + def removeByValueCondition(condition: UnsafeRow => Boolean): Unit = { + val allKeyToNumValues = keyToNumValues.iterator + var numValues: Long = 0L + var index: Long = 0L + var valueRemoved = false + var valueForIndex: UnsafeRow = null + + while (allKeyToNumValues.hasNext) { + val keyToNumValue = allKeyToNumValues.next + val key = keyToNumValue.key + + numValues = keyToNumValue.numValue + index = 0L + valueRemoved = false + valueForIndex = null + + while (index < numValues) { + if (valueForIndex == null) { + valueForIndex = keyWithIndexToValue.get(key, index) + } + if (condition(valueForIndex)) { + if (numValues > 1) { + val valueAtMaxIndex = keyWithIndexToValue.get(key, numValues - 1) + keyWithIndexToValue.put(key, index, valueAtMaxIndex) + keyWithIndexToValue.remove(key, numValues - 1) + valueForIndex = valueAtMaxIndex + } else { + keyWithIndexToValue.remove(key, 0) + valueForIndex = null + } + numValues -= 1 + valueRemoved = true + } else { + valueForIndex = null + index += 1 + } + } + if (valueRemoved) { + if (numValues >= 1) { + keyToNumValues.put(key, numValues) + } else { + keyToNumValues.remove(key) + } + } + } + } + + def iterator(): Iterator[UnsafeRowPair] = { + val pair = new UnsafeRowPair() + keyWithIndexToValue.iterator.map { x => + pair.withRows(x.key, x.value) + } + } + + /** Commit all the changes to all the state stores */ + def commit(): Unit = { + keyToNumValues.commit() + keyWithIndexToValue.commit() + } + + /** Abort any changes to the state stores if needed */ + def abortIfNeeded(): Unit = { + keyWithIndexToValue.abortIfNeeded() + keyWithIndexToValue.abortIfNeeded() + } + + /** Get the combined metrics of all the state stores */ + def metrics: StateStoreMetrics = { + val keyToNumValuesMetrics = keyToNumValues.metrics + val keyWithIndexToValueMetrics = keyWithIndexToValue.metrics + def newDesc(desc: String): String = s"${joinSide.toString.toUpperCase}: $desc" + + StateStoreMetrics( + keyWithIndexToValueMetrics.numKeys, // represent each buffered row only once + keyToNumValuesMetrics.memoryUsedBytes + keyWithIndexToValueMetrics.memoryUsedBytes, + keyWithIndexToValueMetrics.customMetrics.map { + case (s @ StateStoreCustomSizeMetric(_, desc), value) => + s.copy(desc = newDesc(desc)) -> value + case (s @ StateStoreCustomTimingMetric(_, desc), value) => + s.copy(desc = newDesc(desc)) -> value + } + ) + } + + /* + ===================================================== + Private methods and inner classes + ===================================================== + */ + + private val keySchema = StructType( + joinKeys.zipWithIndex.map { case (k, i) => StructField(s"field$i", k.dataType, k.nullable) }) + private val keyAttributes = keySchema.toAttributes + private val keyToNumValues = new KeyToNumValuesStore() + private val keyWithIndexToValue = new KeyWithIndexToValueStore() + + /** Helper trait for invoking common functionalities of a state store. */ + private abstract class StateStoreHandler(stateStoreType: StateStoreType) extends Logging { + + /** StateStore that the subclasses of this class is going to operate on */ + protected def stateStore: StateStore + + def commit(): Unit = { + stateStore.commit() + logDebug("Committed, metrics = " + stateStore.metrics) + } + + def abortIfNeeded(): Unit = { + if (!stateStore.hasCommitted) { + logInfo(s"Aborted store ${stateStore.id}") + stateStore.abort() + } + } + + def metrics: StateStoreMetrics = stateStore.metrics + + /** Get the StateStore with the given schema */ + protected def getStateStore(keySchema: StructType, valueSchema: StructType): StateStore = { + val storeProviderId = StateStoreProviderId( + stateInfo.get, TaskContext.getPartitionId(), getStateStoreName(joinSide, stateStoreType)) + val store = StateStore.get( + storeProviderId, keySchema, valueSchema, None, + stateInfo.get.storeVersion, storeConf, hadoopConf) + logInfo(s"Loaded store ${store.id}") + store + } + } + + /** + * Helper class for representing data returned by [[KeyWithIndexToValueStore]]. + * Designed for object reuse. + */ + private case class KeyAndNumValues(var key: UnsafeRow = null, var numValue: Long = 0) { + def withNew(newKey: UnsafeRow, newNumValues: Long): this.type = { + this.key = newKey + this.numValue = newNumValues + this + } + } + + + /** A wrapper around a [[StateStore]] that stores [key -> number of values]. */ + private class KeyToNumValuesStore extends StateStoreHandler(KeyToNumValuesType) { + private val longValueSchema = new StructType().add("value", "long") + private val longToUnsafeRow = UnsafeProjection.create(longValueSchema) + private val valueRow = longToUnsafeRow(new SpecificInternalRow(longValueSchema)) + protected val stateStore: StateStore = getStateStore(keySchema, longValueSchema) + + /** Get the number of values the key has */ + def get(key: UnsafeRow): Long = { + val longValueRow = stateStore.get(key) + if (longValueRow != null) longValueRow.getLong(0) else 0L + } + + /** Set the number of values the key has */ + def put(key: UnsafeRow, numValues: Long): Unit = { + require(numValues > 0) + valueRow.setLong(0, numValues) + stateStore.put(key, valueRow) + } + + def remove(key: UnsafeRow): Unit = { + stateStore.remove(key) + } + + def iterator: Iterator[KeyAndNumValues] = { + val keyAndNumValues = new KeyAndNumValues() + stateStore.getRange(None, None).map { case pair => + keyAndNumValues.withNew(pair.key, pair.value.getLong(0)) + } + } + } + + /** + * Helper class for representing data returned by [[KeyWithIndexToValueStore]]. + * Designed for object reuse. + */ + private case class KeyWithIndexAndValue( + var key: UnsafeRow = null, var valueIndex: Long = -1, var value: UnsafeRow = null) { + def withNew(newKey: UnsafeRow, newIndex: Long, newValue: UnsafeRow): this.type = { + this.key = newKey + this.valueIndex = newIndex + this.value = newValue + this + } + } + + /** A wrapper around a [[StateStore]] that stores [(key, index) -> value]. */ + private class KeyWithIndexToValueStore extends StateStoreHandler(KeyWithIndexToValuesType) { + private val keyWithIndexExprs = keyAttributes :+ Literal(1L) + private val keyWithIndexSchema = StructType(keySchema.fields).add("index", LongType) --- End diff -- can't you just call `keySchema.add(..)`?
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org