This is an automated email from the ASF dual-hosted git repository. kabhwan pushed a commit to branch WIP-optimize-eviction-in-rocksdb-state-store in repository https://gitbox.apache.org/repos/asf/spark.git
commit 88187c64d70a5566bb8f07ce4b133a33e63ce5fc Author: Jungtaek Lim <kabhwan.opensou...@gmail.com> AuthorDate: Fri Nov 5 16:41:24 2021 +0900 WIP still need to add e2e test and address FIXME/TODOs --- .../sql/execution/streaming/state/RocksDB.scala | 172 ++++------------- .../streaming/state/RocksDBFileManager.scala | 35 +++- .../streaming/state/RocksDBStateEncoder.scala | 96 +--------- .../state/RocksDBStateStoreProvider.scala | 213 +++++++++++++-------- .../execution/benchmark/StateStoreBenchmark.scala | 25 ++- .../execution/streaming/state/RocksDBSuite.scala | 8 +- 6 files changed, 227 insertions(+), 322 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index eed7827..105a446 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.execution.streaming.state import java.io.File -import java.util import java.util.Locale import javax.annotation.concurrent.GuardedBy @@ -51,12 +50,9 @@ import org.apache.spark.util.{NextIterator, Utils} * @param hadoopConf Hadoop configuration for talking to the remote file system * @param loggingId Id that will be prepended in logs for isolating concurrent RocksDBs */ -// FIXME: optionally receiving column families class RocksDB( dfsRootDir: String, val conf: RocksDBConf, - // TODO: change "default" to constant - columnFamilies: Seq[String] = Seq("default"), localRootDir: File = Utils.createTempDir(), hadoopConf: Configuration = new Configuration, loggingId: String = "") extends Logging { @@ -69,10 +65,16 @@ class RocksDB( private val flushOptions = new FlushOptions().setWaitForFlush(true) // wait for flush to complete private val writeBatch = new WriteBatchWithIndex(true) // overwrite multiple updates to a key - private val dbOptions: DBOptions = new DBOptions() // options to open the RocksDB - dbOptions.setCreateIfMissing(true) - dbOptions.setCreateMissingColumnFamilies(true) + private val bloomFilter = new BloomFilter() + private val tableFormatConfig = new BlockBasedTableConfig() + tableFormatConfig.setBlockSize(conf.blockSizeKB * 1024) + tableFormatConfig.setBlockCache(new LRUCache(conf.blockCacheSizeMB * 1024 * 1024)) + tableFormatConfig.setFilterPolicy(bloomFilter) + tableFormatConfig.setFormatVersion(conf.formatVersion) + private val dbOptions = new Options() // options to open the RocksDB + dbOptions.setCreateIfMissing(true) + dbOptions.setTableFormatConfig(tableFormatConfig) private val dbLogger = createLogger() // for forwarding RocksDB native logs to log4j dbOptions.setStatistics(new Statistics()) private val nativeStats = dbOptions.statistics() @@ -85,18 +87,18 @@ class RocksDB( private val acquireLock = new Object @volatile private var db: NativeRocksDB = _ - @volatile private var columnFamilyHandles: util.Map[String, ColumnFamilyHandle] = _ - @volatile private var defaultColumnFamilyHandle: ColumnFamilyHandle = _ @volatile private var loadedVersion = -1L // -1 = nothing valid is loaded @volatile private var numKeysOnLoadedVersion = 0L @volatile private var numKeysOnWritingVersion = 0L @volatile private var fileManagerMetrics = RocksDBFileManagerMetrics.EMPTY_METRICS + @volatile private var customMetadataOnLoadedVersion: Map[String, String] = Map.empty + @volatile private var customMetadataOnWritingVersion: Map[String, String] = Map.empty @GuardedBy("acquireLock") @volatile private var acquiredThreadInfo: AcquiredThreadInfo = _ private val prefixScanReuseIter = - new java.util.concurrent.ConcurrentHashMap[(Long, Int), RocksIterator]() + new java.util.concurrent.ConcurrentHashMap[Long, RocksIterator]() /** * Load the given version of data in a native RocksDB instance. @@ -114,6 +116,8 @@ class RocksDB( openDB() numKeysOnWritingVersion = metadata.numKeys numKeysOnLoadedVersion = metadata.numKeys + customMetadataOnLoadedVersion = metadata.customMetadata + customMetadataOnWritingVersion = metadata.customMetadata loadedVersion = version fileManagerMetrics = fileManager.latestLoadCheckpointMetrics } @@ -137,28 +141,7 @@ class RocksDB( * @note This will return the last written value even if it was uncommitted. */ def get(key: Array[Byte]): Array[Byte] = { - get(defaultColumnFamilyHandle, key) - } - - // FIXME: method doc - def get(cf: String, key: Array[Byte]): Array[Byte] = { - get(findColumnFamilyHandle(cf), key) - } - - private def get(cfHandle: ColumnFamilyHandle, key: Array[Byte]): Array[Byte] = { - writeBatch.getFromBatchAndDB(db, cfHandle, readOptions, key) - } - - def merge(key: Array[Byte], value: Array[Byte]): Unit = { - merge(defaultColumnFamilyHandle, key, value) - } - - def merge(cf: String, key: Array[Byte], value: Array[Byte]): Unit = { - merge(findColumnFamilyHandle(cf), key, value) - } - - private def merge(cfHandle: ColumnFamilyHandle, key: Array[Byte], value: Array[Byte]): Unit = { - writeBatch.merge(cfHandle, key, value) + writeBatch.getFromBatchAndDB(db, readOptions, key) } /** @@ -166,20 +149,8 @@ class RocksDB( * @note This update is not committed to disk until commit() is called. */ def put(key: Array[Byte], value: Array[Byte]): Array[Byte] = { - put(defaultColumnFamilyHandle, key, value) - } - - // FIXME: method doc - def put(cf: String, key: Array[Byte], value: Array[Byte]): Array[Byte] = { - put(findColumnFamilyHandle(cf), key, value) - } - - private def put( - cfHandle: ColumnFamilyHandle, - key: Array[Byte], - value: Array[Byte]): Array[Byte] = { - val oldValue = writeBatch.getFromBatchAndDB(db, cfHandle, readOptions, key) - writeBatch.put(cfHandle, key, value) + val oldValue = writeBatch.getFromBatchAndDB(db, readOptions, key) + writeBatch.put(key, value) if (oldValue == null) { numKeysOnWritingVersion += 1 } @@ -191,18 +162,9 @@ class RocksDB( * @note This update is not committed to disk until commit() is called. */ def remove(key: Array[Byte]): Array[Byte] = { - remove(defaultColumnFamilyHandle, key) - } - - // FIXME: method doc - def remove(cf: String, key: Array[Byte]): Array[Byte] = { - remove(findColumnFamilyHandle(cf), key) - } - - private def remove(cfHandle: ColumnFamilyHandle, key: Array[Byte]): Array[Byte] = { - val value = writeBatch.getFromBatchAndDB(db, cfHandle, readOptions, key) + val value = writeBatch.getFromBatchAndDB(db, readOptions, key) if (value != null) { - writeBatch.delete(cfHandle, key) + writeBatch.remove(key) numKeysOnWritingVersion -= 1 } value @@ -211,17 +173,8 @@ class RocksDB( /** * Get an iterator of all committed and uncommitted key-value pairs. */ - def iterator(): NextIterator[ByteArrayPair] = { - iterator(defaultColumnFamilyHandle) - } - - // FIXME: doc - def iterator(cf: String): NextIterator[ByteArrayPair] = { - iterator(findColumnFamilyHandle(cf)) - } - - private def iterator(cfHandle: ColumnFamilyHandle): NextIterator[ByteArrayPair] = { - val iter = writeBatch.newIteratorWithBase(cfHandle, db.newIterator(cfHandle)) + def iterator(): Iterator[ByteArrayPair] = { + val iter = writeBatch.newIteratorWithBase(db.newIterator()) logInfo(s"Getting iterator from version $loadedVersion") iter.seekToFirst() @@ -248,20 +201,11 @@ class RocksDB( } def prefixScan(prefix: Array[Byte]): Iterator[ByteArrayPair] = { - prefixScan(defaultColumnFamilyHandle, prefix) - } - - def prefixScan(cf: String, prefix: Array[Byte]): Iterator[ByteArrayPair] = { - prefixScan(findColumnFamilyHandle(cf), prefix) - } - - private def prefixScan( - cfHandle: ColumnFamilyHandle, prefix: Array[Byte]): Iterator[ByteArrayPair] = { val threadId = Thread.currentThread().getId - val iter = prefixScanReuseIter.computeIfAbsent((threadId, cfHandle.getID), key => { - val it = writeBatch.newIteratorWithBase(cfHandle, db.newIterator(cfHandle)) + val iter = prefixScanReuseIter.computeIfAbsent(threadId, tid => { + val it = writeBatch.newIteratorWithBase(db.newIterator()) logInfo(s"Getting iterator from version $loadedVersion for prefix scan on " + - s"thread ID ${key._1} and column family ID ${key._2}") + s"thread ID $tid") it }) @@ -283,12 +227,8 @@ class RocksDB( } } - private def findColumnFamilyHandle(cf: String): ColumnFamilyHandle = { - val cfHandle = columnFamilyHandles.get(cf) - if (cfHandle == null) { - throw new IllegalArgumentException(s"Handle for column family $cf is not found") - } - cfHandle + def setCustomMetadata(metadata: Map[String, String]): Unit = { + customMetadataOnWritingVersion = metadata } /** @@ -310,16 +250,11 @@ class RocksDB( val writeTimeMs = timeTakenMs { db.write(writeOptions, writeBatch) } logInfo(s"Flushing updates for $newVersion") - val flushTimeMs = timeTakenMs { - db.flush(flushOptions, - new util.ArrayList[ColumnFamilyHandle](columnFamilyHandles.values())) - } + val flushTimeMs = timeTakenMs { db.flush(flushOptions) } val compactTimeMs = if (conf.compactOnCommit) { logInfo("Compacting") - timeTakenMs { - columnFamilyHandles.values().forEach(cfHandle => db.compactRange(cfHandle)) - } + timeTakenMs { db.compactRange() } } else 0 logInfo("Pausing background work") @@ -335,9 +270,11 @@ class RocksDB( logInfo(s"Syncing checkpoint for $newVersion to DFS") val fileSyncTimeMs = timeTakenMs { - fileManager.saveCheckpointToDfs(checkpointDir, newVersion, numKeysOnWritingVersion) + fileManager.saveCheckpointToDfs(checkpointDir, newVersion, numKeysOnWritingVersion, + customMetadataOnWritingVersion.toMap) } numKeysOnLoadedVersion = numKeysOnWritingVersion + customMetadataOnLoadedVersion = customMetadataOnWritingVersion loadedVersion = newVersion fileManagerMetrics = fileManager.latestSaveCheckpointMetrics commitLatencyMs ++= Map( @@ -352,7 +289,6 @@ class RocksDB( loadedVersion } catch { case t: Throwable => - logWarning(s"ERROR! exc: $t", t) loadedVersion = -1 // invalidate loaded version throw t } finally { @@ -369,6 +305,7 @@ class RocksDB( closePrefixScanIterators() writeBatch.clear() numKeysOnWritingVersion = numKeysOnLoadedVersion + customMetadataOnWritingVersion = customMetadataOnLoadedVersion release() logInfo(s"Rolled back to $loadedVersion") } @@ -404,6 +341,8 @@ class RocksDB( /** Get the latest version available in the DFS */ def getLatestVersion(): Long = fileManager.getLatestVersion() + def getCustomMetadata(): Map[String, String] = customMetadataOnWritingVersion + /** Get current instantaneous statistics */ def metrics: RocksDBMetrics = { import HistogramType._ @@ -496,43 +435,12 @@ class RocksDB( private def openDB(): Unit = { assert(db == null) - - val columnFamilyDescriptors = new util.ArrayList[ColumnFamilyDescriptor]() - columnFamilies.foreach { cf => - val bloomFilter = new BloomFilter() - val tableFormatConfig = new BlockBasedTableConfig() - tableFormatConfig.setBlockSize(conf.blockSizeKB * 1024) - tableFormatConfig.setBlockCache(new LRUCache(conf.blockCacheSizeMB * 1024 * 1024)) - tableFormatConfig.setFilterPolicy(bloomFilter) - tableFormatConfig.setFormatVersion(conf.formatVersion) - - val columnFamilyOptions = new ColumnFamilyOptions() - columnFamilyOptions.setTableFormatConfig(tableFormatConfig) - columnFamilyDescriptors.add(new ColumnFamilyDescriptor(cf.getBytes(), columnFamilyOptions)) - } - - val cfHandles = new util.ArrayList[ColumnFamilyHandle](columnFamilyDescriptors.size()) - db = NativeRocksDB.open(dbOptions, workingDir.toString, columnFamilyDescriptors, - cfHandles) - - columnFamilyHandles = new util.HashMap[String, ColumnFamilyHandle]() - columnFamilies.indices.foreach { idx => - columnFamilyHandles.put(columnFamilies(idx), cfHandles.get(idx)) - } - - // FIXME: constant - defaultColumnFamilyHandle = columnFamilyHandles.get("default") - + db = NativeRocksDB.open(dbOptions, workingDir.toString) logInfo(s"Opened DB with conf ${conf}") } private def closeDB(): Unit = { if (db != null) { - columnFamilyHandles.entrySet().forEach(pair => db.destroyColumnFamilyHandle(pair.getValue)) - columnFamilyHandles.clear() - columnFamilyHandles = null - defaultColumnFamilyHandle = null - db.close() db = null } @@ -546,17 +454,10 @@ class RocksDB( // Warn is mapped to info because RocksDB warn is too verbose // (e.g. dumps non-warning stuff like stats) val loggingFunc: ( => String) => Unit = infoLogLevel match { - /* case InfoLogLevel.FATAL_LEVEL | InfoLogLevel.ERROR_LEVEL => logError(_) case InfoLogLevel.WARN_LEVEL | InfoLogLevel.INFO_LEVEL => logInfo(_) case InfoLogLevel.DEBUG_LEVEL => logDebug(_) case _ => logTrace(_) - */ - case InfoLogLevel.FATAL_LEVEL | InfoLogLevel.ERROR_LEVEL => logError(_) - case InfoLogLevel.WARN_LEVEL => logWarning(_) - case InfoLogLevel.INFO_LEVEL => logInfo(_) - case InfoLogLevel.DEBUG_LEVEL => logDebug(_) - case _ => logTrace(_) } loggingFunc(s"[NativeRocksDB-${infoLogLevel.getValue}] $logMsg") } @@ -702,7 +603,7 @@ object RocksDBMetrics { /** Class to wrap RocksDB's native histogram */ case class RocksDBNativeHistogram( - sum: Long, avg: Double, stddev: Double, median: Double, p95: Double, p99: Double, count: Long) { + sum: Long, avg: Double, stddev: Double, median: Double, p95: Double, p99: Double, count: Long) { def json: String = Serialization.write(this)(RocksDBMetrics.format) } @@ -733,4 +634,3 @@ case class AcquiredThreadInfo() { s"[ThreadId: ${threadRef.get.map(_.getId)}$taskStr]" } } - diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala index 23cdbd0..567f916 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala @@ -152,11 +152,15 @@ class RocksDBFileManager( def latestSaveCheckpointMetrics: RocksDBFileManagerMetrics = saveCheckpointMetrics /** Save all the files in given local checkpoint directory as a committed version in DFS */ - def saveCheckpointToDfs(checkpointDir: File, version: Long, numKeys: Long): Unit = { + def saveCheckpointToDfs( + checkpointDir: File, + version: Long, + numKeys: Long, + customMetadata: Map[String, String] = Map.empty): Unit = { logFilesInDir(checkpointDir, s"Saving checkpoint files for version $version") val (localImmutableFiles, localOtherFiles) = listRocksDBFiles(checkpointDir) val rocksDBFiles = saveImmutableFilesToDfs(version, localImmutableFiles) - val metadata = RocksDBCheckpointMetadata(rocksDBFiles, numKeys) + val metadata = RocksDBCheckpointMetadata(rocksDBFiles, numKeys, customMetadata) val metadataFile = localMetadataFile(checkpointDir) metadata.writeToFile(metadataFile) logInfo(s"Written metadata for version $version:\n${metadata.prettyJson}") @@ -184,7 +188,7 @@ class RocksDBFileManager( val metadata = if (version == 0) { if (localDir.exists) Utils.deleteRecursively(localDir) localDir.mkdirs() - RocksDBCheckpointMetadata(Seq.empty, 0) + RocksDBCheckpointMetadata(Seq.empty, 0, Map.empty) } else { // Delete all non-immutable files in local dir, and unzip new ones from DFS commit file listRocksDBFiles(localDir)._2.foreach(_.delete()) @@ -540,12 +544,20 @@ object RocksDBFileManagerMetrics { case class RocksDBCheckpointMetadata( sstFiles: Seq[RocksDBSstFile], logFiles: Seq[RocksDBLogFile], - numKeys: Long) { + numKeys: Long, + customMetadata: Map[String, String]) { import RocksDBCheckpointMetadata._ def json: String = { - // We turn this field into a null to avoid write a empty logFiles field in the json. - val nullified = if (logFiles.isEmpty) this.copy(logFiles = null) else this + // We turn the field into a null to avoid write below fields in the json if they are empty: + // - logFiles + // - customMetadata + val nullified = { + var cur = this + cur = if (logFiles.isEmpty) cur.copy(logFiles = null) else cur + cur = if (customMetadata.isEmpty) cur.copy(customMetadata = null) else cur + cur + } mapper.writeValueAsString(nullified) } @@ -593,11 +605,18 @@ object RocksDBCheckpointMetadata { } } - def apply(rocksDBFiles: Seq[RocksDBImmutableFile], numKeys: Long): RocksDBCheckpointMetadata = { + def apply( + rocksDBFiles: Seq[RocksDBImmutableFile], + numKeys: Long): RocksDBCheckpointMetadata = apply(rocksDBFiles, numKeys, Map.empty) + + def apply( + rocksDBFiles: Seq[RocksDBImmutableFile], + numKeys: Long, + customMetadata: Map[String, String]): RocksDBCheckpointMetadata = { val sstFiles = rocksDBFiles.collect { case file: RocksDBSstFile => file } val logFiles = rocksDBFiles.collect { case file: RocksDBLogFile => file } - RocksDBCheckpointMetadata(sstFiles, logFiles, numKeys) + RocksDBCheckpointMetadata(sstFiles, logFiles, numKeys, customMetadata) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala index 323826d..84e9a8d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala @@ -22,7 +22,7 @@ import java.nio.ByteOrder import org.apache.spark.sql.catalyst.expressions.{BoundReference, JoinedRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider.{STATE_ENCODING_NUM_VERSION_BYTES, STATE_ENCODING_VERSION} -import org.apache.spark.sql.types.{StructField, StructType, TimestampType} +import org.apache.spark.sql.types.{StructField, StructType} import org.apache.spark.unsafe.Platform sealed trait RocksDBStateEncoder { @@ -30,11 +30,6 @@ sealed trait RocksDBStateEncoder { def encodePrefixKey(prefixKey: UnsafeRow): Array[Byte] def extractPrefixKey(key: UnsafeRow): UnsafeRow - def supportEventTimeIndex: Boolean - def extractEventTime(key: UnsafeRow): Long - def encodeEventTimeIndexKey(timestamp: Long, encodedKey: Array[Byte]): Array[Byte] - def decodeEventTimeIndexKey(eventTimeBytes: Array[Byte]): (Long, Array[Byte]) - def encodeKey(row: UnsafeRow): Array[Byte] def encodeValue(row: UnsafeRow): Array[Byte] @@ -47,13 +42,11 @@ object RocksDBStateEncoder { def getEncoder( keySchema: StructType, valueSchema: StructType, - numColsPrefixKey: Int, - eventTimeColIdx: Array[Int]): RocksDBStateEncoder = { + numColsPrefixKey: Int): RocksDBStateEncoder = { if (numColsPrefixKey > 0) { - // FIXME: need to deal with prefix case as well new PrefixKeyScanStateEncoder(keySchema, valueSchema, numColsPrefixKey) } else { - new NoPrefixKeyStateEncoder(keySchema, valueSchema, eventTimeColIdx) + new NoPrefixKeyStateEncoder(keySchema, valueSchema) } } @@ -228,23 +221,6 @@ class PrefixKeyScanStateEncoder( } override def supportPrefixKeyScan: Boolean = true - - override def supportEventTimeIndex: Boolean = false - - // FIXME: fix me! - def extractEventTime(key: UnsafeRow): Long = { - throw new IllegalStateException("This encoder doesn't support event time index!") - } - - // FIXME: fix me! - def encodeEventTimeIndexKey(timestamp: Long, encodedKey: Array[Byte]): Array[Byte] = { - throw new IllegalStateException("This encoder doesn't support event time index!") - } - - // FIXME: fix me! - def decodeEventTimeIndexKey(eventTimeBytes: Array[Byte]): (Long, Array[Byte]) = { - throw new IllegalStateException("This encoder doesn't support event time index!") - } } /** @@ -259,8 +235,7 @@ class PrefixKeyScanStateEncoder( */ class NoPrefixKeyStateEncoder( keySchema: StructType, - valueSchema: StructType, - eventTimeColIdx: Array[Int]) extends RocksDBStateEncoder { + valueSchema: StructType) extends RocksDBStateEncoder { import RocksDBStateEncoder._ @@ -269,32 +244,6 @@ class NoPrefixKeyStateEncoder( private val valueRow = new UnsafeRow(valueSchema.size) private val rowTuple = new UnsafeRowPair() - validateColumnTypeOnEventTimeColumn() - - private def validateColumnTypeOnEventTimeColumn(): Unit = { - if (eventTimeColIdx.nonEmpty) { - var curSchema: StructType = keySchema - eventTimeColIdx.dropRight(1).foreach { idx => - curSchema(idx).dataType match { - case stType: StructType => - curSchema = stType - case _ => - // FIXME: better error message - throw new IllegalStateException("event time column is not properly specified! " + - s"index: ${eventTimeColIdx.mkString("(", ", ", ")")} / key schema: $keySchema") - } - } - - curSchema(eventTimeColIdx.last).dataType match { - case _: TimestampType => - case _ => - // FIXME: better error message - throw new IllegalStateException("event time column is not properly specified! " + - s"index: ${eventTimeColIdx.mkString("(", ", ", ")")} / key schema: $keySchema") - } - } - } - override def encodeKey(row: UnsafeRow): Array[Byte] = encodeUnsafeRow(row) override def encodeValue(row: UnsafeRow): Array[Byte] = encodeUnsafeRow(row) @@ -337,41 +286,4 @@ class NoPrefixKeyStateEncoder( override def encodePrefixKey(prefixKey: UnsafeRow): Array[Byte] = { throw new IllegalStateException("This encoder doesn't support prefix key!") } - - override def supportEventTimeIndex: Boolean = eventTimeColIdx.nonEmpty - - override def extractEventTime(key: UnsafeRow): Long = { - var curRow: UnsafeRow = key - var curSchema: StructType = keySchema - - eventTimeColIdx.dropRight(1).foreach { idx => - // validation is done in initialization phase - curSchema = curSchema(idx).dataType.asInstanceOf[StructType] - curRow = curRow.getStruct(idx, curSchema.length) - } - - curRow.getLong(eventTimeColIdx.last) / 1000 - } - - override def encodeEventTimeIndexKey(timestamp: Long, encodedKey: Array[Byte]): Array[Byte] = { - val newKey = new Array[Byte](8 + encodedKey.length) - - Platform.putLong(newKey, Platform.BYTE_ARRAY_OFFSET, - BinarySortable.encodeToBinarySortableLong(timestamp)) - Platform.copyMemory(encodedKey, Platform.BYTE_ARRAY_OFFSET, newKey, - Platform.BYTE_ARRAY_OFFSET + 8, encodedKey.length) - - newKey - } - - override def decodeEventTimeIndexKey(eventTimeBytes: Array[Byte]): (Long, Array[Byte]) = { - val encoded = Platform.getLong(eventTimeBytes, Platform.BYTE_ARRAY_OFFSET) - val timestamp = BinarySortable.decodeBinarySortableLong(encoded) - - val key = new Array[Byte](eventTimeBytes.length - 8) - Platform.copyMemory(eventTimeBytes, Platform.BYTE_ARRAY_OFFSET + 8, - key, Platform.BYTE_ARRAY_OFFSET, eventTimeBytes.length - 8) - - (timestamp, key) - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 1d66220..d7c2e0f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -24,15 +24,14 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.types.StructType -import org.apache.spark.unsafe.Platform +import org.apache.spark.sql.types.{StructType, TimestampType} import org.apache.spark.util.{NextIterator, Utils} private[sql] class RocksDBStateStoreProvider extends StateStoreProvider with Logging with Closeable { import RocksDBStateStoreProvider._ - class RocksDBStateStore(lastVersion: Long) extends StateStore { + class RocksDBStateStore(lastVersion: Long, eventTimeColIdx: Array[Int]) extends StateStore { /** Trait and classes representing the internal state of the store */ trait STATE case object UPDATING extends STATE @@ -42,6 +41,40 @@ private[sql] class RocksDBStateStoreProvider @volatile private var state: STATE = UPDATING @volatile private var isValidated = false + private val supportEventTimeIndex: Boolean = eventTimeColIdx.nonEmpty + + if (supportEventTimeIndex) { + validateColumnTypeOnEventTimeColumn() + } + + private def validateColumnTypeOnEventTimeColumn(): Unit = { + require(eventTimeColIdx.nonEmpty) + + var curSchema: StructType = keySchema + eventTimeColIdx.dropRight(1).foreach { idx => + curSchema(idx).dataType match { + case stType: StructType => + curSchema = stType + case _ => + // FIXME: better error message + throw new IllegalStateException("event time column is not properly specified! " + + s"index: ${eventTimeColIdx.mkString("(", ", ", ")")} / key schema: $keySchema") + } + } + + curSchema(eventTimeColIdx.last).dataType match { + case _: TimestampType => + case _ => + // FIXME: better error message + throw new IllegalStateException("event time column is not properly specified! " + + s"index: ${eventTimeColIdx.mkString("(", ", ", ")")} / key schema: $keySchema") + } + } + + private var lowestEventTime: Long = Option(rocksDB.getCustomMetadata()) + .flatMap(_.get(METADATA_KEY_LOWEST_EVENT_TIME).map(_.toLong)) + .getOrElse(INVALID_LOWEST_EVENT_TIME_VALUE) + override def id: StateStoreId = RocksDBStateStoreProvider.this.stateStoreId override def version: Long = lastVersion @@ -66,19 +99,29 @@ private[sql] class RocksDBStateStoreProvider val encodedValue = encoder.encodeValue(value) rocksDB.put(encodedKey, encodedValue) - if (encoder.supportEventTimeIndex) { - val timestamp = encoder.extractEventTime(key) - val tsKey = encoder.encodeEventTimeIndexKey(timestamp, encodedKey) - - rocksDB.put(CF_EVENT_TIME_INDEX, tsKey, Array.empty) + if (supportEventTimeIndex) { + val eventTimeValue = extractEventTime(key) + if (lowestEventTime != INVALID_LOWEST_EVENT_TIME_VALUE + && lowestEventTime > eventTimeValue) { + lowestEventTime = eventTimeValue + } } } override def remove(key: UnsafeRow): Unit = { verify(state == UPDATING, "Cannot remove after already committed or aborted") verify(key != null, "Key cannot be null") - // FIXME: this should reflect the index + rocksDB.remove(encoder.encodeKey(key)) + if (supportEventTimeIndex) { + val eventTimeValue = extractEventTime(key) + if (lowestEventTime == eventTimeValue) { + // We can't track the next lowest event time without scanning entire keys. + // Mark the lowest event time value as invalid, so that scan happens in evict phase and + // the value is correctly updated later. + lowestEventTime = INVALID_LOWEST_EVENT_TIME_VALUE + } + } } override def iterator(): Iterator[UnsafeRowPair] = { @@ -100,8 +143,89 @@ private[sql] class RocksDBStateStoreProvider rocksDB.prefixScan(prefix).map(kv => encoder.decode(kv)) } + /** FIXME: method doc */ + override def evictOnWatermark( + watermarkMs: Long, + altPred: UnsafeRowPair => Boolean): Iterator[UnsafeRowPair] = { + if (supportEventTimeIndex) { + // convert lowestEventTime to milliseconds, and compare to watermarkMs + // retract 1 ms to avoid edge-case on conversion from microseconds to milliseconds + if (lowestEventTime != INVALID_LOWEST_EVENT_TIME_VALUE + && ((lowestEventTime / 1000) - 1 > watermarkMs)) { + Iterator.empty + } else { + // start with invalidating the lowest event time + lowestEventTime = INVALID_LOWEST_EVENT_TIME_VALUE + + new NextIterator[UnsafeRowPair] { + private val iter = rocksDB.iterator() + + // here we use Long.MaxValue as invalid value + private var lowestEventTimeInIter = Long.MaxValue + + override protected def getNext(): UnsafeRowPair = { + var result: UnsafeRowPair = null + while (result == null && iter.hasNext) { + val kv = iter.next() + val rowPair = encoder.decode(kv) + if (altPred(rowPair)) { + rocksDB.remove(kv.key) + result = rowPair + } else { + val eventTime = extractEventTime(rowPair.key) + if (lowestEventTimeInIter > eventTime) { + lowestEventTimeInIter = eventTime + } + } + } + + if (result == null) { + finished = true + null + } else { + result + } + } + + override protected def close(): Unit = { + if (lowestEventTimeInIter != Long.MaxValue) { + lowestEventTime = lowestEventTimeInIter + } + } + } + } + } else { + rocksDB.iterator().flatMap { kv => + val rowPair = encoder.decode(kv) + if (altPred(rowPair)) { + rocksDB.remove(kv.key) + Some(rowPair) + } else { + None + } + } + } + } + + private def extractEventTime(key: UnsafeRow): Long = { + var curRow: UnsafeRow = key + var curSchema: StructType = keySchema + + eventTimeColIdx.dropRight(1).foreach { idx => + // validation is done in initialization phase + curSchema = curSchema(idx).dataType.asInstanceOf[StructType] + curRow = curRow.getStruct(idx, curSchema.length) + } + + curRow.getLong(eventTimeColIdx.last) + } + override def commit(): Long = synchronized { verify(state == UPDATING, "Cannot commit after already committed or aborted") + + // set the metadata to RocksDB instance so that it can be committed as well + rocksDB.setCustomMetadata(Map(METADATA_KEY_LOWEST_EVENT_TIME -> lowestEventTime.toString)) + val newVersion = rocksDB.commit() state = COMMITTED logInfo(s"Committed $newVersion for $id") @@ -173,68 +297,6 @@ private[sql] class RocksDBStateStoreProvider /** Return the [[RocksDB]] instance in this store. This is exposed mainly for testing. */ def dbInstance(): RocksDB = rocksDB - - /** FIXME: method doc */ - override def evictOnWatermark( - watermarkMs: Long, - altPred: UnsafeRowPair => Boolean): Iterator[UnsafeRowPair] = { - - if (encoder.supportEventTimeIndex) { - val kv = new ByteArrayPair() - - // FIXME: DEBUG - // logWarning(s"DEBUG: start iterating event time index, watermark $watermarkMs") - - new NextIterator[UnsafeRowPair] { - private val iter = rocksDB.iterator(CF_EVENT_TIME_INDEX) - override protected def getNext(): UnsafeRowPair = { - if (iter.hasNext) { - val pair = iter.next() - - val encodedTs = Platform.getLong(pair.key, Platform.BYTE_ARRAY_OFFSET) - val decodedTs = BinarySortable.decodeBinarySortableLong(encodedTs) - - // FIXME: DEBUG - // logWarning(s"DEBUG: decoded TS: $decodedTs") - - if (decodedTs > watermarkMs) { - finished = true - null - } else { - // FIXME: can we leverage deleteRange to bulk delete on index? - rocksDB.remove(CF_EVENT_TIME_INDEX, pair.key) - val (_, encodedKey) = encoder.decodeEventTimeIndexKey(pair.key) - val value = rocksDB.get(encodedKey) - if (value == null) { - throw new IllegalStateException("Event time index has been broken!") - } - kv.set(encodedKey, value) - val rowPair = encoder.decode(kv) - rocksDB.remove(encodedKey) - rowPair - } - } else { - finished = true - null - } - } - - override protected def close(): Unit = { - iter.closeIfNeeded() - } - } - } else { - rocksDB.iterator().flatMap { kv => - val rowPair = encoder.decode(kv) - if (altPred(rowPair)) { - rocksDB.remove(kv.key) - Some(rowPair) - } else { - None - } - } - } - } } override def init( @@ -257,7 +319,7 @@ private[sql] class RocksDBStateStoreProvider this.operatorContext = operatorContext this.encoder = RocksDBStateEncoder.getEncoder(keySchema, valueSchema, - operatorContext.numColsPrefixKey, operatorContext.eventTimeColIdx) + operatorContext.numColsPrefixKey) rocksDB // lazy initialization } @@ -267,7 +329,7 @@ private[sql] class RocksDBStateStoreProvider override def getStore(version: Long): StateStore = { require(version >= 0, "Version cannot be less than 0") rocksDB.load(version) - new RocksDBStateStore(version) + new RocksDBStateStore(version, operatorContext.eventTimeColIdx) } override def doMaintenance(): Unit = { @@ -298,7 +360,6 @@ private[sql] class RocksDBStateStoreProvider val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf) val localRootDir = Utils.createTempDir(Utils.getLocalDir(sparkConf), storeIdStr) new RocksDB(dfsRootDir, RocksDBConf(storeConf), - columnFamilies = Seq("default", RocksDBStateStoreProvider.CF_EVENT_TIME_INDEX), localRootDir = localRootDir, hadoopConf = hadoopConf, loggingId = storeIdStr) } @@ -315,8 +376,8 @@ object RocksDBStateStoreProvider { val STATE_ENCODING_NUM_VERSION_BYTES = 1 val STATE_ENCODING_VERSION: Byte = 0 - // reserved column families - val CF_EVENT_TIME_INDEX: String = "__event_time_idx" + val INVALID_LOWEST_EVENT_TIME_VALUE: Long = Long.MinValue + val METADATA_KEY_LOWEST_EVENT_TIME: String = "lowestEventTimeInState" // Native operation latencies report as latency in microseconds // as SQLMetrics support millis. Convert the value to millis diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/StateStoreBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/StateStoreBenchmark.scala index 9a83f5c..12ac517 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/StateStoreBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/StateStoreBenchmark.scala @@ -66,7 +66,7 @@ object StateStoreBenchmark extends SqlBasedBenchmark { private def runEvictBenchmark(): Unit = { runBenchmark("evict rows") { val numOfRows = Seq(10000) // Seq(1000, 10000, 100000) - val numOfTimestamps = Seq(10, 100, 1000) + val numOfTimestamps = Seq(100, 1000) val numOfEvictionRates = Seq(50, 25, 10, 5, 1, 0) // Seq(100, 75, 50, 25, 1, 0) numOfRows.foreach { numOfRow => @@ -116,7 +116,7 @@ object StateStoreBenchmark extends SqlBasedBenchmark { val rocksDBStore = rocksDBProvider.getStore(committedVersion) timer.startTiming() - evictAsFullScanAndRemove(rocksDBStore, maxTimestampToEvictInMillis) + evictAsFullScanAndRemove(rocksDBStore, maxTimestampToEvictInMillis, numOfRowsToEvict) timer.stopTiming() rocksDBStore.abort() @@ -126,7 +126,7 @@ object StateStoreBenchmark extends SqlBasedBenchmark { val rocksDBStore = rocksDBProvider.getStore(committedVersion) timer.startTiming() - evictAsNewEvictApi(rocksDBStore, maxTimestampToEvictInMillis) + evictAsNewEvictApi(rocksDBStore, maxTimestampToEvictInMillis, numOfRowsToEvict) timer.stopTiming() rocksDBStore.abort() @@ -487,20 +487,29 @@ object StateStoreBenchmark extends SqlBasedBenchmark { private def evictAsFullScanAndRemove( store: StateStore, - maxTimestampToEvict: Long): Unit = { + maxTimestampToEvict: Long, + expectedNumOfRows: Long): Unit = { + var removedRows: Long = 0 store.iterator().foreach { r => - if (r.key.getLong(1) < maxTimestampToEvict) { + if (r.key.getLong(1) / 1000 <= maxTimestampToEvict) { store.remove(r.key) + removedRows += 1 } } + assert(removedRows == expectedNumOfRows, + s"expected: $expectedNumOfRows actual: $removedRows") } private def evictAsNewEvictApi( store: StateStore, - maxTimestampToEvict: Long): Unit = { + maxTimestampToEvict: Long, + expectedNumOfRows: Long): Unit = { + var removedRows: Long = 0 store.evictOnWatermark(maxTimestampToEvict, pair => { - pair.key.getLong(1) < maxTimestampToEvict - }).foreach { _ => } + pair.key.getLong(1) / 1000 <= maxTimestampToEvict + }).foreach { _ => removedRows += 1 } + assert(removedRows == expectedNumOfRows, + s"expected: $expectedNumOfRows actual: $removedRows") } private def fullScanAndCompareTimestamp( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala index 1ee2748..31e49ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala @@ -355,14 +355,18 @@ class RocksDBSuite extends SparkFunSuite { RocksDBCheckpointMetadata(Seq.empty, 0L), """{"sstFiles":[],"numKeys":0}""" ) - // shouldn't include the "logFiles" field in json when it's empty + // shouldn't include the "logFiles" & "customMetadata" field in json when it's empty checkJsonRoundtrip( RocksDBCheckpointMetadata(sstFiles, 12345678901234L), """{"sstFiles":[{"localFileName":"00001.sst","dfsSstFileName":"00001-uuid.sst","sizeBytes":12345678901234}],"numKeys":12345678901234}""" ) + // shouldn't include the "customMetadata" field in json when it's empty checkJsonRoundtrip( - RocksDBCheckpointMetadata(sstFiles, logFiles, 12345678901234L), + RocksDBCheckpointMetadata(sstFiles, logFiles, 12345678901234L, Map.empty), """{"sstFiles":[{"localFileName":"00001.sst","dfsSstFileName":"00001-uuid.sst","sizeBytes":12345678901234}],"logFiles":[{"localFileName":"00001.log","dfsLogFileName":"00001-uuid.log","sizeBytes":12345678901234}],"numKeys":12345678901234}""") + + // FIXME: test customMetadata here + // scalastyle:on line.size.limit } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org