This is an automated email from the ASF dual-hosted git repository. kabhwan pushed a commit to branch WIP-optimize-eviction-in-rocksdb-state-store in repository https://gitbox.apache.org/repos/asf/spark.git
commit 21d4f96c6c7a56acce30d485a0747e1d62d7ad27 Author: Jungtaek Lim <kabhwan.opensou...@gmail.com> AuthorDate: Thu Sep 30 11:53:39 2021 +0900 WIP: benchmark test code done --- .../streaming/FlatMapGroupsWithStateExec.scala | 7 +- .../state/HDFSBackedStateStoreProvider.scala | 42 +- .../sql/execution/streaming/state/RocksDB.scala | 158 ++++- .../streaming/state/RocksDBStateEncoder.scala | 135 ++++- .../state/RocksDBStateStoreProvider.scala | 100 +++- .../sql/execution/streaming/state/StateStore.scala | 40 +- .../execution/streaming/state/StateStoreRDD.scala | 8 +- .../state/StreamingAggregationStateManager.scala | 23 + .../state/SymmetricHashJoinStateManager.scala | 3 +- .../sql/execution/streaming/state/package.scala | 12 +- .../execution/streaming/statefulOperators.scala | 62 +- .../sql/execution/streaming/streamingLimits.scala | 5 +- .../execution/benchmark/StateStoreBenchmark.scala | 633 +++++++++++++++++++++ ...ngSortWithSessionWindowStateIteratorSuite.scala | 7 +- .../streaming/state/MemoryStateStore.scala | 14 + .../state/RocksDBStateStoreIntegrationSuite.scala | 60 +- .../streaming/state/RocksDBStateStoreSuite.scala | 6 +- .../streaming/state/StateStoreRDDSuite.scala | 18 +- .../streaming/state/StateStoreSuite.scala | 31 +- .../StreamingSessionWindowStateManagerSuite.scala | 4 +- .../apache/spark/sql/streaming/StreamSuite.scala | 4 +- .../sql/streaming/StreamingAggregationSuite.scala | 34 +- 22 files changed, 1270 insertions(+), 136 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index a00a622..381aeb9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -224,21 +224,23 @@ case class FlatMapGroupsWithStateExec( val stateStoreId = StateStoreId( stateInfo.get.checkpointLocation, stateInfo.get.operatorId, partitionId) val storeProviderId = StateStoreProviderId(stateStoreId, stateInfo.get.queryRunId) + // FIXME: would setting prefixScan / evict help? val store = StateStore.get( storeProviderId, groupingAttributes.toStructType, stateManager.stateSchema, - numColsPrefixKey = 0, + StatefulOperatorContext(), stateInfo.get.storeVersion, storeConf, hadoopConfBroadcast.value.value) val processor = new InputProcessor(store) processDataWithPartition(childDataIterator, store, processor, Some(initStateIterator)) } } else { + // FIXME: would setting prefixScan / evict help? child.execute().mapPartitionsWithStateStore[InternalRow]( getStateInfo, groupingAttributes.toStructType, stateManager.stateSchema, - numColsPrefixKey = 0, + StatefulOperatorContext(), session.sqlContext.sessionState, Some(session.sqlContext.streams.stateStoreCoordinator) ) { case (store: StateStore, singleIterator: Iterator[InternalRow]) => @@ -334,6 +336,7 @@ case class FlatMapGroupsWithStateExec( throw new IllegalStateException( s"Cannot filter timed out keys for $timeoutConf") } + // FIXME: would setting prefixScan / evict help? val timingOutPairs = stateManager.getAllState(store).filter { state => state.timeoutTimestamp != NO_TIMESTAMP && state.timeoutTimestamp < timeoutThreshold } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 75b7dae..96ba2a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -100,8 +100,11 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with /** Trait and classes representing the internal state of the store */ trait STATE + case object UPDATING extends STATE + case object COMMITTED extends STATE + case object ABORTED extends STATE private val newVersion = version + 1 @@ -195,6 +198,22 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with override def toString(): String = { s"HDFSStateStore[id=(op=${id.operatorId},part=${id.partitionId}),dir=$baseDir]" } + + /** FIXME: method doc */ + override def evictOnWatermark( + watermarkMs: Long, + altPred: UnsafeRowPair => Boolean): Iterator[UnsafeRowPair] = { + // HDFSBackedStateStore doesn't index event time column + // FIXME: should we do this for in-memory as well? + iterator().filter { pair => + if (altPred.apply(pair)) { + remove(pair.key) + true + } else { + false + } + } + } } def getMetricsForProvider(): Map[String, Long] = synchronized { @@ -219,7 +238,7 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with private def getLoadedMapForStore(version: Long): HDFSBackedStateStoreMap = synchronized { require(version >= 0, "Version cannot be less than 0") - val newMap = HDFSBackedStateStoreMap.create(keySchema, numColsPrefixKey) + val newMap = HDFSBackedStateStoreMap.create(keySchema, operatorContext.numColsPrefixKey) if (version > 0) { newMap.putAll(loadMap(version)) } @@ -230,7 +249,7 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with stateStoreId: StateStoreId, keySchema: StructType, valueSchema: StructType, - numColsPrefixKey: Int, + operatorContext: StatefulOperatorContext, storeConf: StateStoreConf, hadoopConf: Configuration): Unit = { this.stateStoreId_ = stateStoreId @@ -240,10 +259,11 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with this.hadoopConf = hadoopConf this.numberOfVersionsToRetainInMemory = storeConf.maxVersionsToRetainInMemory - require((keySchema.length == 0 && numColsPrefixKey == 0) || - (keySchema.length > numColsPrefixKey), "The number of columns in the key must be " + - "greater than the number of columns for prefix key!") - this.numColsPrefixKey = numColsPrefixKey + require((keySchema.length == 0 && operatorContext.numColsPrefixKey == 0) || + (keySchema.length > operatorContext.numColsPrefixKey), "The number of columns in the key " + + "must be greater than the number of columns for prefix key!") + + this.operatorContext = operatorContext fm.mkdirs(baseDir) } @@ -283,7 +303,7 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with @volatile private var storeConf: StateStoreConf = _ @volatile private var hadoopConf: Configuration = _ @volatile private var numberOfVersionsToRetainInMemory: Int = _ - @volatile private var numColsPrefixKey: Int = 0 + @volatile private var operatorContext: StatefulOperatorContext = _ // TODO: The validation should be moved to a higher level so that it works for all state store // implementations @@ -401,7 +421,8 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with if (lastAvailableVersion <= 0) { // Use an empty map for versions 0 or less. - lastAvailableMap = Some(HDFSBackedStateStoreMap.create(keySchema, numColsPrefixKey)) + lastAvailableMap = Some(HDFSBackedStateStoreMap.create(keySchema, + operatorContext.numColsPrefixKey)) } else { lastAvailableMap = synchronized { Option(loadedMaps.get(lastAvailableVersion)) } @@ -411,7 +432,8 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with // Load all the deltas from the version after the last available one up to the target version. // The last available version is the one with a full snapshot, so it doesn't need deltas. - val resultMap = HDFSBackedStateStoreMap.create(keySchema, numColsPrefixKey) + val resultMap = HDFSBackedStateStoreMap.create(keySchema, + operatorContext.numColsPrefixKey) resultMap.putAll(lastAvailableMap.get) for (deltaVersion <- lastAvailableVersion + 1 to version) { updateFromDeltaFile(deltaVersion, resultMap) @@ -554,7 +576,7 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with private def readSnapshotFile(version: Long): Option[HDFSBackedStateStoreMap] = { val fileToRead = snapshotFile(version) - val map = HDFSBackedStateStoreMap.create(keySchema, numColsPrefixKey) + val map = HDFSBackedStateStoreMap.create(keySchema, operatorContext.numColsPrefixKey) var input: DataInputStream = null try { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index 1ff8b41..eed7827 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.streaming.state import java.io.File +import java.util import java.util.Locale import javax.annotation.concurrent.GuardedBy @@ -50,9 +51,12 @@ import org.apache.spark.util.{NextIterator, Utils} * @param hadoopConf Hadoop configuration for talking to the remote file system * @param loggingId Id that will be prepended in logs for isolating concurrent RocksDBs */ +// FIXME: optionally receiving column families class RocksDB( dfsRootDir: String, val conf: RocksDBConf, + // TODO: change "default" to constant + columnFamilies: Seq[String] = Seq("default"), localRootDir: File = Utils.createTempDir(), hadoopConf: Configuration = new Configuration, loggingId: String = "") extends Logging { @@ -65,16 +69,10 @@ class RocksDB( private val flushOptions = new FlushOptions().setWaitForFlush(true) // wait for flush to complete private val writeBatch = new WriteBatchWithIndex(true) // overwrite multiple updates to a key - private val bloomFilter = new BloomFilter() - private val tableFormatConfig = new BlockBasedTableConfig() - tableFormatConfig.setBlockSize(conf.blockSizeKB * 1024) - tableFormatConfig.setBlockCache(new LRUCache(conf.blockCacheSizeMB * 1024 * 1024)) - tableFormatConfig.setFilterPolicy(bloomFilter) - tableFormatConfig.setFormatVersion(conf.formatVersion) - - private val dbOptions = new Options() // options to open the RocksDB + private val dbOptions: DBOptions = new DBOptions() // options to open the RocksDB dbOptions.setCreateIfMissing(true) - dbOptions.setTableFormatConfig(tableFormatConfig) + dbOptions.setCreateMissingColumnFamilies(true) + private val dbLogger = createLogger() // for forwarding RocksDB native logs to log4j dbOptions.setStatistics(new Statistics()) private val nativeStats = dbOptions.statistics() @@ -87,6 +85,8 @@ class RocksDB( private val acquireLock = new Object @volatile private var db: NativeRocksDB = _ + @volatile private var columnFamilyHandles: util.Map[String, ColumnFamilyHandle] = _ + @volatile private var defaultColumnFamilyHandle: ColumnFamilyHandle = _ @volatile private var loadedVersion = -1L // -1 = nothing valid is loaded @volatile private var numKeysOnLoadedVersion = 0L @volatile private var numKeysOnWritingVersion = 0L @@ -96,7 +96,7 @@ class RocksDB( @volatile private var acquiredThreadInfo: AcquiredThreadInfo = _ private val prefixScanReuseIter = - new java.util.concurrent.ConcurrentHashMap[Long, RocksIterator]() + new java.util.concurrent.ConcurrentHashMap[(Long, Int), RocksIterator]() /** * Load the given version of data in a native RocksDB instance. @@ -137,7 +137,28 @@ class RocksDB( * @note This will return the last written value even if it was uncommitted. */ def get(key: Array[Byte]): Array[Byte] = { - writeBatch.getFromBatchAndDB(db, readOptions, key) + get(defaultColumnFamilyHandle, key) + } + + // FIXME: method doc + def get(cf: String, key: Array[Byte]): Array[Byte] = { + get(findColumnFamilyHandle(cf), key) + } + + private def get(cfHandle: ColumnFamilyHandle, key: Array[Byte]): Array[Byte] = { + writeBatch.getFromBatchAndDB(db, cfHandle, readOptions, key) + } + + def merge(key: Array[Byte], value: Array[Byte]): Unit = { + merge(defaultColumnFamilyHandle, key, value) + } + + def merge(cf: String, key: Array[Byte], value: Array[Byte]): Unit = { + merge(findColumnFamilyHandle(cf), key, value) + } + + private def merge(cfHandle: ColumnFamilyHandle, key: Array[Byte], value: Array[Byte]): Unit = { + writeBatch.merge(cfHandle, key, value) } /** @@ -145,8 +166,20 @@ class RocksDB( * @note This update is not committed to disk until commit() is called. */ def put(key: Array[Byte], value: Array[Byte]): Array[Byte] = { - val oldValue = writeBatch.getFromBatchAndDB(db, readOptions, key) - writeBatch.put(key, value) + put(defaultColumnFamilyHandle, key, value) + } + + // FIXME: method doc + def put(cf: String, key: Array[Byte], value: Array[Byte]): Array[Byte] = { + put(findColumnFamilyHandle(cf), key, value) + } + + private def put( + cfHandle: ColumnFamilyHandle, + key: Array[Byte], + value: Array[Byte]): Array[Byte] = { + val oldValue = writeBatch.getFromBatchAndDB(db, cfHandle, readOptions, key) + writeBatch.put(cfHandle, key, value) if (oldValue == null) { numKeysOnWritingVersion += 1 } @@ -158,9 +191,18 @@ class RocksDB( * @note This update is not committed to disk until commit() is called. */ def remove(key: Array[Byte]): Array[Byte] = { - val value = writeBatch.getFromBatchAndDB(db, readOptions, key) + remove(defaultColumnFamilyHandle, key) + } + + // FIXME: method doc + def remove(cf: String, key: Array[Byte]): Array[Byte] = { + remove(findColumnFamilyHandle(cf), key) + } + + private def remove(cfHandle: ColumnFamilyHandle, key: Array[Byte]): Array[Byte] = { + val value = writeBatch.getFromBatchAndDB(db, cfHandle, readOptions, key) if (value != null) { - writeBatch.remove(key) + writeBatch.delete(cfHandle, key) numKeysOnWritingVersion -= 1 } value @@ -169,8 +211,17 @@ class RocksDB( /** * Get an iterator of all committed and uncommitted key-value pairs. */ - def iterator(): Iterator[ByteArrayPair] = { - val iter = writeBatch.newIteratorWithBase(db.newIterator()) + def iterator(): NextIterator[ByteArrayPair] = { + iterator(defaultColumnFamilyHandle) + } + + // FIXME: doc + def iterator(cf: String): NextIterator[ByteArrayPair] = { + iterator(findColumnFamilyHandle(cf)) + } + + private def iterator(cfHandle: ColumnFamilyHandle): NextIterator[ByteArrayPair] = { + val iter = writeBatch.newIteratorWithBase(cfHandle, db.newIterator(cfHandle)) logInfo(s"Getting iterator from version $loadedVersion") iter.seekToFirst() @@ -197,11 +248,20 @@ class RocksDB( } def prefixScan(prefix: Array[Byte]): Iterator[ByteArrayPair] = { + prefixScan(defaultColumnFamilyHandle, prefix) + } + + def prefixScan(cf: String, prefix: Array[Byte]): Iterator[ByteArrayPair] = { + prefixScan(findColumnFamilyHandle(cf), prefix) + } + + private def prefixScan( + cfHandle: ColumnFamilyHandle, prefix: Array[Byte]): Iterator[ByteArrayPair] = { val threadId = Thread.currentThread().getId - val iter = prefixScanReuseIter.computeIfAbsent(threadId, tid => { - val it = writeBatch.newIteratorWithBase(db.newIterator()) + val iter = prefixScanReuseIter.computeIfAbsent((threadId, cfHandle.getID), key => { + val it = writeBatch.newIteratorWithBase(cfHandle, db.newIterator(cfHandle)) logInfo(s"Getting iterator from version $loadedVersion for prefix scan on " + - s"thread ID $tid") + s"thread ID ${key._1} and column family ID ${key._2}") it }) @@ -223,6 +283,14 @@ class RocksDB( } } + private def findColumnFamilyHandle(cf: String): ColumnFamilyHandle = { + val cfHandle = columnFamilyHandles.get(cf) + if (cfHandle == null) { + throw new IllegalArgumentException(s"Handle for column family $cf is not found") + } + cfHandle + } + /** * Commit all the updates made as a version to DFS. The steps it needs to do to commits are: * - Write all the updates to the native RocksDB @@ -242,11 +310,16 @@ class RocksDB( val writeTimeMs = timeTakenMs { db.write(writeOptions, writeBatch) } logInfo(s"Flushing updates for $newVersion") - val flushTimeMs = timeTakenMs { db.flush(flushOptions) } + val flushTimeMs = timeTakenMs { + db.flush(flushOptions, + new util.ArrayList[ColumnFamilyHandle](columnFamilyHandles.values())) + } val compactTimeMs = if (conf.compactOnCommit) { logInfo("Compacting") - timeTakenMs { db.compactRange() } + timeTakenMs { + columnFamilyHandles.values().forEach(cfHandle => db.compactRange(cfHandle)) + } } else 0 logInfo("Pausing background work") @@ -279,6 +352,7 @@ class RocksDB( loadedVersion } catch { case t: Throwable => + logWarning(s"ERROR! exc: $t", t) loadedVersion = -1 // invalidate loaded version throw t } finally { @@ -422,12 +496,43 @@ class RocksDB( private def openDB(): Unit = { assert(db == null) - db = NativeRocksDB.open(dbOptions, workingDir.toString) + + val columnFamilyDescriptors = new util.ArrayList[ColumnFamilyDescriptor]() + columnFamilies.foreach { cf => + val bloomFilter = new BloomFilter() + val tableFormatConfig = new BlockBasedTableConfig() + tableFormatConfig.setBlockSize(conf.blockSizeKB * 1024) + tableFormatConfig.setBlockCache(new LRUCache(conf.blockCacheSizeMB * 1024 * 1024)) + tableFormatConfig.setFilterPolicy(bloomFilter) + tableFormatConfig.setFormatVersion(conf.formatVersion) + + val columnFamilyOptions = new ColumnFamilyOptions() + columnFamilyOptions.setTableFormatConfig(tableFormatConfig) + columnFamilyDescriptors.add(new ColumnFamilyDescriptor(cf.getBytes(), columnFamilyOptions)) + } + + val cfHandles = new util.ArrayList[ColumnFamilyHandle](columnFamilyDescriptors.size()) + db = NativeRocksDB.open(dbOptions, workingDir.toString, columnFamilyDescriptors, + cfHandles) + + columnFamilyHandles = new util.HashMap[String, ColumnFamilyHandle]() + columnFamilies.indices.foreach { idx => + columnFamilyHandles.put(columnFamilies(idx), cfHandles.get(idx)) + } + + // FIXME: constant + defaultColumnFamilyHandle = columnFamilyHandles.get("default") + logInfo(s"Opened DB with conf ${conf}") } private def closeDB(): Unit = { if (db != null) { + columnFamilyHandles.entrySet().forEach(pair => db.destroyColumnFamilyHandle(pair.getValue)) + columnFamilyHandles.clear() + columnFamilyHandles = null + defaultColumnFamilyHandle = null + db.close() db = null } @@ -441,10 +546,17 @@ class RocksDB( // Warn is mapped to info because RocksDB warn is too verbose // (e.g. dumps non-warning stuff like stats) val loggingFunc: ( => String) => Unit = infoLogLevel match { + /* case InfoLogLevel.FATAL_LEVEL | InfoLogLevel.ERROR_LEVEL => logError(_) case InfoLogLevel.WARN_LEVEL | InfoLogLevel.INFO_LEVEL => logInfo(_) case InfoLogLevel.DEBUG_LEVEL => logDebug(_) case _ => logTrace(_) + */ + case InfoLogLevel.FATAL_LEVEL | InfoLogLevel.ERROR_LEVEL => logError(_) + case InfoLogLevel.WARN_LEVEL => logWarning(_) + case InfoLogLevel.INFO_LEVEL => logInfo(_) + case InfoLogLevel.DEBUG_LEVEL => logDebug(_) + case _ => logTrace(_) } loggingFunc(s"[NativeRocksDB-${infoLogLevel.getValue}] $logMsg") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala index 81755e5..323826d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala @@ -17,9 +17,12 @@ package org.apache.spark.sql.execution.streaming.state +import java.lang.{Long => JLong} +import java.nio.ByteOrder + import org.apache.spark.sql.catalyst.expressions.{BoundReference, JoinedRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider.{STATE_ENCODING_NUM_VERSION_BYTES, STATE_ENCODING_VERSION} -import org.apache.spark.sql.types.{StructField, StructType} +import org.apache.spark.sql.types.{StructField, StructType, TimestampType} import org.apache.spark.unsafe.Platform sealed trait RocksDBStateEncoder { @@ -27,6 +30,11 @@ sealed trait RocksDBStateEncoder { def encodePrefixKey(prefixKey: UnsafeRow): Array[Byte] def extractPrefixKey(key: UnsafeRow): UnsafeRow + def supportEventTimeIndex: Boolean + def extractEventTime(key: UnsafeRow): Long + def encodeEventTimeIndexKey(timestamp: Long, encodedKey: Array[Byte]): Array[Byte] + def decodeEventTimeIndexKey(eventTimeBytes: Array[Byte]): (Long, Array[Byte]) + def encodeKey(row: UnsafeRow): Array[Byte] def encodeValue(row: UnsafeRow): Array[Byte] @@ -39,11 +47,13 @@ object RocksDBStateEncoder { def getEncoder( keySchema: StructType, valueSchema: StructType, - numColsPrefixKey: Int): RocksDBStateEncoder = { + numColsPrefixKey: Int, + eventTimeColIdx: Array[Int]): RocksDBStateEncoder = { if (numColsPrefixKey > 0) { + // FIXME: need to deal with prefix case as well new PrefixKeyScanStateEncoder(keySchema, valueSchema, numColsPrefixKey) } else { - new NoPrefixKeyStateEncoder(keySchema, valueSchema) + new NoPrefixKeyStateEncoder(keySchema, valueSchema, eventTimeColIdx) } } @@ -86,6 +96,39 @@ object RocksDBStateEncoder { } } +object BinarySortable { + private val NATIVE_BYTE_ORDER: ByteOrder = ByteOrder.nativeOrder() + private val LITTLE_ENDIAN: Boolean = NATIVE_BYTE_ORDER == ByteOrder.LITTLE_ENDIAN + private val SIGN_BIT_LONG: Long = (1L << 63) + + def encodeToBinarySortableLong(value: Long): Long = { + // Flip the sign bit. This simply works with binary form of comparison, as negative values + // are placed in reversed order, and positive values are placed in sequential order. + val encoded = value ^ SIGN_BIT_LONG + + // We have to retain the sequence of bytes as same as BIG_ENDIAN, as the binary form will be + // compared in sequential order (via offset). + if (LITTLE_ENDIAN) { + JLong.reverseBytes(encoded) + } else { + encoded + } + } + + def decodeBinarySortableLong(encoded: Long): Long = { + // The value is based on BIG_ENDIAN. If the system is LITTLE_ENDIAN, we should convert it to + // follow the system. + val decoded = if (LITTLE_ENDIAN) { + JLong.reverseBytes(encoded) + } else { + encoded + } + + // Flip the sign bit as encode function does. + decoded ^ SIGN_BIT_LONG + } +} + class PrefixKeyScanStateEncoder( keySchema: StructType, valueSchema: StructType, @@ -185,6 +228,23 @@ class PrefixKeyScanStateEncoder( } override def supportPrefixKeyScan: Boolean = true + + override def supportEventTimeIndex: Boolean = false + + // FIXME: fix me! + def extractEventTime(key: UnsafeRow): Long = { + throw new IllegalStateException("This encoder doesn't support event time index!") + } + + // FIXME: fix me! + def encodeEventTimeIndexKey(timestamp: Long, encodedKey: Array[Byte]): Array[Byte] = { + throw new IllegalStateException("This encoder doesn't support event time index!") + } + + // FIXME: fix me! + def decodeEventTimeIndexKey(eventTimeBytes: Array[Byte]): (Long, Array[Byte]) = { + throw new IllegalStateException("This encoder doesn't support event time index!") + } } /** @@ -197,8 +257,10 @@ class PrefixKeyScanStateEncoder( * (offset 0 is the version byte of value 0). That is, if the unsafe row has N bytes, * then the generated array byte will be N+1 bytes. */ -class NoPrefixKeyStateEncoder(keySchema: StructType, valueSchema: StructType) - extends RocksDBStateEncoder { +class NoPrefixKeyStateEncoder( + keySchema: StructType, + valueSchema: StructType, + eventTimeColIdx: Array[Int]) extends RocksDBStateEncoder { import RocksDBStateEncoder._ @@ -207,6 +269,32 @@ class NoPrefixKeyStateEncoder(keySchema: StructType, valueSchema: StructType) private val valueRow = new UnsafeRow(valueSchema.size) private val rowTuple = new UnsafeRowPair() + validateColumnTypeOnEventTimeColumn() + + private def validateColumnTypeOnEventTimeColumn(): Unit = { + if (eventTimeColIdx.nonEmpty) { + var curSchema: StructType = keySchema + eventTimeColIdx.dropRight(1).foreach { idx => + curSchema(idx).dataType match { + case stType: StructType => + curSchema = stType + case _ => + // FIXME: better error message + throw new IllegalStateException("event time column is not properly specified! " + + s"index: ${eventTimeColIdx.mkString("(", ", ", ")")} / key schema: $keySchema") + } + } + + curSchema(eventTimeColIdx.last).dataType match { + case _: TimestampType => + case _ => + // FIXME: better error message + throw new IllegalStateException("event time column is not properly specified! " + + s"index: ${eventTimeColIdx.mkString("(", ", ", ")")} / key schema: $keySchema") + } + } + } + override def encodeKey(row: UnsafeRow): Array[Byte] = encodeUnsafeRow(row) override def encodeValue(row: UnsafeRow): Array[Byte] = encodeUnsafeRow(row) @@ -249,4 +337,41 @@ class NoPrefixKeyStateEncoder(keySchema: StructType, valueSchema: StructType) override def encodePrefixKey(prefixKey: UnsafeRow): Array[Byte] = { throw new IllegalStateException("This encoder doesn't support prefix key!") } + + override def supportEventTimeIndex: Boolean = eventTimeColIdx.nonEmpty + + override def extractEventTime(key: UnsafeRow): Long = { + var curRow: UnsafeRow = key + var curSchema: StructType = keySchema + + eventTimeColIdx.dropRight(1).foreach { idx => + // validation is done in initialization phase + curSchema = curSchema(idx).dataType.asInstanceOf[StructType] + curRow = curRow.getStruct(idx, curSchema.length) + } + + curRow.getLong(eventTimeColIdx.last) / 1000 + } + + override def encodeEventTimeIndexKey(timestamp: Long, encodedKey: Array[Byte]): Array[Byte] = { + val newKey = new Array[Byte](8 + encodedKey.length) + + Platform.putLong(newKey, Platform.BYTE_ARRAY_OFFSET, + BinarySortable.encodeToBinarySortableLong(timestamp)) + Platform.copyMemory(encodedKey, Platform.BYTE_ARRAY_OFFSET, newKey, + Platform.BYTE_ARRAY_OFFSET + 8, encodedKey.length) + + newKey + } + + override def decodeEventTimeIndexKey(eventTimeBytes: Array[Byte]): (Long, Array[Byte]) = { + val encoded = Platform.getLong(eventTimeBytes, Platform.BYTE_ARRAY_OFFSET) + val timestamp = BinarySortable.decodeBinarySortableLong(encoded) + + val key = new Array[Byte](eventTimeBytes.length - 8) + Platform.copyMemory(eventTimeBytes, Platform.BYTE_ARRAY_OFFSET + 8, + key, Platform.BYTE_ARRAY_OFFSET, eventTimeBytes.length - 8) + + (timestamp, key) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index a2b33c2..1d66220 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -25,7 +25,8 @@ import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.types.StructType -import org.apache.spark.util.Utils +import org.apache.spark.unsafe.Platform +import org.apache.spark.util.{NextIterator, Utils} private[sql] class RocksDBStateStoreProvider extends StateStoreProvider with Logging with Closeable { @@ -60,12 +61,23 @@ private[sql] class RocksDBStateStoreProvider verify(state == UPDATING, "Cannot put after already committed or aborted") verify(key != null, "Key cannot be null") require(value != null, "Cannot put a null value") - rocksDB.put(encoder.encodeKey(key), encoder.encodeValue(value)) + + val encodedKey = encoder.encodeKey(key) + val encodedValue = encoder.encodeValue(value) + rocksDB.put(encodedKey, encodedValue) + + if (encoder.supportEventTimeIndex) { + val timestamp = encoder.extractEventTime(key) + val tsKey = encoder.encodeEventTimeIndexKey(timestamp, encodedKey) + + rocksDB.put(CF_EVENT_TIME_INDEX, tsKey, Array.empty) + } } override def remove(key: UnsafeRow): Unit = { verify(state == UPDATING, "Cannot remove after already committed or aborted") verify(key != null, "Key cannot be null") + // FIXME: this should reflect the index rocksDB.remove(encoder.encodeKey(key)) } @@ -161,13 +173,75 @@ private[sql] class RocksDBStateStoreProvider /** Return the [[RocksDB]] instance in this store. This is exposed mainly for testing. */ def dbInstance(): RocksDB = rocksDB + + /** FIXME: method doc */ + override def evictOnWatermark( + watermarkMs: Long, + altPred: UnsafeRowPair => Boolean): Iterator[UnsafeRowPair] = { + + if (encoder.supportEventTimeIndex) { + val kv = new ByteArrayPair() + + // FIXME: DEBUG + // logWarning(s"DEBUG: start iterating event time index, watermark $watermarkMs") + + new NextIterator[UnsafeRowPair] { + private val iter = rocksDB.iterator(CF_EVENT_TIME_INDEX) + override protected def getNext(): UnsafeRowPair = { + if (iter.hasNext) { + val pair = iter.next() + + val encodedTs = Platform.getLong(pair.key, Platform.BYTE_ARRAY_OFFSET) + val decodedTs = BinarySortable.decodeBinarySortableLong(encodedTs) + + // FIXME: DEBUG + // logWarning(s"DEBUG: decoded TS: $decodedTs") + + if (decodedTs > watermarkMs) { + finished = true + null + } else { + // FIXME: can we leverage deleteRange to bulk delete on index? + rocksDB.remove(CF_EVENT_TIME_INDEX, pair.key) + val (_, encodedKey) = encoder.decodeEventTimeIndexKey(pair.key) + val value = rocksDB.get(encodedKey) + if (value == null) { + throw new IllegalStateException("Event time index has been broken!") + } + kv.set(encodedKey, value) + val rowPair = encoder.decode(kv) + rocksDB.remove(encodedKey) + rowPair + } + } else { + finished = true + null + } + } + + override protected def close(): Unit = { + iter.closeIfNeeded() + } + } + } else { + rocksDB.iterator().flatMap { kv => + val rowPair = encoder.decode(kv) + if (altPred(rowPair)) { + rocksDB.remove(kv.key) + Some(rowPair) + } else { + None + } + } + } + } } override def init( stateStoreId: StateStoreId, keySchema: StructType, valueSchema: StructType, - numColsPrefixKey: Int, + operatorContext: StatefulOperatorContext, storeConf: StateStoreConf, hadoopConf: Configuration): Unit = { this.stateStoreId_ = stateStoreId @@ -176,11 +250,14 @@ private[sql] class RocksDBStateStoreProvider this.storeConf = storeConf this.hadoopConf = hadoopConf - require((keySchema.length == 0 && numColsPrefixKey == 0) || - (keySchema.length > numColsPrefixKey), "The number of columns in the key must be " + - "greater than the number of columns for prefix key!") + require((keySchema.length == 0 && operatorContext.numColsPrefixKey == 0) || + (keySchema.length > operatorContext.numColsPrefixKey), "The number of columns in the key " + + "must be greater than the number of columns for prefix key!") - this.encoder = RocksDBStateEncoder.getEncoder(keySchema, valueSchema, numColsPrefixKey) + this.operatorContext = operatorContext + + this.encoder = RocksDBStateEncoder.getEncoder(keySchema, valueSchema, + operatorContext.numColsPrefixKey, operatorContext.eventTimeColIdx) rocksDB // lazy initialization } @@ -212,6 +289,7 @@ private[sql] class RocksDBStateStoreProvider @volatile private var valueSchema: StructType = _ @volatile private var storeConf: StateStoreConf = _ @volatile private var hadoopConf: Configuration = _ + @volatile private var operatorContext: StatefulOperatorContext = _ private[sql] lazy val rocksDB = { val dfsRootDir = stateStoreId.storeCheckpointLocation().toString @@ -219,7 +297,10 @@ private[sql] class RocksDBStateStoreProvider s"partId=${stateStoreId.partitionId},name=${stateStoreId.storeName})" val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf) val localRootDir = Utils.createTempDir(Utils.getLocalDir(sparkConf), storeIdStr) - new RocksDB(dfsRootDir, RocksDBConf(storeConf), localRootDir, hadoopConf, storeIdStr) + new RocksDB(dfsRootDir, RocksDBConf(storeConf), + columnFamilies = Seq("default", RocksDBStateStoreProvider.CF_EVENT_TIME_INDEX), + localRootDir = localRootDir, + hadoopConf = hadoopConf, loggingId = storeIdStr) } @volatile private var encoder: RocksDBStateEncoder = _ @@ -234,6 +315,9 @@ object RocksDBStateStoreProvider { val STATE_ENCODING_NUM_VERSION_BYTES = 1 val STATE_ENCODING_VERSION: Byte = 0 + // reserved column families + val CF_EVENT_TIME_INDEX: String = "__event_time_idx" + // Native operation latencies report as latency in microseconds // as SQLMetrics support millis. Convert the value to millis val CUSTOM_METRIC_GET_TIME = StateStoreCustomTimingMetric( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 5020638..cee9ad1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -130,6 +130,11 @@ trait StateStore extends ReadStateStore { */ override def iterator(): Iterator[UnsafeRowPair] + /** FIXME: method doc */ + def evictOnWatermark( + watermarkMs: Long, + altPred: UnsafeRowPair => Boolean): Iterator[UnsafeRowPair] + /** Current metrics of the state store */ def metrics: StateStoreMetrics @@ -229,6 +234,19 @@ class InvalidUnsafeRowException "checkpoint or use the legacy Spark version to process the streaming state.", null) /** + * FIXME: classdoc + * + * @param numColsPrefixKey The number of leftmost columns to be used as prefix key. + * A value not greater than 0 means the operator doesn't activate prefix + * key, and the operator should not call prefixScan method in StateStore. + * @param eventTimeColIdx column specifying event time for the row. only works when the column + * is in the key. array type as the column can be struct type. + */ +case class StatefulOperatorContext( + numColsPrefixKey: Int = 0, + eventTimeColIdx: Array[Int] = Array.empty) + +/** * Trait representing a provider that provide [[StateStore]] instances representing * versions of state data. * @@ -255,9 +273,7 @@ trait StateStoreProvider { * @param stateStoreId Id of the versioned StateStores that this provider will generate * @param keySchema Schema of keys to be stored * @param valueSchema Schema of value to be stored - * @param numColsPrefixKey The number of leftmost columns to be used as prefix key. - * A value not greater than 0 means the operator doesn't activate prefix - * key, and the operator should not call prefixScan method in StateStore. + * @param operatorContext FIXME: ... * @param storeConfs Configurations used by the StateStores * @param hadoopConf Hadoop configuration that could be used by StateStore to save state data */ @@ -265,7 +281,7 @@ trait StateStoreProvider { stateStoreId: StateStoreId, keySchema: StructType, valueSchema: StructType, - numColsPrefixKey: Int, + operatorContext: StatefulOperatorContext, storeConfs: StateStoreConf, hadoopConf: Configuration): Unit @@ -318,11 +334,11 @@ object StateStoreProvider { providerId: StateStoreProviderId, keySchema: StructType, valueSchema: StructType, - numColsPrefixKey: Int, + operatorContext: StatefulOperatorContext, storeConf: StateStoreConf, hadoopConf: Configuration): StateStoreProvider = { val provider = create(storeConf.providerClass) - provider.init(providerId.storeId, keySchema, valueSchema, numColsPrefixKey, + provider.init(providerId.storeId, keySchema, valueSchema, operatorContext, storeConf, hadoopConf) provider } @@ -471,13 +487,13 @@ object StateStore extends Logging { storeProviderId: StateStoreProviderId, keySchema: StructType, valueSchema: StructType, - numColsPrefixKey: Int, + operatorContext: StatefulOperatorContext, version: Long, storeConf: StateStoreConf, hadoopConf: Configuration): ReadStateStore = { require(version >= 0) val storeProvider = getStateStoreProvider(storeProviderId, keySchema, valueSchema, - numColsPrefixKey, storeConf, hadoopConf) + operatorContext, storeConf, hadoopConf) storeProvider.getReadStore(version) } @@ -486,13 +502,13 @@ object StateStore extends Logging { storeProviderId: StateStoreProviderId, keySchema: StructType, valueSchema: StructType, - numColsPrefixKey: Int, + operatorContext: StatefulOperatorContext, version: Long, storeConf: StateStoreConf, hadoopConf: Configuration): StateStore = { require(version >= 0) val storeProvider = getStateStoreProvider(storeProviderId, keySchema, valueSchema, - numColsPrefixKey, storeConf, hadoopConf) + operatorContext, storeConf, hadoopConf) storeProvider.getStore(version) } @@ -500,7 +516,7 @@ object StateStore extends Logging { storeProviderId: StateStoreProviderId, keySchema: StructType, valueSchema: StructType, - numColsPrefixKey: Int, + operatorContext: StatefulOperatorContext, storeConf: StateStoreConf, hadoopConf: Configuration): StateStoreProvider = { loadedProviders.synchronized { @@ -527,7 +543,7 @@ object StateStore extends Logging { val provider = loadedProviders.getOrElseUpdate( storeProviderId, StateStoreProvider.createAndInit( - storeProviderId, keySchema, valueSchema, numColsPrefixKey, storeConf, hadoopConf) + storeProviderId, keySchema, valueSchema, operatorContext, storeConf, hadoopConf) ) val otherProviderIds = loadedProviders.keys.filter(_ != storeProviderId).toSeq val providerIdsToUnload = reportActiveStoreInstance(storeProviderId, otherProviderIds) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala index fbe83ad..33d83b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala @@ -74,7 +74,7 @@ class ReadStateStoreRDD[T: ClassTag, U: ClassTag]( storeVersion: Long, keySchema: StructType, valueSchema: StructType, - numColsPrefixKey: Int, + operatorContext: StatefulOperatorContext, sessionState: SessionState, @transient private val storeCoordinator: Option[StateStoreCoordinatorRef], extraOptions: Map[String, String] = Map.empty) @@ -87,7 +87,7 @@ class ReadStateStoreRDD[T: ClassTag, U: ClassTag]( val storeProviderId = getStateProviderId(partition) val store = StateStore.getReadOnly( - storeProviderId, keySchema, valueSchema, numColsPrefixKey, storeVersion, + storeProviderId, keySchema, valueSchema, operatorContext, storeVersion, storeConf, hadoopConfBroadcast.value.value) val inputIter = dataRDD.iterator(partition, ctxt) storeReadFunction(store, inputIter) @@ -108,7 +108,7 @@ class StateStoreRDD[T: ClassTag, U: ClassTag]( storeVersion: Long, keySchema: StructType, valueSchema: StructType, - numColsPrefixKey: Int, + operatorContext: StatefulOperatorContext, sessionState: SessionState, @transient private val storeCoordinator: Option[StateStoreCoordinatorRef], extraOptions: Map[String, String] = Map.empty) @@ -121,7 +121,7 @@ class StateStoreRDD[T: ClassTag, U: ClassTag]( val storeProviderId = getStateProviderId(partition) val store = StateStore.get( - storeProviderId, keySchema, valueSchema, numColsPrefixKey, storeVersion, + storeProviderId, keySchema, valueSchema, operatorContext, storeVersion, storeConf, hadoopConfBroadcast.value.value) val inputIter = dataRDD.iterator(partition, ctxt) storeUpdateFunction(store, inputIter) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala index 36138f1..c6b63cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala @@ -51,6 +51,12 @@ sealed trait StreamingAggregationStateManager extends Serializable { /** Remove a single non-null key from the target state store. */ def remove(store: StateStore, key: UnsafeRow): Unit + // FIXME: method doc! + def evictOnWatermark( + store: StateStore, + watermarkMs: Long, + altPred: UnsafeRowPair => Boolean): Iterator[UnsafeRowPair] + /** Return an iterator containing all the key-value pairs in target state store. */ def iterator(store: ReadStateStore): Iterator[UnsafeRowPair] @@ -128,6 +134,13 @@ class StreamingAggregationStateManagerImplV1( override def values(store: ReadStateStore): Iterator[UnsafeRow] = { store.iterator().map(_.value) } + + override def evictOnWatermark( + store: StateStore, + watermarkMs: Long, + altPred: UnsafeRowPair => Boolean): Iterator[UnsafeRowPair] = { + store.evictOnWatermark(watermarkMs, altPred) + } } /** @@ -186,6 +199,16 @@ class StreamingAggregationStateManagerImplV2( store.iterator().map(rowPair => new UnsafeRowPair(rowPair.key, restoreOriginalRow(rowPair))) } + + override def evictOnWatermark( + store: StateStore, + watermarkMs: Long, + altPred: UnsafeRowPair => Boolean): Iterator[UnsafeRowPair] = { + store.evictOnWatermark(watermarkMs, altPred).map { rowPair => + new UnsafeRowPair(rowPair.key, restoreOriginalRow(rowPair)) + } + } + override def values(store: ReadStateStore): Iterator[UnsafeRow] = { store.iterator().map(rowPair => restoreOriginalRow(rowPair)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala index f301d23..ce56845 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala @@ -363,8 +363,9 @@ class SymmetricHashJoinStateManager( protected def getStateStore(keySchema: StructType, valueSchema: StructType): StateStore = { val storeProviderId = StateStoreProviderId( stateInfo.get, partitionId, getStateStoreName(joinSide, stateStoreType)) + // FIXME: would setting prefixScan / evict help? val store = StateStore.get( - storeProviderId, keySchema, valueSchema, numColsPrefixKey = 0, + storeProviderId, keySchema, valueSchema, StatefulOperatorContext(), stateInfo.get.storeVersion, storeConf, hadoopConf) logInfo(s"Loaded store ${store.id}") store diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala index 01ff72b..7cd25eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -35,14 +35,14 @@ package object state { stateInfo: StatefulOperatorStateInfo, keySchema: StructType, valueSchema: StructType, - numColsPrefixKey: Int)( + operatorContext: StatefulOperatorContext)( storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U]): StateStoreRDD[T, U] = { mapPartitionsWithStateStore( stateInfo, keySchema, valueSchema, - numColsPrefixKey, + operatorContext, sqlContext.sessionState, Some(sqlContext.streams.stateStoreCoordinator))( storeUpdateFunction) @@ -53,7 +53,7 @@ package object state { stateInfo: StatefulOperatorStateInfo, keySchema: StructType, valueSchema: StructType, - numColsPrefixKey: Int, + operatorContext: StatefulOperatorContext, sessionState: SessionState, storeCoordinator: Option[StateStoreCoordinatorRef], extraOptions: Map[String, String] = Map.empty)( @@ -77,7 +77,7 @@ package object state { stateInfo.storeVersion, keySchema, valueSchema, - numColsPrefixKey, + operatorContext, sessionState, storeCoordinator, extraOptions) @@ -88,7 +88,7 @@ package object state { stateInfo: StatefulOperatorStateInfo, keySchema: StructType, valueSchema: StructType, - numColsPrefixKey: Int, + operatorContext: StatefulOperatorContext, sessionState: SessionState, storeCoordinator: Option[StateStoreCoordinatorRef], extraOptions: Map[String, String] = Map.empty)( @@ -112,7 +112,7 @@ package object state { stateInfo.storeVersion, keySchema, valueSchema, - numColsPrefixKey, + operatorContext, sessionState, storeCoordinator, extraOptions) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 3431823..923d94e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -237,11 +237,10 @@ trait WatermarkSupport extends SparkPlan { protected def removeKeysOlderThanWatermark(store: StateStore): Unit = { if (watermarkPredicateForKeys.nonEmpty) { val numRemovedStateRows = longMetric("numRemovedStateRows") - store.iterator().foreach { rowPair => - if (watermarkPredicateForKeys.get.eval(rowPair.key)) { - store.remove(rowPair.key) - numRemovedStateRows += 1 - } + store.evictOnWatermark(eventTimeWatermark.get, pair => { + watermarkPredicateForKeys.get.eval(pair.key) + }).foreach { _ => + numRemovedStateRows += 1 } } } @@ -251,11 +250,10 @@ trait WatermarkSupport extends SparkPlan { store: StateStore): Unit = { if (watermarkPredicateForKeys.nonEmpty) { val numRemovedStateRows = longMetric("numRemovedStateRows") - storeManager.keys(store).foreach { keyRow => - if (watermarkPredicateForKeys.get.eval(keyRow)) { - storeManager.remove(store, keyRow) - numRemovedStateRows += 1 - } + storeManager.evictOnWatermark(store, + eventTimeWatermark.get, pair => watermarkPredicateForKeys.get.eval(pair.key) + ).foreach { _ => + numRemovedStateRows += 1 } } } @@ -307,7 +305,8 @@ case class StateStoreRestoreExec( getStateInfo, keyExpressions.toStructType, stateManager.getStateValueSchema, - numColsPrefixKey = 0, + // FIXME: set event time column here! + StatefulOperatorContext(), session.sessionState, Some(session.streams.stateStoreCoordinator)) { case (store, iter) => val hasInput = iter.hasNext @@ -365,11 +364,26 @@ case class StateStoreSaveExec( assert(outputMode.nonEmpty, "Incorrect planning in IncrementalExecution, outputMode has not been set") + val eventTimeIdx = keyExpressions.indexWhere(_.metadata.contains(EventTimeWatermark.delayKey)) + val eventTimeColIdx = if (eventTimeIdx >= 0) { + keyExpressions.toStructType(eventTimeIdx).dataType match { + // FIXME: for now, we only consider window operation here, as we do the same in + // WatermarkSupport.watermarkExpression + case StructType(_) => Array[Int](eventTimeIdx, 1) + case TimestampType => Array[Int](eventTimeIdx) + case _ => throw new IllegalStateException( + "The type of event time column should be timestamp") + } + } else { + Array.empty[Int] + } + child.execute().mapPartitionsWithStateStore( getStateInfo, keyExpressions.toStructType, stateManager.getStateValueSchema, - numColsPrefixKey = 0, + // FIXME: set event time column here! + StatefulOperatorContext(eventTimeColIdx = eventTimeColIdx), session.sessionState, Some(session.streams.stateStoreCoordinator)) { (store, iter) => val numOutputRows = longMetric("numOutputRows") @@ -414,18 +428,16 @@ case class StateStoreSaveExec( } val removalStartTimeNs = System.nanoTime - val rangeIter = stateManager.iterator(store) + val evictedIter = stateManager.evictOnWatermark(store, + eventTimeWatermark.get, pair => watermarkPredicateForKeys.get.eval(pair.key)) new NextIterator[InternalRow] { override protected def getNext(): InternalRow = { var removedValueRow: InternalRow = null - while(rangeIter.hasNext && removedValueRow == null) { - val rowPair = rangeIter.next() - if (watermarkPredicateForKeys.get.eval(rowPair.key)) { - stateManager.remove(store, rowPair.key) - numRemovedStateRows += 1 - removedValueRow = rowPair.value - } + while(evictedIter.hasNext && removedValueRow == null) { + val rowPair = evictedIter.next() + numRemovedStateRows += 1 + removedValueRow = rowPair.value } if (removedValueRow == null) { finished = true @@ -541,7 +553,8 @@ case class SessionWindowStateStoreRestoreExec( getStateInfo, stateManager.getStateKeySchema, stateManager.getStateValueSchema, - numColsPrefixKey = stateManager.getNumColsForPrefixKey, + // FIXME: set event time column here! + StatefulOperatorContext(stateManager.getNumColsForPrefixKey), session.sessionState, Some(session.streams.stateStoreCoordinator)) { case (store, iter) => @@ -618,7 +631,8 @@ case class SessionWindowStateStoreSaveExec( getStateInfo, stateManager.getStateKeySchema, stateManager.getStateValueSchema, - numColsPrefixKey = stateManager.getNumColsForPrefixKey, + // FIXME: set event time column! + StatefulOperatorContext(numColsPrefixKey = stateManager.getNumColsForPrefixKey), session.sessionState, Some(session.streams.stateStoreCoordinator)) { case (store, iter) => @@ -652,6 +666,7 @@ case class SessionWindowStateStoreSaveExec( val removalStartTimeNs = System.nanoTime new NextIterator[InternalRow] { + // FIXME: can we optimize this case as well? private val removedIter = stateManager.removeByValueCondition( store, watermarkPredicateForData.get.eval) @@ -751,7 +766,8 @@ case class StreamingDeduplicateExec( getStateInfo, keyExpressions.toStructType, child.output.toStructType, - numColsPrefixKey = 0, + // FIXME: set event time column! + StatefulOperatorContext(), session.sessionState, Some(session.streams.stateStoreCoordinator), // We won't check value row in state store since the value StreamingDeduplicateExec.EMPTY_ROW diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/streamingLimits.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/streamingLimits.scala index 8bba9b8..ba15f50 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/streamingLimits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/streamingLimits.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericInternalRow, SortOrder, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, Distribution, Partitioning} import org.apache.spark.sql.execution.{LimitExec, SparkPlan, UnaryExecNode} -import org.apache.spark.sql.execution.streaming.state.StateStoreOps +import org.apache.spark.sql.execution.streaming.state.{StatefulOperatorContext, StateStoreOps} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.{LongType, NullType, StructField, StructType} import org.apache.spark.util.{CompletionIterator, NextIterator} @@ -52,7 +52,8 @@ case class StreamingGlobalLimitExec( getStateInfo, keySchema, valueSchema, - numColsPrefixKey = 0, + // FIXME: set event time column! + StatefulOperatorContext(), session.sessionState, Some(session.streams.stateStoreCoordinator)) { (store, iter) => val key = UnsafeProjection.create(keySchema)(new GenericInternalRow(Array[Any](null))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/StateStoreBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/StateStoreBenchmark.scala new file mode 100644 index 0000000..9a83f5c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/StateStoreBenchmark.scala @@ -0,0 +1,633 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.benchmark + +import java.{util => jutil} + +import scala.util.Random + +import org.apache.hadoop.conf.Configuration +import org.rocksdb.RocksDBException + +import org.apache.spark.benchmark.Benchmark +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.execution.streaming.state.{HDFSBackedStateStoreProvider, RocksDBStateStoreProvider, StatefulOperatorContext, StateStore, StateStoreConf, StateStoreId, StateStoreProvider} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{IntegerType, StructField, StructType, TimestampType} +import org.apache.spark.util.Utils + +/** + * Synthetic benchmark for State Store operations. + * To run this benchmark: + * {{{ + * 1. without sbt: + * bin/spark-submit --class <this class> + * --jars <spark core test jar>,<spark catalyst test jar> <sql core test jar> + * 2. build/sbt "sql/test:runMain <this class>" + * 3. generate result: + * SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain <this class>" + * Results will be written to "benchmarks/StateStoreBenchmark-results.txt". + * }}} + */ +object StateStoreBenchmark extends SqlBasedBenchmark { + + private val numOfRows: Seq[Int] = Seq(10000, 50000, 100000) // Seq(10000, 100000, 1000000) + + // 200%, 100%, 50%, 25%, 10%, 5%, 1%, no update + // rate is relative to the number of rows in prev. batch + private val updateRates: Seq[Int] = Seq(25, 10, 5) // Seq(200, 100, 50, 25, 10, 5, 1, 0) + + // 100%, 50%, 25%, 10%, 5%, 1%, no evict + // rate is relative to the number of rows in prev. batch + private val evictRates: Seq[Int] = Seq(100, 50, 25, 10, 5, 1, 0) + + private val keySchema = StructType( + Seq(StructField("key1", IntegerType, true), StructField("key2", TimestampType, true))) + private val valueSchema = StructType(Seq(StructField("value", IntegerType, true))) + + private val keyProjection = UnsafeProjection.create(keySchema) + private val valueProjection = UnsafeProjection.create(valueSchema) + + private def runEvictBenchmark(): Unit = { + runBenchmark("evict rows") { + val numOfRows = Seq(10000) // Seq(1000, 10000, 100000) + val numOfTimestamps = Seq(10, 100, 1000) + val numOfEvictionRates = Seq(50, 25, 10, 5, 1, 0) // Seq(100, 75, 50, 25, 1, 0) + + numOfRows.foreach { numOfRow => + numOfTimestamps.foreach { numOfTimestamp => + val timestampsInMicros = (0L until numOfTimestamp).map(ts => ts * 1000L).toList + + val testData = constructRandomizedTestData(numOfRow, timestampsInMicros, 0) + + val rocksDBProvider = newRocksDBStateProviderWithEventTimeIdx() + val rocksDBStore = rocksDBProvider.getStore(0) + updateRows(rocksDBStore, testData) + + val committedVersion = try { + rocksDBStore.commit() + } catch { + case exc: RocksDBException => + // scalastyle:off println + System.out.println(s"Exception in RocksDB happen! ${exc.getMessage} / " + + s"status: ${exc.getStatus.getState} / ${exc.getStatus.getCodeString}" ) + exc.printStackTrace() + throw exc + // scalastyle:on println + } + + numOfEvictionRates.foreach { numOfEvictionRate => + val numOfRowsToEvict = numOfRow * numOfEvictionRate / 100 + // scalastyle:off println + System.out.println(s"numOfRowsToEvict: $numOfRowsToEvict / " + + s"timestampsInMicros: $timestampsInMicros / " + + s"numOfEvictionRate: $numOfEvictionRate / " + + s"numOfTimestamp: $numOfTimestamp / " + + s"take: ${numOfTimestamp * numOfEvictionRate / 100}") + + // scalastyle:on println + val maxTimestampToEvictInMillis = timestampsInMicros + .take(numOfTimestamp * numOfEvictionRate / 100) + .lastOption.map(_ / 1000).getOrElse(-1L) + + val benchmark = new Benchmark(s"evicting $numOfRowsToEvict rows " + + s"(max timestamp to evict in millis: $maxTimestampToEvictInMillis) " + + s"from $numOfRow rows with $numOfTimestamp timestamps " + + s"(${numOfRow / numOfTimestamp} rows" + + s" for the same timestamp)", + numOfRow, minNumIters = 1000, output = output) + + benchmark.addTimerCase("RocksDBStateStoreProvider") { timer => + val rocksDBStore = rocksDBProvider.getStore(committedVersion) + + timer.startTiming() + evictAsFullScanAndRemove(rocksDBStore, maxTimestampToEvictInMillis) + timer.stopTiming() + + rocksDBStore.abort() + } + + benchmark.addTimerCase("RocksDBStateStoreProvider with event time idx") { timer => + val rocksDBStore = rocksDBProvider.getStore(committedVersion) + + timer.startTiming() + evictAsNewEvictApi(rocksDBStore, maxTimestampToEvictInMillis) + timer.stopTiming() + + rocksDBStore.abort() + } + + benchmark.run() + } + + rocksDBProvider.close() + } + } + } + } + + private def runPutBenchmark(): Unit = { + runBenchmark("put rows") { + val numOfRows = Seq(10000) // Seq(1000, 10000, 100000) + val numOfTimestamps = Seq(100, 1000, 10000) // Seq(1, 10, 100, 1000, 10000) + numOfRows.foreach { numOfRow => + numOfTimestamps.foreach { numOfTimestamp => + val timestamps = (0L until numOfTimestamp).map(ts => ts * 1000L).toList + + val testData = constructRandomizedTestData(numOfRow, timestamps, 0) + + val rocksDBProvider = newRocksDBStateProvider() + val rocksDBWithIdxProvider = newRocksDBStateProviderWithEventTimeIdx() + + val benchmark = new Benchmark(s"putting $numOfRow rows, with $numOfTimestamp " + + s"timestamps (${numOfRow / numOfTimestamp} rows for the same timestamp)", + numOfRow, minNumIters = 1000, output = output) + + benchmark.addTimerCase("RocksDBStateStoreProvider") { timer => + val rocksDBStore = rocksDBProvider.getStore(0) + + timer.startTiming() + updateRows(rocksDBStore, testData) + timer.stopTiming() + + rocksDBStore.abort() + } + + benchmark.addTimerCase("RocksDBStateStoreProvider with event time idx") { timer => + val rocksDBWithIdxStore = rocksDBWithIdxProvider.getStore(0) + + timer.startTiming() + updateRows(rocksDBWithIdxStore, testData) + timer.stopTiming() + + rocksDBWithIdxStore.abort() + } + + benchmark.run() + + rocksDBProvider.close() + rocksDBWithIdxProvider.close() + } + } + } + } + + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + // runPutBenchmark() + runEvictBenchmark() + + /* + val testData = constructRandomizedTestData(numOfRows.max) + + skip("scanning and comparing") { + numOfRows.foreach { numOfRow => + val curData = testData.take(numOfRow) + + val inMemoryProvider = newHDFSBackedStateStoreProvider() + val inMemoryStore = inMemoryProvider.getStore(0) + + val rocksDBProvider = newRocksDBStateProvider() + val rocksDBStore = rocksDBProvider.getStore(0) + + updateRows(inMemoryStore, curData) + updateRows(rocksDBStore, curData) + + val newVersionForInMemory = inMemoryStore.commit() + val newVersionForRocksDB = rocksDBStore.commit() + + val benchmark = new Benchmark(s"scanning and comparing $numOfRow rows", + numOfRow, minNumIters = 1000, output = output) + + benchmark.addTimerCase("HDFSBackedStateStoreProvider") { timer => + val inMemoryStore2 = inMemoryProvider.getStore(newVersionForInMemory) + + timer.startTiming() + // NOTE: the latency would be quite similar regardless of the rate of eviction + // as we don't remove the actual row, so I simply picked 10 % + fullScanAndCompareTimestamp(inMemoryStore2, (numOfRow * 0.1).toInt) + timer.stopTiming() + + inMemoryStore2.abort() + } + + benchmark.addTimerCase("RocksDBStateStoreProvider") { timer => + val rocksDBStore2 = rocksDBProvider.getStore(newVersionForRocksDB) + + timer.startTiming() + // NOTE: the latency would be quite similar regardless of the rate of eviction + // as we don't remove the actual row, so I simply picked 10 % + fullScanAndCompareTimestamp(rocksDBStore2, (numOfRow * 0.1).toInt) + timer.stopTiming() + + rocksDBStore2.abort() + } + + benchmark.run() + + inMemoryProvider.close() + rocksDBProvider.close() + } + } + + // runBenchmark("simulate full operations on eviction") { + skip("simulate full operations on eviction") { + numOfRows.foreach { numOfRow => + val curData = testData.take(numOfRow) + + val inMemoryProvider = newHDFSBackedStateStoreProvider() + val inMemoryStore = inMemoryProvider.getStore(0) + + val rocksDBProvider = newRocksDBStateProvider() + val rocksDBStore = rocksDBProvider.getStore(0) + + val indexForInMemoryStore = new jutil.concurrent.ConcurrentSkipListMap[ + Int, jutil.List[UnsafeRow]]() + val indexForRocksDBStore = new jutil.concurrent.ConcurrentSkipListMap[ + Int, jutil.List[UnsafeRow]]() + + updateRowsWithSortedMapIndex(inMemoryStore, indexForInMemoryStore, curData) + updateRowsWithSortedMapIndex(rocksDBStore, indexForRocksDBStore, curData) + + assert(indexForInMemoryStore.size() == numOfRow) + assert(indexForRocksDBStore.size() == numOfRow) + + val newVersionForInMemory = inMemoryStore.commit() + val newVersionForRocksDB = rocksDBStore.commit() + + val rowsToUpdate = constructRandomizedTestData(numOfRow / 100 * updateRates.max, + minIdx = numOfRow + 1) + + updateRates.foreach { updateRate => + val numRowsUpdate = numOfRow / 100 * updateRate + val curRowsToUpdate = rowsToUpdate.take(numRowsUpdate) + + evictRates.foreach { evictRate => + val maxIdxToEvict = numOfRow / 100 * evictRate + + val benchmark = new Benchmark(s"simulating evict on $numOfRow rows, update " + + s"$numRowsUpdate rows ($updateRate %), evict $maxIdxToEvict rows ($evictRate %)", + numOfRow, minNumIters = 100, output = output) + + benchmark.addTimerCase("HDFSBackedStateStoreProvider") { timer => + val inMemoryStore2 = inMemoryProvider.getStore(newVersionForInMemory) + + timer.startTiming() + updateRows(inMemoryStore2, curRowsToUpdate) + evictAsFullScanAndRemove(inMemoryStore2, maxIdxToEvict) + timer.stopTiming() + + inMemoryStore2.abort() + } + + benchmark.addTimerCase("HDFSBackedStateStoreProvider - sorted map index") { timer => + + val inMemoryStore2 = inMemoryProvider.getStore(newVersionForInMemory) + + val curIndex = new jutil.concurrent.ConcurrentSkipListMap[Int, + jutil.List[UnsafeRow]]() + curIndex.putAll(indexForInMemoryStore) + + assert(curIndex.size() == numOfRow) + + timer.startTiming() + updateRowsWithSortedMapIndex(inMemoryStore2, curIndex, curRowsToUpdate) + + assert(curIndex.size() == numOfRow + curRowsToUpdate.size) + + evictAsScanSortedMapIndexAndRemove(inMemoryStore2, curIndex, maxIdxToEvict) + timer.stopTiming() + + assert(curIndex.size() == numOfRow + curRowsToUpdate.size - maxIdxToEvict) + + curIndex.clear() + + inMemoryStore2.abort() + } + + benchmark.run() + + val benchmark2 = new Benchmark(s"simulating evict on $numOfRow rows, update " + + s"$numRowsUpdate rows ($updateRate %), evict $maxIdxToEvict rows ($evictRate %)", + numOfRow, minNumIters = 100, output = output) + + benchmark2.addTimerCase("RocksDBStateStoreProvider") { timer => + val rocksDBStore2 = rocksDBProvider.getStore(newVersionForRocksDB) + + timer.startTiming() + updateRows(rocksDBStore2, curRowsToUpdate) + evictAsFullScanAndRemove(rocksDBStore2, maxIdxToEvict) + timer.stopTiming() + + rocksDBStore2.abort() + } + + benchmark2.addTimerCase("RocksDBStateStoreProvider - sorted map index") { timer => + + val rocksDBStore2 = rocksDBProvider.getStore(newVersionForRocksDB) + + val curIndex = new jutil.concurrent.ConcurrentSkipListMap[Int, + jutil.List[UnsafeRow]]() + curIndex.putAll(indexForRocksDBStore) + + assert(curIndex.size() == numOfRow) + + timer.startTiming() + updateRowsWithSortedMapIndex(rocksDBStore2, curIndex, curRowsToUpdate) + + assert(curIndex.size() == numOfRow + curRowsToUpdate.size) + + evictAsScanSortedMapIndexAndRemove(rocksDBStore2, curIndex, maxIdxToEvict) + timer.stopTiming() + + assert(curIndex.size() == numOfRow + curRowsToUpdate.size - maxIdxToEvict) + + curIndex.clear() + + rocksDBStore2.abort() + } + + benchmark2.run() + } + } + + inMemoryProvider.close() + rocksDBProvider.close() + } + } + + // runBenchmark("simulate full operations on eviction") { + skip("simulate full operations on eviction - scannable index") { + numOfRows.foreach { numOfRow => + val curData = testData.take(numOfRow) + + val rocksDBProvider = newRocksDBStateProvider() + val rocksDBStore = rocksDBProvider.getStore(0) + + val rocksDBWithIdxProvider = newRocksDBStateProviderWithEventTimeIdx() + val rocksDBWithIdxStore = rocksDBWithIdxProvider.getStore(0) + + updateRows(rocksDBStore, curData) + updateRows(rocksDBWithIdxStore, curData) + + val newVersionForRocksDB = rocksDBStore.commit() + val newVersionForRocksDBWithIdx = rocksDBWithIdxStore.commit() + + val rowsToUpdate = constructRandomizedTestData(numOfRow / 100 * updateRates.max, + minIdx = numOfRow + 1) + + updateRates.foreach { updateRate => + val numRowsUpdate = numOfRow / 100 * updateRate + val curRowsToUpdate = rowsToUpdate.take(numRowsUpdate) + + evictRates.foreach { evictRate => + val maxIdxToEvict = numOfRow / 100 * evictRate + + val benchmark = new Benchmark(s"simulating evict on $numOfRow rows, update " + + s"$numRowsUpdate rows ($updateRate %), evict $maxIdxToEvict rows ($evictRate %)", + numOfRow, minNumIters = 100, output = output) + + benchmark.addTimerCase("RocksDBStateStoreProvider") { timer => + val rocksDBStore2 = rocksDBProvider.getStore(newVersionForRocksDB) + + timer.startTiming() + updateRows(rocksDBStore2, curRowsToUpdate) + evictAsFullScanAndRemove(rocksDBStore2, maxIdxToEvict) + // evictAsNewEvictApi(rocksDBStore2, maxIdxToEvict) + timer.stopTiming() + + rocksDBStore2.abort() + } + + benchmark.addTimerCase("RocksDBStateStoreProvider with event time idx") { timer => + val rocksDBWithIdxStore2 = rocksDBWithIdxProvider.getStore( + newVersionForRocksDBWithIdx) + + timer.startTiming() + updateRows(rocksDBWithIdxStore2, curRowsToUpdate) + evictAsNewEvictApi(rocksDBWithIdxStore2, maxIdxToEvict) + timer.stopTiming() + + rocksDBWithIdxStore2.abort() + } + + benchmark.run() + } + } + + rocksDBProvider.close() + rocksDBWithIdxProvider.close() + } + } + + runBenchmark("put rows") { + numOfRows.foreach { numOfRow => + val curData = testData.take(numOfRow) + + val rocksDBProvider = newRocksDBStateProvider() + val rocksDBWithIdxProvider = newRocksDBStateProviderWithEventTimeIdx() + + val benchmark = new Benchmark(s"putting $numOfRow rows", + numOfRow, minNumIters = 1000, output = output) + + benchmark.addTimerCase("RocksDBStateStoreProvider") { timer => + val rocksDBStore = rocksDBProvider.getStore(0) + + timer.startTiming() + updateRows(rocksDBStore, curData) + timer.stopTiming() + + rocksDBStore.abort() + } + + benchmark.addTimerCase("RocksDBStateStoreProvider with event time idx") { timer => + val rocksDBWithIdxStore = rocksDBWithIdxProvider.getStore(0) + + timer.startTiming() + updateRows(rocksDBWithIdxStore, curData) + timer.stopTiming() + + rocksDBWithIdxStore.abort() + } + + benchmark.run() + + rocksDBProvider.close() + rocksDBWithIdxProvider.close() + } + } + */ + } + + final def skip(benchmarkName: String)(func: => Any): Unit = { + output.foreach(_.write(s"$benchmarkName is skipped".getBytes)) + } + + private def updateRows( + store: StateStore, + rows: Seq[(UnsafeRow, UnsafeRow)]): Unit = { + rows.foreach { case (key, value) => + store.put(key, value) + } + } + + private def evictAsFullScanAndRemove( + store: StateStore, + maxTimestampToEvict: Long): Unit = { + store.iterator().foreach { r => + if (r.key.getLong(1) < maxTimestampToEvict) { + store.remove(r.key) + } + } + } + + private def evictAsNewEvictApi( + store: StateStore, + maxTimestampToEvict: Long): Unit = { + store.evictOnWatermark(maxTimestampToEvict, pair => { + pair.key.getLong(1) < maxTimestampToEvict + }).foreach { _ => } + } + + private def fullScanAndCompareTimestamp( + store: StateStore, + maxIdxToEvict: Int): Unit = { + var i: Long = 0 + store.iterator().foreach { r => + if (r.key.getInt(1) < maxIdxToEvict) { + // simply to avoid the "if statement" to be no-op + i += 1 + } + } + } + + private def updateRowsWithSortedMapIndex( + store: StateStore, + index: jutil.SortedMap[Int, jutil.List[UnsafeRow]], + rows: Seq[(UnsafeRow, UnsafeRow)]): Unit = { + rows.foreach { case (key, value) => + val idx = key.getInt(1) + + // TODO: rewrite this in atomic way? + if (index.containsKey(idx)) { + val list = index.get(idx) + list.add(key) + } else { + val list = new jutil.ArrayList[UnsafeRow]() + list.add(key) + index.put(idx, list) + } + + store.put(key, value) + } + } + + private def evictAsScanSortedMapIndexAndRemove( + store: StateStore, + index: jutil.SortedMap[Int, jutil.List[UnsafeRow]], + maxIdxToEvict: Int): Unit = { + val keysToRemove = index.headMap(maxIdxToEvict + 1) + val keysToRemoveIter = keysToRemove.entrySet().iterator() + while (keysToRemoveIter.hasNext) { + val entry = keysToRemoveIter.next() + val keys = entry.getValue + val keysIter = keys.iterator() + while (keysIter.hasNext) { + val key = keysIter.next() + store.remove(key) + } + keys.clear() + keysToRemoveIter.remove() + } + } + + // FIXME: should the size of key / value be variables? + private def constructTestData(numRows: Int, minIdx: Int = 0): Seq[(UnsafeRow, UnsafeRow)] = { + (1 to numRows).map { idx => + val keyRow = new GenericInternalRow(2) + keyRow.setInt(0, 1) + keyRow.setLong(1, (minIdx + idx) * 1000L) // microseconds + val valueRow = new GenericInternalRow(1) + valueRow.setInt(0, minIdx + idx) + + val keyUnsafeRow = keyProjection(keyRow).copy() + val valueUnsafeRow = valueProjection(valueRow).copy() + + (keyUnsafeRow, valueUnsafeRow) + } + } + + // This prevents created keys to be in order, which may affect the performance on RocksDB. + private def constructRandomizedTestData( + numRows: Int, + timestamps: List[Long], + minIdx: Int = 0): Seq[(UnsafeRow, UnsafeRow)] = { + assert(numRows >= timestamps.length) + assert(numRows % timestamps.length == 0) + + (1 to numRows).map { idx => + val keyRow = new GenericInternalRow(2) + keyRow.setInt(0, Random.nextInt(Int.MaxValue)) + keyRow.setLong(1, timestamps((minIdx + idx) % timestamps.length)) // microseconds + val valueRow = new GenericInternalRow(1) + valueRow.setInt(0, minIdx + idx) + + val keyUnsafeRow = keyProjection(keyRow).copy() + val valueUnsafeRow = valueProjection(valueRow).copy() + + (keyUnsafeRow, valueUnsafeRow) + } + } + + private def newHDFSBackedStateStoreProvider(): StateStoreProvider = { + val storeId = StateStoreId(newDir(), Random.nextInt(), 0) + val provider = new HDFSBackedStateStoreProvider() + val sqlConf = new SQLConf() + sqlConf.setConfString("spark.sql.streaming.stateStore.compression.codec", "zstd") + val storeConf = new StateStoreConf(sqlConf) + provider.init( + storeId, keySchema, valueSchema, StatefulOperatorContext(), + storeConf, new Configuration) + provider + } + + private def newRocksDBStateProvider(): StateStoreProvider = { + val storeId = StateStoreId(newDir(), Random.nextInt(), 0) + val provider = new RocksDBStateStoreProvider() + val sqlConf = new SQLConf() + sqlConf.setConfString("spark.sql.streaming.stateStore.compression.codec", "zstd") + val storeConf = new StateStoreConf(sqlConf) + provider.init( + storeId, keySchema, valueSchema, StatefulOperatorContext(), + storeConf, new Configuration) + provider + } + + private def newRocksDBStateProviderWithEventTimeIdx(): StateStoreProvider = { + val storeId = StateStoreId(newDir(), Random.nextInt(), 0) + val provider = new RocksDBStateStoreProvider() + val sqlConf = new SQLConf() + sqlConf.setConfString("spark.sql.streaming.stateStore.compression.codec", "zstd") + val storeConf = new StateStoreConf(sqlConf) + provider.init( + storeId, keySchema, valueSchema, StatefulOperatorContext(eventTimeColIdx = Array(1)), + storeConf, new Configuration) + provider + } + + private def newDir(): String = Utils.createTempDir().toString +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowStateIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowStateIteratorSuite.scala index 81f1a3f..4f7c040 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowStateIteratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowStateIteratorSuite.scala @@ -24,7 +24,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} -import org.apache.spark.sql.execution.streaming.state.{HDFSBackedStateStoreProvider, RocksDBStateStoreProvider, StateStore, StateStoreConf, StateStoreId, StateStoreProviderId, StreamingSessionWindowStateManager} +import org.apache.spark.sql.execution.streaming.state.{HDFSBackedStateStoreProvider, RocksDBStateStoreProvider, StatefulOperatorContext, StateStore, StateStoreConf, StateStoreId, StateStoreProviderId, StreamingSessionWindowStateManager} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType} @@ -217,9 +217,12 @@ class MergingSortWithSessionWindowStateIteratorSuite extends StreamTest with Bef stateFormatVersion) val storeProviderId = StateStoreProviderId(stateInfo, 0, StateStoreId.DEFAULT_STORE_NAME) + + // FIXME: event time column? val store = StateStore.get( storeProviderId, manager.getStateKeySchema, manager.getStateValueSchema, - manager.getNumColsForPrefixKey, stateInfo.storeVersion, storeConf, new Configuration) + StatefulOperatorContext(numColsPrefixKey = manager.getNumColsForPrefixKey), + stateInfo.storeVersion, storeConf, new Configuration) try { f(manager, store) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala index e52ccd0..4b53c0a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala @@ -50,4 +50,18 @@ class MemoryStateStore extends StateStore() { override def prefixScan(prefixKey: UnsafeRow): Iterator[UnsafeRowPair] = { throw new UnsupportedOperationException("Doesn't support prefix scan!") } + + /** FIXME: method doc */ + override def evictOnWatermark( + watermarkMs: Long, + altPred: UnsafeRowPair => Boolean): Iterator[UnsafeRowPair] = { + iterator().filter { pair => + if (altPred.apply(pair)) { + remove(pair.key) + true + } else { + false + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala index 2d741d3..7374303 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala @@ -24,7 +24,7 @@ import scala.collection.JavaConverters import org.scalatest.time.{Minute, Span} import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingQueryWrapper} -import org.apache.spark.sql.functions.count +import org.apache.spark.sql.functions.{count, timestamp_seconds, window} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming._ @@ -52,6 +52,64 @@ class RocksDBStateStoreIntegrationSuite extends StreamTest { } } + test("append mode") { + val inputData = MemoryStream[Int] + val conf = Map(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName) + + val windowedAggregation = inputData.toDF() + .withColumn("eventTime", timestamp_seconds($"value")) + .withWatermark("eventTime", "10 seconds") + .groupBy(window($"eventTime", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) + + testStream(windowedAggregation)( + StartStream(additionalConfs = conf), + AddData(inputData, 10, 11, 12, 13, 14, 15), + CheckNewAnswer(), + AddData(inputData, 25), // Advance watermark to 15 seconds + CheckNewAnswer((10, 5)), + // assertNumStateRows(2), + // assertNumRowsDroppedByWatermark(0), + AddData(inputData, 10), // Should not emit anything as data less than watermark + CheckNewAnswer() + // assertNumStateRows(2), + // assertNumRowsDroppedByWatermark(1) + ) + } + + test("update mode") { + val inputData = MemoryStream[Int] + val conf = Map(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName) + + val windowedAggregation = inputData.toDF() + .withColumn("eventTime", timestamp_seconds($"value")) + .withWatermark("eventTime", "10 seconds") + .groupBy(window($"eventTime", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) + + testStream(windowedAggregation, OutputMode.Update)( + StartStream(additionalConfs = conf), + AddData(inputData, 10, 11, 12, 13, 14, 15), + CheckNewAnswer((10, 5), (15, 1)), + AddData(inputData, 25), // Advance watermark to 15 seconds + CheckNewAnswer((25, 1)), + // assertNumStateRows(2), + // assertNumRowsDroppedByWatermark(0), + AddData(inputData, 10, 25), // Ignore 10 as its less than watermark + CheckNewAnswer((25, 2)), + // assertNumStateRows(2), + // assertNumRowsDroppedByWatermark(1), + AddData(inputData, 10), // Should not emit anything as data less than watermark + CheckNewAnswer() + // assertNumStateRows(2), + // assertNumRowsDroppedByWatermark(1) + ) + } + test("SPARK-36236: query progress contains only the expected RocksDB store custom metrics") { // fails if any new custom metrics are added to remind the author of API changes import testImplicits._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala index c93d0f0..c06a9ec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala @@ -87,8 +87,9 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid queryRunId = UUID.randomUUID, operatorId = 0, storeVersion = 0, numPartitions = 5) // Create state store in a task and get the RocksDBConf from the instantiated RocksDB instance + // FIXME: event time column? val rocksDBConfInTask: RocksDBConf = testRDD.mapPartitionsWithStateStore[RocksDBConf]( - spark.sqlContext, testStateInfo, testSchema, testSchema, 0) { + spark.sqlContext, testStateInfo, testSchema, testSchema, StatefulOperatorContext()) { (store: StateStore, _: Iterator[String]) => // Use reflection to get RocksDB instance val dbInstanceMethod = @@ -144,7 +145,8 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid numColsPrefixKey: Int): RocksDBStateStoreProvider = { val provider = new RocksDBStateStoreProvider() provider.init( - storeId, keySchema, valueSchema, numColsPrefixKey = numColsPrefixKey, + storeId, keySchema, valueSchema, + StatefulOperatorContext(numColsPrefixKey = numColsPrefixKey), new StateStoreConf, new Configuration) provider } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala index 6bb8ebe..0e9b19d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala @@ -60,13 +60,13 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn val path = Utils.createDirectory(tempDir, Random.nextFloat.toString).toString val rdd1 = makeRDD(spark.sparkContext, Seq(("a", 0), ("b", 0), ("a", 0))) .mapPartitionsWithStateStore(spark.sqlContext, operatorStateInfo(path, version = 0), - keySchema, valueSchema, numColsPrefixKey = 0)(increment) + keySchema, valueSchema, StatefulOperatorContext())(increment) assert(rdd1.collect().toSet === Set(("a", 0) -> 2, ("b", 0) -> 1)) // Generate next version of stores val rdd2 = makeRDD(spark.sparkContext, Seq(("a", 0), ("c", 0))) .mapPartitionsWithStateStore(spark.sqlContext, operatorStateInfo(path, version = 1), - keySchema, valueSchema, numColsPrefixKey = 0)(increment) + keySchema, valueSchema, StatefulOperatorContext())(increment) assert(rdd2.collect().toSet === Set(("a", 0) -> 3, ("b", 0) -> 1, ("c", 0) -> 1)) // Make sure the previous RDD still has the same data. @@ -84,7 +84,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn implicit val sqlContext = spark.sqlContext makeRDD(spark.sparkContext, Seq(("a", 0))).mapPartitionsWithStateStore( sqlContext, operatorStateInfo(path, version = storeVersion), - keySchema, valueSchema, numColsPrefixKey = 0)(increment) + keySchema, valueSchema, StatefulOperatorContext())(increment) } // Generate RDDs and state store data @@ -134,19 +134,19 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn val rddOfGets1 = makeRDD(spark.sparkContext, Seq(("a", 0), ("b", 0), ("c", 0))) .mapPartitionsWithStateStore(spark.sqlContext, operatorStateInfo(path, version = 0), - keySchema, valueSchema, numColsPrefixKey = 0)(iteratorOfGets) + keySchema, valueSchema, StatefulOperatorContext())(iteratorOfGets) assert(rddOfGets1.collect().toSet === Set(("a", 0) -> None, ("b", 0) -> None, ("c", 0) -> None)) val rddOfPuts = makeRDD(spark.sparkContext, Seq(("a", 0), ("b", 0), ("a", 0))) .mapPartitionsWithStateStore(sqlContext, operatorStateInfo(path, version = 0), - keySchema, valueSchema, numColsPrefixKey = 0)(iteratorOfPuts) + keySchema, valueSchema, StatefulOperatorContext())(iteratorOfPuts) assert(rddOfPuts.collect().toSet === Set(("a", 0) -> 1, ("a", 0) -> 2, ("b", 0) -> 1)) val rddOfGets2 = makeRDD(spark.sparkContext, Seq(("a", 0), ("b", 0), ("c", 0))) .mapPartitionsWithStateStore(sqlContext, operatorStateInfo(path, version = 1), - keySchema, valueSchema, numColsPrefixKey = 0)(iteratorOfGets) + keySchema, valueSchema, StatefulOperatorContext())(iteratorOfGets) assert(rddOfGets2.collect().toSet === Set(("a", 0) -> Some(2), ("b", 0) -> Some(1), ("c", 0) -> None)) } @@ -172,7 +172,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn val rdd = makeRDD(spark.sparkContext, Seq(("a", 0), ("b", 0), ("a", 0))) .mapPartitionsWithStateStore(sqlContext, operatorStateInfo(path, queryRunId = queryRunId), - keySchema, valueSchema, numColsPrefixKey = 0)(increment) + keySchema, valueSchema, StatefulOperatorContext())(increment) require(rdd.partitions.length === 2) assert( @@ -200,13 +200,13 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn val opId = 0 val rdd1 = makeRDD(spark.sparkContext, Seq(("a", 0), ("b", 0), ("a", 0))) .mapPartitionsWithStateStore(sqlContext, operatorStateInfo(path, version = 0), - keySchema, valueSchema, numColsPrefixKey = 0)(increment) + keySchema, valueSchema, StatefulOperatorContext())(increment) assert(rdd1.collect().toSet === Set(("a", 0) -> 2, ("b", 0) -> 1)) // Generate next version of stores val rdd2 = makeRDD(spark.sparkContext, Seq(("a", 0), ("c", 0))) .mapPartitionsWithStateStore(sqlContext, operatorStateInfo(path, version = 1), - keySchema, valueSchema, numColsPrefixKey = 0)(increment) + keySchema, valueSchema, StatefulOperatorContext())(increment) assert(rdd2.collect().toSet === Set(("a", 0) -> 3, ("b", 0) -> 1, ("c", 0) -> 1)) // Make sure the previous RDD still has the same data. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 601b62b..d89c6fed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -270,8 +270,8 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] def generateStoreVersions(): Unit = { for (i <- 1 to 20) { - val store = StateStore.get(storeProviderId1, keySchema, valueSchema, numColsPrefixKey = 0, - latestStoreVersion, storeConf, hadoopConf) + val store = StateStore.get(storeProviderId1, keySchema, valueSchema, + StatefulOperatorContext(), latestStoreVersion, storeConf, hadoopConf) put(store, "a", 0, i) store.commit() latestStoreVersion += 1 @@ -324,7 +324,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] } // Reload the store and verify - StateStore.get(storeProviderId1, keySchema, valueSchema, numColsPrefixKey = 0, + StateStore.get(storeProviderId1, keySchema, valueSchema, StatefulOperatorContext(), latestStoreVersion, storeConf, hadoopConf) assert(StateStore.isLoaded(storeProviderId1)) @@ -336,7 +336,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] } // Reload the store and verify - StateStore.get(storeProviderId1, keySchema, valueSchema, numColsPrefixKey = 0, + StateStore.get(storeProviderId1, keySchema, valueSchema, StatefulOperatorContext(), latestStoreVersion, storeConf, hadoopConf) assert(StateStore.isLoaded(storeProviderId1)) @@ -344,7 +344,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] // then this executor should unload inactive instances immediately. coordinatorRef .reportActiveInstance(storeProviderId1, "other-host", "other-exec", Seq.empty) - StateStore.get(storeProviderId2, keySchema, valueSchema, numColsPrefixKey = 0, + StateStore.get(storeProviderId2, keySchema, valueSchema, StatefulOperatorContext(), 0, storeConf, hadoopConf) assert(!StateStore.isLoaded(storeProviderId1)) assert(StateStore.isLoaded(storeProviderId2)) @@ -453,7 +453,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] // Getting the store should not create temp file val store0 = shouldNotCreateTempFile { StateStore.get( - storeId, keySchema, valueSchema, numColsPrefixKey = 0, + storeId, keySchema, valueSchema, StatefulOperatorContext(), version = 0, storeConf, hadoopConf) } @@ -470,7 +470,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] // Remove should create a temp file val store1 = shouldNotCreateTempFile { StateStore.get( - storeId, keySchema, valueSchema, numColsPrefixKey = 0, + storeId, keySchema, valueSchema, StatefulOperatorContext(), version = 1, storeConf, hadoopConf) } remove(store1, _._1 == "a") @@ -485,7 +485,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] // Commit without any updates should create a delta file val store2 = shouldNotCreateTempFile { StateStore.get( - storeId, keySchema, valueSchema, numColsPrefixKey = 0, + storeId, keySchema, valueSchema, StatefulOperatorContext(), version = 2, storeConf, hadoopConf) } store2.commit() @@ -720,11 +720,12 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] hadoopConf: Configuration = new Configuration): HDFSBackedStateStoreProvider = { val sqlConf = getDefaultSQLConf(minDeltasForSnapshot, numOfVersToRetainInMemory) val provider = new HDFSBackedStateStoreProvider() + // FIXME: event time column? provider.init( StateStoreId(dir, opId, partition), keySchema, valueSchema, - numColsPrefixKey = numColsPrefixKey, + StatefulOperatorContext(numColsPrefixKey = numColsPrefixKey), new StateStoreConf(sqlConf), hadoopConf) provider @@ -1027,31 +1028,31 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] // Verify that trying to get incorrect versions throw errors intercept[IllegalArgumentException] { StateStore.get( - storeId, keySchema, valueSchema, 0, -1, storeConf, hadoopConf) + storeId, keySchema, valueSchema, StatefulOperatorContext(), -1, storeConf, hadoopConf) } assert(!StateStore.isLoaded(storeId)) // version -1 should not attempt to load the store intercept[IllegalStateException] { StateStore.get( - storeId, keySchema, valueSchema, 0, 1, storeConf, hadoopConf) + storeId, keySchema, valueSchema, StatefulOperatorContext(), 1, storeConf, hadoopConf) } // Increase version of the store and try to get again val store0 = StateStore.get( - storeId, keySchema, valueSchema, 0, 0, storeConf, hadoopConf) + storeId, keySchema, valueSchema, StatefulOperatorContext(), 0, storeConf, hadoopConf) assert(store0.version === 0) put(store0, "a", 0, 1) store0.commit() val store1 = StateStore.get( - storeId, keySchema, valueSchema, 0, 1, storeConf, hadoopConf) + storeId, keySchema, valueSchema, StatefulOperatorContext(), 1, storeConf, hadoopConf) assert(StateStore.isLoaded(storeId)) assert(store1.version === 1) assert(rowPairsToDataSet(store1.iterator()) === Set(("a", 0) -> 1)) // Verify that you can also load older version val store0reloaded = StateStore.get( - storeId, keySchema, valueSchema, 0, 0, storeConf, hadoopConf) + storeId, keySchema, valueSchema, StatefulOperatorContext(), 0, storeConf, hadoopConf) assert(store0reloaded.version === 0) assert(rowPairsToDataSet(store0reloaded.iterator()) === Set.empty) @@ -1060,7 +1061,7 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] assert(!StateStore.isLoaded(storeId)) val store1reloaded = StateStore.get( - storeId, keySchema, valueSchema, 0, 1, storeConf, hadoopConf) + storeId, keySchema, valueSchema, StatefulOperatorContext(), 1, storeConf, hadoopConf) assert(StateStore.isLoaded(storeId)) assert(store1reloaded.version === 1) put(store1reloaded, "a", 0, 2) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingSessionWindowStateManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingSessionWindowStateManagerSuite.scala index 096c3bb..dcdbac9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingSessionWindowStateManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingSessionWindowStateManagerSuite.scala @@ -181,9 +181,11 @@ class StreamingSessionWindowStateManagerSuite extends StreamTest with BeforeAndA stateFormatVersion) val storeProviderId = StateStoreProviderId(stateInfo, 0, StateStoreId.DEFAULT_STORE_NAME) + // FIXME: event time column? val store = StateStore.get( storeProviderId, manager.getStateKeySchema, manager.getStateValueSchema, - manager.getNumColsForPrefixKey, stateInfo.storeVersion, storeConf, new Configuration) + StatefulOperatorContext(numColsPrefixKey = manager.getNumColsForPrefixKey), + stateInfo.storeVersion, storeConf, new Configuration) try { f(manager, store) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index e89197b..d376942 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.execution.{LocalLimitExec, SimpleMode, SparkPlan} import org.apache.spark.sql.execution.command.ExplainCommand import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources.{ContinuousMemoryStream, MemorySink} -import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreConf, StateStoreId, StateStoreProvider} +import org.apache.spark.sql.execution.streaming.state.{StatefulOperatorContext, StateStore, StateStoreConf, StateStoreId, StateStoreProvider} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.StreamSourceProvider @@ -1418,7 +1418,7 @@ class TestStateStoreProvider extends StateStoreProvider { stateStoreId: StateStoreId, keySchema: StructType, valueSchema: StructType, - numColsPrefixKey: Int, + operatorContext: StatefulOperatorContext, storeConfs: StateStoreConf, hadoopConf: Configuration): Unit = { throw new Exception("Successfully instantiated") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index 77334ad..fbf3aae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.exchange.Exchange import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources.MemorySink -import org.apache.spark.sql.execution.streaming.state.{StateSchemaNotCompatible, StateStore, StreamingAggregationStateManager} +import org.apache.spark.sql.execution.streaming.state.{HDFSBackedStateStoreProvider, RocksDBStateStoreProvider, StateSchemaNotCompatible, StateStore, StreamingAggregationStateManager} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode._ @@ -53,29 +53,47 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with Assertions { import testImplicits._ def executeFuncWithStateVersionSQLConf( + providerCls: String, stateVersion: Int, confPairs: Seq[(String, String)], func: => Any): Unit = { withSQLConf(confPairs ++ - Seq(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> stateVersion.toString): _*) { + Seq( + SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> stateVersion.toString, + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerCls.stripSuffix("$")): _*) { func } } def testWithAllStateVersions(name: String, confPairs: (String, String)*) (func: => Any): Unit = { - for (version <- StreamingAggregationStateManager.supportedVersions) { - test(s"$name - state format version $version") { - executeFuncWithStateVersionSQLConf(version, confPairs, func) + val providers = Seq( + // FIXME: testing... + // classOf[HDFSBackedStateStoreProvider].getCanonicalName, + classOf[RocksDBStateStoreProvider].getCanonicalName) + + for ( + version <- StreamingAggregationStateManager.supportedVersions; + provider <- providers + ) yield { + test(s"$name - state format version $version / provider: $provider") { + executeFuncWithStateVersionSQLConf(provider, version, confPairs, func) } } } def testQuietlyWithAllStateVersions(name: String, confPairs: (String, String)*) (func: => Any): Unit = { - for (version <- StreamingAggregationStateManager.supportedVersions) { - testQuietly(s"$name - state format version $version") { - executeFuncWithStateVersionSQLConf(version, confPairs, func) + val providers = Seq( + classOf[HDFSBackedStateStoreProvider].getCanonicalName, + classOf[RocksDBStateStoreProvider].getCanonicalName) + + for ( + version <- StreamingAggregationStateManager.supportedVersions; + provider <- providers + ) yield { + testQuietly(s"$name - state format version $version / provider: $provider") { + executeFuncWithStateVersionSQLConf(provider, version, confPairs, func) } } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org