Copilot commented on code in PR #49032: URL: https://github.com/apache/spark/pull/49032#discussion_r2505648167
########## core/src/test/scala/org/apache/spark/storage/BlockTTLIntegrationSuite.scala: ########## @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import scala.jdk.CollectionConverters._ + +import org.scalatest.concurrent.Eventually +import org.scalatest.time._ + +import org.apache.spark._ +import org.apache.spark.internal.config +import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef} +import org.apache.spark.util.ResetSystemProperties + +class BlockTTLIntegrationSuite extends SparkFunSuite with LocalSparkContext + with ResetSystemProperties with Eventually { + + implicit override val patienceConfig: PatienceConfig = + PatienceConfig(timeout = scaled(Span(20, Seconds)), interval = scaled(Span(5, Millis))) + + val numExecs = 3 + val numParts = 3 + val TaskStarted = "TASK_STARTED" + val TaskEnded = "TASK_ENDED" + val JobEnded = "JOB_ENDED" + + // TODO(holden): This is shared with MapOutputTrackerSuite move to a BlockTestUtils or similar. + private def fetchDeclaredField(value: AnyRef, fieldName: String): AnyRef = { + val field = value.getClass.getDeclaredField(fieldName) + field.setAccessible(true) + field.get(value) + } + + private def lookupBlockManagerMasterEndpoint(sc: SparkContext): BlockManagerMasterEndpoint = { + val rpcEnv = sc.env.rpcEnv + val dispatcher = fetchDeclaredField(rpcEnv, "dispatcher") + fetchDeclaredField(dispatcher, "endpointRefs"). + asInstanceOf[java.util.Map[RpcEndpoint, RpcEndpointRef]].asScala. + filter(_._1.isInstanceOf[BlockManagerMasterEndpoint]). + head._1.asInstanceOf[BlockManagerMasterEndpoint] + } + + private def lookupMapOutputTrackerMaster(sc: SparkContext): MapOutputTrackerMaster = { + val bme = lookupBlockManagerMasterEndpoint(sc) + bme.getMapOutputTrackerMaster() + } + + test(s"Test that cache blocks are recorded.") { + val conf = new SparkConf() + .setAppName("test-blockmanager-decommissioner") + .setMaster("local-cluster[2, 1, 1024]") + .set(config.SPARK_TTL_BLOCK_CLEANER, 100L) + .set(config.SPARK_TTL_SHUFFLE_BLOCK_CLEANER, 100L) + sc = new SparkContext(conf) + sc.setLogLevel("DEBUG") + TestUtils.waitUntilExecutorsUp(sc, 2, 60000) + val managerMasterEndpoint = lookupBlockManagerMasterEndpoint(sc) + assert(managerMasterEndpoint.rddAccessTime.isEmpty) + // Make some cache blocks + val input = sc.parallelize(1.to(100)).cache() + input.count() + // Check that the blocks were registered with the TTL tracker + assert(!managerMasterEndpoint.rddAccessTime.isEmpty) + val trackedRDDBlocks = managerMasterEndpoint.rddAccessTime.asScala.keys + assert(!trackedRDDBlocks.isEmpty) + } + + test(s"Test that shuffle blocks are tracked properly and removed after TTL") { Review Comment: Inconsistent test naming. The test name includes `s"..."` string interpolation syntax but doesn't use any variables. The `s` prefix should be removed. ```suggestion test("Test that shuffle blocks are tracked properly and removed after TTL") { ``` ########## core/src/main/scala/org/apache/spark/storage/BlockId.scala: ########## @@ -59,7 +65,8 @@ case class RDDBlockId(rddId: Int, splitIndex: Int) extends BlockId { // Format of the shuffle block ids (including data and index) should be kept in sync with // org.apache.spark.network.shuffle.ExternalShuffleBlockResolver#getBlockData(). @DeveloperApi -case class ShuffleBlockId(shuffleId: Int, mapId: Long, reduceId: Int) extends BlockId { +case class ShuffleBlockId(shuffleId: Int, mapId: Long, reduceId: Int) extends BlockId + with ShuffleId{ Review Comment: Missing space after trait name. Should be `with ShuffleId {` instead of `with ShuffleId{`. ```suggestion with ShuffleId { ``` ########## core/src/main/scala/org/apache/spark/MapOutputTracker.scala: ########## @@ -758,6 +773,50 @@ private[spark] class MapOutputTrackerMaster( private val availableProcessors = Runtime.getRuntime.availableProcessors() + def updateShuffleAtime(shuffleId: Int): Unit = { + if (conf.get(SPARK_TTL_SHUFFLE_BLOCK_CLEANER).isDefined) { + shuffleAccessTime.put(shuffleId, System.currentTimeMillis()) + } + } + + private class TTLCleaner extends Runnable { + override def run(): Unit = { + // Poll the shuffle access times if we're configured for it. + conf.get(SPARK_TTL_SHUFFLE_BLOCK_CLEANER) match { + case Some(ttl) => + while (true) { + val maxAge = System.currentTimeMillis() - ttl + // Find the elements to be removed & update oldest remaining time (if any) + var oldest = System.currentTimeMillis() + val toBeRemoved = shuffleAccessTime.asScala.flatMap { case (shuffleId, atime) => Review Comment: Potential race condition: The TTL cleaner thread reads from `shuffleAccessTime` and calls `unregisterAllMapAndMergeOutput()`, which also modifies `shuffleAccessTime`. Since `shuffleAccessTime` is a non-concurrent `JHashMap`, concurrent access from the cleaner thread and other threads updating access times could lead to `ConcurrentModificationException` or data corruption. Consider using a `ConcurrentHashMap` or adding proper synchronization. ########## core/src/main/scala/org/apache/spark/MapOutputTracker.scala: ########## @@ -758,6 +773,50 @@ private[spark] class MapOutputTrackerMaster( private val availableProcessors = Runtime.getRuntime.availableProcessors() + def updateShuffleAtime(shuffleId: Int): Unit = { + if (conf.get(SPARK_TTL_SHUFFLE_BLOCK_CLEANER).isDefined) { + shuffleAccessTime.put(shuffleId, System.currentTimeMillis()) + } + } + + private class TTLCleaner extends Runnable { + override def run(): Unit = { + // Poll the shuffle access times if we're configured for it. + conf.get(SPARK_TTL_SHUFFLE_BLOCK_CLEANER) match { + case Some(ttl) => + while (true) { + val maxAge = System.currentTimeMillis() - ttl + // Find the elements to be removed & update oldest remaining time (if any) + var oldest = System.currentTimeMillis() + val toBeRemoved = shuffleAccessTime.asScala.flatMap { case (shuffleId, atime) => + if (atime < maxAge) { + Some(shuffleId) + } else { + if (atime < oldest) { + oldest = atime + } + None + } + }.toList + toBeRemoved.map { shuffleId => + try { + unregisterAllMapAndMergeOutput(shuffleId) + } catch { + case NonFatal(e) => + logError( + log"Error removing shuffle ${MDC(SHUFFLE_ID, shuffleId)} with TTL cleaner", e) + } + } + // Wait until the next possible element to be removed + val delay = math.max((oldest + ttl) - System.currentTimeMillis(), 1) + Thread.sleep(delay) + } + case None => + logDebug("Tried to start TTL cleaner when not configured.") + } + } + } Review Comment: The TTL cleaner thread continues running indefinitely with `while (true)` but there's no mechanism to stop it gracefully. When `stop()` is called, it only calls `cleanerThreadpool.map(_.shutdown())` which will not interrupt the thread sleeping in `Thread.sleep(delay)`. Consider using `shutdownNow()` instead of `shutdown()` or implementing a proper shutdown flag to allow graceful termination. ########## core/src/main/scala/org/apache/spark/internal/config/package.scala: ########## @@ -2904,4 +2904,24 @@ package object config { .checkValue(v => v.forall(Set("stdout", "stderr").contains), "The value only can be one or more of 'stdout, stderr'.") .createWithDefault(Seq("stdout", "stderr")) + + private[spark] val SPARK_TTL_BLOCK_CLEANER = + ConfigBuilder("spark.cleaner.ttl.all") + .doc("Add a TTL for all blocks tracked in Spark. By default blocks are only removed after " + + " GC on driver which with DataFrames or RDDs at the global scope will not occur. " + + "This must be configured before starting the SparkContext (e.g. can not be added to a" + + "a running Spark instance.)") + .version("4.1.0") + .timeConf(TimeUnit.MILLISECONDS) + .createOptional + + private[spark] val SPARK_TTL_SHUFFLE_BLOCK_CLEANER = + ConfigBuilder("spark.cleaner.ttl.shuffle") + .doc("Add a TTL for shuffle blocks tracked in Spark. By default blocks are only removed " + + "after GC on driver which with DataFrames or RDDs at the global scope will not occur." + + "This must be configured before starting the SparkContext (e.g. can not be added to a" + + "a running Spark instance.)") Review Comment: Documentation contains "a a" which should be "a". The phrase should be "can not be added to a running Spark instance." Also, there's a missing space after "occur." at the end of line 2921. ########## core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala: ########## @@ -249,6 +267,62 @@ class BlockManagerMasterEndpoint( context.reply(updateRDDBlockVisibility(taskId, visible)) } + private def updateBlockAtime(blockId: BlockId) = { + // First handle "regular" blocks + if (!blockId.isShuffle) { + // Only update access times if we have the cleaner enabled. + if (conf.get(config.SPARK_TTL_BLOCK_CLEANER).isDefined) { + // Note: we don't _really_ care about concurrency here too much, if we have + // conflicting updates in time they're going to "close enough" to be a wash + // so we don't bother checking the return value here. + // For now we only do RDD blocks, because I'm not convinced it's safe to TTL + // clean Broadcast blocks, but maybe we can revisit that. + blockId.asRDDId.map { r => rddAccessTime.put(r.rddId, System.currentTimeMillis()) } + } + } else if (conf.get(config.SPARK_TTL_SHUFFLE_BLOCK_CLEANER).isDefined) { + // We track shuffles in the mapoutput tracker. + blockId.asShuffleId.map(s => mapOutputTracker.updateShuffleAtime(s.shuffleId)) + } + } + + + private class TTLCleaner extends Runnable { + override def run(): Unit = { + // Poll the shuffle access times if we're configured for it. + conf.get(config.SPARK_TTL_BLOCK_CLEANER) match { + case Some(ttl) => + while (true) { + val maxAge = System.currentTimeMillis() - ttl + // Find the elements to be removed & update oldest remaining time (if any) + var oldest = System.currentTimeMillis() + val toBeRemoved = rddAccessTime.asScala.flatMap { case (rddId, atime) => + if (atime < maxAge) { + Some(rddId) + } else { + if (atime < oldest) { + oldest = atime + } + None + } + }.toList + toBeRemoved.map { rddId => + try { + removeRdd(rddId) + } catch { + case NonFatal(e) => + logWarning(log"Error removing rdd ${MDC(RDD_ID, rddId)} with TTL cleaner", e) + } + } + // Wait until the next possible element to be removed + val delay = math.max((oldest + ttl) - System.currentTimeMillis(), 1) + Thread.sleep(delay) + } + case None => + logDebug("Tried to start TTL cleaner when not configured.") + } + } Review Comment: The TTL cleaner thread continues running indefinitely with `while (true)` but there's no mechanism to stop it gracefully. When `onStop()` is called, it only calls `cleanerThreadpool.map(_.shutdown())` which will not interrupt the thread sleeping in `Thread.sleep(delay)`. Consider using `shutdownNow()` instead of `shutdown()` or implementing a proper shutdown flag to allow graceful termination. ########## core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala: ########## @@ -345,7 +419,17 @@ class BlockManagerMasterEndpoint( } private def removeRdd(rddId: Int): Future[Seq[Int]] = { - // First remove the metadata for the given RDD, and then asynchronously remove the blocks + // Drop the RDD from TTL tracking. + try { + if (conf.get(config.SPARK_TTL_BLOCK_CLEANER).isDefined) { + rddAccessTime.remove(rddId) + } + } catch { + case NonFatal(e) => + logWarning(log"Error removing ${MDC(RDD_ID, rddId)} from RDD TTL tracking", e) Review Comment: The `try-catch` block around `rddAccessTime.remove(rddId)` is overly broad. A simple `remove()` operation on a HashMap should not throw exceptions in normal circumstances. If this is meant to handle concurrent access issues, it indicates the underlying concurrency problem needs to be fixed rather than caught. Consider removing this defensive catch or documenting what specific failure scenario this is protecting against. ```suggestion if (conf.get(config.SPARK_TTL_BLOCK_CLEANER).isDefined) { rddAccessTime.remove(rddId) ``` ########## core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala: ########## @@ -85,6 +85,11 @@ class BlockManagerMasterEndpoint( // Mapping from block id to the set of block managers that have the block. private val blockLocations = new JHashMap[BlockId, mutable.HashSet[BlockManagerId]] + // Keep track of last access times if we're using block TTLs + // We intentionally use a non-concurrent datastructure since "close" + // is good enough for atimes and reducing update cost matters. Review Comment: Misleading or incomplete documentation. The comment states "We intentionally use a non-concurrent datastructure since 'close' is good enough for atimes and reducing update cost matters." However, this approach is problematic when combined with the TTL cleaner thread that iterates over this map. The comment should explain how concurrent access is handled safely, or the implementation should be changed to use proper concurrency controls. ########## core/src/main/scala/org/apache/spark/MapOutputTracker.scala: ########## @@ -874,6 +935,14 @@ private[spark] class MapOutputTrackerMaster( /** Unregister all map and merge output information of the given shuffle. */ def unregisterAllMapAndMergeOutput(shuffleId: Int): Unit = { + try { + if (conf.get(SPARK_TTL_SHUFFLE_BLOCK_CLEANER).isDefined) { + shuffleAccessTime.remove(shuffleId) + } + } catch { + case NonFatal(e) => + logWarning(log"Error removing ${MDC(SHUFFLE_ID, shuffleId)} from Shuffle TTL tracking", e) Review Comment: The `try-catch` block around `shuffleAccessTime.remove(shuffleId)` is overly broad. A simple `remove()` operation on a HashMap should not throw exceptions in normal circumstances. If this is meant to handle concurrent access issues, it indicates the underlying concurrency problem needs to be fixed rather than caught. Consider removing this defensive catch or documenting what specific failure scenario this is protecting against. ```suggestion if (conf.get(SPARK_TTL_SHUFFLE_BLOCK_CLEANER).isDefined) { shuffleAccessTime.remove(shuffleId) ``` ########## core/src/main/scala/org/apache/spark/MapOutputTracker.scala: ########## @@ -711,6 +712,10 @@ private[spark] class MapOutputTrackerMaster( private[spark] val isLocal: Boolean) extends MapOutputTracker(conf) { + // Keep track of last access times for shuffle based TTL. Note: we don't use concurrent + // here because we don't care about overwriting times that are "close." + private[spark] val shuffleAccessTime = new JHashMap[Int, Long] Review Comment: The comment "we don't use concurrent here because we don't care about overwriting times that are 'close.'" is misleading given that a TTL cleaner thread iterates over this map. While overwriting times might be acceptable, concurrent iteration and modification can lead to `ConcurrentModificationException`. The comment should clarify the concurrency safety or the implementation should use proper thread-safe structures. ```suggestion // Keep track of last access times for shuffle based TTL. We use a concurrent map here to // ensure thread safety, as a TTL cleaner thread may iterate over this map while other threads // modify it. Overwriting times that are "close" is acceptable, but concurrency safety is required. private[spark] val shuffleAccessTime = new ConcurrentHashMap[Int, Long]() ``` ########## core/src/test/scala/org/apache/spark/storage/BlockTTLIntegrationSuite.scala: ########## @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import scala.jdk.CollectionConverters._ + +import org.scalatest.concurrent.Eventually +import org.scalatest.time._ + +import org.apache.spark._ +import org.apache.spark.internal.config +import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef} +import org.apache.spark.util.ResetSystemProperties + +class BlockTTLIntegrationSuite extends SparkFunSuite with LocalSparkContext + with ResetSystemProperties with Eventually { + + implicit override val patienceConfig: PatienceConfig = + PatienceConfig(timeout = scaled(Span(20, Seconds)), interval = scaled(Span(5, Millis))) + + val numExecs = 3 + val numParts = 3 + val TaskStarted = "TASK_STARTED" + val TaskEnded = "TASK_ENDED" + val JobEnded = "JOB_ENDED" Review Comment: Unused variables: `numExecs`, `numParts`, `TaskStarted`, `TaskEnded`, and `JobEnded` are defined but never used in any of the tests. Consider removing these unused declarations. ```suggestion ``` ########## core/src/test/scala/org/apache/spark/storage/BlockTTLIntegrationSuite.scala: ########## @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import scala.jdk.CollectionConverters._ + +import org.scalatest.concurrent.Eventually +import org.scalatest.time._ + +import org.apache.spark._ +import org.apache.spark.internal.config +import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef} +import org.apache.spark.util.ResetSystemProperties + +class BlockTTLIntegrationSuite extends SparkFunSuite with LocalSparkContext + with ResetSystemProperties with Eventually { + + implicit override val patienceConfig: PatienceConfig = + PatienceConfig(timeout = scaled(Span(20, Seconds)), interval = scaled(Span(5, Millis))) + + val numExecs = 3 + val numParts = 3 + val TaskStarted = "TASK_STARTED" + val TaskEnded = "TASK_ENDED" + val JobEnded = "JOB_ENDED" + + // TODO(holden): This is shared with MapOutputTrackerSuite move to a BlockTestUtils or similar. + private def fetchDeclaredField(value: AnyRef, fieldName: String): AnyRef = { + val field = value.getClass.getDeclaredField(fieldName) + field.setAccessible(true) + field.get(value) + } + + private def lookupBlockManagerMasterEndpoint(sc: SparkContext): BlockManagerMasterEndpoint = { + val rpcEnv = sc.env.rpcEnv + val dispatcher = fetchDeclaredField(rpcEnv, "dispatcher") + fetchDeclaredField(dispatcher, "endpointRefs"). + asInstanceOf[java.util.Map[RpcEndpoint, RpcEndpointRef]].asScala. + filter(_._1.isInstanceOf[BlockManagerMasterEndpoint]). + head._1.asInstanceOf[BlockManagerMasterEndpoint] + } + + private def lookupMapOutputTrackerMaster(sc: SparkContext): MapOutputTrackerMaster = { + val bme = lookupBlockManagerMasterEndpoint(sc) + bme.getMapOutputTrackerMaster() + } + + test(s"Test that cache blocks are recorded.") { + val conf = new SparkConf() + .setAppName("test-blockmanager-decommissioner") + .setMaster("local-cluster[2, 1, 1024]") + .set(config.SPARK_TTL_BLOCK_CLEANER, 100L) + .set(config.SPARK_TTL_SHUFFLE_BLOCK_CLEANER, 100L) + sc = new SparkContext(conf) + sc.setLogLevel("DEBUG") + TestUtils.waitUntilExecutorsUp(sc, 2, 60000) + val managerMasterEndpoint = lookupBlockManagerMasterEndpoint(sc) + assert(managerMasterEndpoint.rddAccessTime.isEmpty) + // Make some cache blocks + val input = sc.parallelize(1.to(100)).cache() + input.count() + // Check that the blocks were registered with the TTL tracker + assert(!managerMasterEndpoint.rddAccessTime.isEmpty) + val trackedRDDBlocks = managerMasterEndpoint.rddAccessTime.asScala.keys + assert(!trackedRDDBlocks.isEmpty) + } + + test(s"Test that shuffle blocks are tracked properly and removed after TTL") { + val ttl = 250L + val conf = new SparkConf() + .setAppName("test-blockmanager-ttls-shuffle-only") + .setMaster("local-cluster[2, 1, 1024]") + .set(config.SPARK_TTL_BLOCK_CLEANER, ttl) + .set(config.SPARK_TTL_SHUFFLE_BLOCK_CLEANER, ttl) + sc = new SparkContext(conf) + sc.setLogLevel("DEBUG") + TestUtils.waitUntilExecutorsUp(sc, 2, 60000) + val managerMasterEndpoint = lookupBlockManagerMasterEndpoint(sc) + val mapOutputTracker = lookupMapOutputTrackerMaster(sc) + // Make sure it's empty at the start + assert(managerMasterEndpoint.rddAccessTime.isEmpty) + assert(mapOutputTracker.shuffleAccessTime.isEmpty) + // Make some cache blocks + val input = sc.parallelize(1.to(100)).groupBy(_ % 10) + input.count() + // Make sure we've got the tracker threads defined + assert(mapOutputTracker.cleanerThreadpool.isDefined) + // Check that the shuffle blocks were NOT registered with the RDD TTL tracker. + assert(managerMasterEndpoint.rddAccessTime.isEmpty) + // Check that the shuffle blocks are registered with the map output TTL + eventually { assert(!mapOutputTracker.shuffleAccessTime.isEmpty) } + // It should be expired! + val t = System.currentTimeMillis() + eventually { assert( + mapOutputTracker.shuffleAccessTime.isEmpty, + s"We should have no blocks since we are now at time ${t} with ttl of ${ttl}") } + } + + + test(s"Test that all blocks are tracked properly and removed after TTL") { Review Comment: Inconsistent test naming. The test name includes `s"..."` string interpolation syntax but doesn't use any variables. The `s` prefix should be removed. ```suggestion test("Test that all blocks are tracked properly and removed after TTL") { ``` ########## core/src/main/scala/org/apache/spark/internal/config/package.scala: ########## @@ -2904,4 +2904,24 @@ package object config { .checkValue(v => v.forall(Set("stdout", "stderr").contains), "The value only can be one or more of 'stdout, stderr'.") .createWithDefault(Seq("stdout", "stderr")) + + private[spark] val SPARK_TTL_BLOCK_CLEANER = + ConfigBuilder("spark.cleaner.ttl.all") + .doc("Add a TTL for all blocks tracked in Spark. By default blocks are only removed after " + + " GC on driver which with DataFrames or RDDs at the global scope will not occur. " + + "This must be configured before starting the SparkContext (e.g. can not be added to a" + + "a running Spark instance.)") + .version("4.1.0") + .timeConf(TimeUnit.MILLISECONDS) + .createOptional + + private[spark] val SPARK_TTL_SHUFFLE_BLOCK_CLEANER = + ConfigBuilder("spark.cleaner.ttl.shuffle") + .doc("Add a TTL for shuffle blocks tracked in Spark. By default blocks are only removed " + + "after GC on driver which with DataFrames or RDDs at the global scope will not occur." + + "This must be configured before starting the SparkContext (e.g. can not be added to a" + + "a running Spark instance.)") Review Comment: Documentation has spacing issue. There should be no space before "GC" - change " GC" to "GC". Also, the phrase "This must be configured before starting the SparkContext (e.g. can not be added to a a running Spark instance.)" contains "a a" which should be "a". ```suggestion .doc("Add a TTL for all blocks tracked in Spark. By default blocks are only removed after" + " GC on driver which with DataFrames or RDDs at the global scope will not occur." + "This must be configured before starting the SparkContext (e.g. can not be added to a running Spark instance.)") .version("4.1.0") .timeConf(TimeUnit.MILLISECONDS) .createOptional private[spark] val SPARK_TTL_SHUFFLE_BLOCK_CLEANER = ConfigBuilder("spark.cleaner.ttl.shuffle") .doc("Add a TTL for shuffle blocks tracked in Spark. By default blocks are only removed" + "after GC on driver which with DataFrames or RDDs at the global scope will not occur." + "This must be configured before starting the SparkContext (e.g. can not be added to a running Spark instance.)") ``` ########## core/src/test/scala/org/apache/spark/storage/BlockTTLIntegrationSuite.scala: ########## @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import scala.jdk.CollectionConverters._ + +import org.scalatest.concurrent.Eventually +import org.scalatest.time._ + +import org.apache.spark._ +import org.apache.spark.internal.config +import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef} +import org.apache.spark.util.ResetSystemProperties + +class BlockTTLIntegrationSuite extends SparkFunSuite with LocalSparkContext + with ResetSystemProperties with Eventually { + + implicit override val patienceConfig: PatienceConfig = + PatienceConfig(timeout = scaled(Span(20, Seconds)), interval = scaled(Span(5, Millis))) + + val numExecs = 3 + val numParts = 3 + val TaskStarted = "TASK_STARTED" + val TaskEnded = "TASK_ENDED" + val JobEnded = "JOB_ENDED" + + // TODO(holden): This is shared with MapOutputTrackerSuite move to a BlockTestUtils or similar. + private def fetchDeclaredField(value: AnyRef, fieldName: String): AnyRef = { + val field = value.getClass.getDeclaredField(fieldName) + field.setAccessible(true) + field.get(value) + } + + private def lookupBlockManagerMasterEndpoint(sc: SparkContext): BlockManagerMasterEndpoint = { + val rpcEnv = sc.env.rpcEnv + val dispatcher = fetchDeclaredField(rpcEnv, "dispatcher") + fetchDeclaredField(dispatcher, "endpointRefs"). + asInstanceOf[java.util.Map[RpcEndpoint, RpcEndpointRef]].asScala. + filter(_._1.isInstanceOf[BlockManagerMasterEndpoint]). + head._1.asInstanceOf[BlockManagerMasterEndpoint] + } + + private def lookupMapOutputTrackerMaster(sc: SparkContext): MapOutputTrackerMaster = { + val bme = lookupBlockManagerMasterEndpoint(sc) + bme.getMapOutputTrackerMaster() + } + + test(s"Test that cache blocks are recorded.") { + val conf = new SparkConf() + .setAppName("test-blockmanager-decommissioner") + .setMaster("local-cluster[2, 1, 1024]") + .set(config.SPARK_TTL_BLOCK_CLEANER, 100L) + .set(config.SPARK_TTL_SHUFFLE_BLOCK_CLEANER, 100L) + sc = new SparkContext(conf) + sc.setLogLevel("DEBUG") + TestUtils.waitUntilExecutorsUp(sc, 2, 60000) + val managerMasterEndpoint = lookupBlockManagerMasterEndpoint(sc) + assert(managerMasterEndpoint.rddAccessTime.isEmpty) + // Make some cache blocks + val input = sc.parallelize(1.to(100)).cache() + input.count() + // Check that the blocks were registered with the TTL tracker + assert(!managerMasterEndpoint.rddAccessTime.isEmpty) + val trackedRDDBlocks = managerMasterEndpoint.rddAccessTime.asScala.keys + assert(!trackedRDDBlocks.isEmpty) + } + + test(s"Test that shuffle blocks are tracked properly and removed after TTL") { + val ttl = 250L + val conf = new SparkConf() + .setAppName("test-blockmanager-ttls-shuffle-only") + .setMaster("local-cluster[2, 1, 1024]") + .set(config.SPARK_TTL_BLOCK_CLEANER, ttl) + .set(config.SPARK_TTL_SHUFFLE_BLOCK_CLEANER, ttl) + sc = new SparkContext(conf) + sc.setLogLevel("DEBUG") + TestUtils.waitUntilExecutorsUp(sc, 2, 60000) + val managerMasterEndpoint = lookupBlockManagerMasterEndpoint(sc) + val mapOutputTracker = lookupMapOutputTrackerMaster(sc) + // Make sure it's empty at the start + assert(managerMasterEndpoint.rddAccessTime.isEmpty) + assert(mapOutputTracker.shuffleAccessTime.isEmpty) + // Make some cache blocks + val input = sc.parallelize(1.to(100)).groupBy(_ % 10) + input.count() + // Make sure we've got the tracker threads defined + assert(mapOutputTracker.cleanerThreadpool.isDefined) + // Check that the shuffle blocks were NOT registered with the RDD TTL tracker. + assert(managerMasterEndpoint.rddAccessTime.isEmpty) + // Check that the shuffle blocks are registered with the map output TTL + eventually { assert(!mapOutputTracker.shuffleAccessTime.isEmpty) } + // It should be expired! + val t = System.currentTimeMillis() + eventually { assert( + mapOutputTracker.shuffleAccessTime.isEmpty, + s"We should have no blocks since we are now at time ${t} with ttl of ${ttl}") } + } + + + test(s"Test that all blocks are tracked properly and removed after TTL") { + val ttl = 250L + val conf = new SparkConf() + .setAppName("test-blockmanager-ttls-enabled") + .setMaster("local-cluster[2, 1, 1024]") + .set(config.SPARK_TTL_BLOCK_CLEANER, ttl) + .set(config.SPARK_TTL_SHUFFLE_BLOCK_CLEANER, ttl) + sc = new SparkContext(conf) + sc.setLogLevel("DEBUG") + TestUtils.waitUntilExecutorsUp(sc, 2, 60000) + val managerMasterEndpoint = lookupBlockManagerMasterEndpoint(sc) + val mapOutputTracker = lookupMapOutputTrackerMaster(sc) + assert(managerMasterEndpoint.rddAccessTime.isEmpty) + // Make some cache blocks + val input = sc.parallelize(1.to(100)).groupBy(_ % 10) + val cachedInput = input.cache() + cachedInput.count() + // Check that we have both shuffle & RDD blocks registered + eventually { assert(!managerMasterEndpoint.rddAccessTime.isEmpty) } + eventually { assert(!mapOutputTracker.shuffleAccessTime.isEmpty) } + // Both should be expired! + val t = System.currentTimeMillis() + eventually { + assert(mapOutputTracker.shuffleAccessTime.isEmpty, + s"We should have no blocks since we are now at time ${t} with ttl of ${ttl}") + assert(managerMasterEndpoint.rddAccessTime.isEmpty, + s"We should have no blocks since we are now at time ${t} with ttl of ${ttl}") + } + // And redoing the count should work and everything should come back. + input.count() + eventually { + assert(!managerMasterEndpoint.rddAccessTime.isEmpty) + assert(!mapOutputTracker.shuffleAccessTime.isEmpty) + } + } + + test(s"Test that blocks TTLS are not tracked when not enabled") { Review Comment: Inconsistent test naming. The test name includes `s"..."` string interpolation syntax but doesn't use any variables. The `s` prefix should be removed. ########## core/src/test/scala/org/apache/spark/storage/BlockTTLIntegrationSuite.scala: ########## @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import scala.jdk.CollectionConverters._ + +import org.scalatest.concurrent.Eventually +import org.scalatest.time._ + +import org.apache.spark._ +import org.apache.spark.internal.config +import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef} +import org.apache.spark.util.ResetSystemProperties + +class BlockTTLIntegrationSuite extends SparkFunSuite with LocalSparkContext + with ResetSystemProperties with Eventually { + + implicit override val patienceConfig: PatienceConfig = + PatienceConfig(timeout = scaled(Span(20, Seconds)), interval = scaled(Span(5, Millis))) + + val numExecs = 3 + val numParts = 3 + val TaskStarted = "TASK_STARTED" + val TaskEnded = "TASK_ENDED" + val JobEnded = "JOB_ENDED" + + // TODO(holden): This is shared with MapOutputTrackerSuite move to a BlockTestUtils or similar. + private def fetchDeclaredField(value: AnyRef, fieldName: String): AnyRef = { + val field = value.getClass.getDeclaredField(fieldName) + field.setAccessible(true) + field.get(value) + } + + private def lookupBlockManagerMasterEndpoint(sc: SparkContext): BlockManagerMasterEndpoint = { + val rpcEnv = sc.env.rpcEnv + val dispatcher = fetchDeclaredField(rpcEnv, "dispatcher") + fetchDeclaredField(dispatcher, "endpointRefs"). + asInstanceOf[java.util.Map[RpcEndpoint, RpcEndpointRef]].asScala. + filter(_._1.isInstanceOf[BlockManagerMasterEndpoint]). + head._1.asInstanceOf[BlockManagerMasterEndpoint] + } + + private def lookupMapOutputTrackerMaster(sc: SparkContext): MapOutputTrackerMaster = { + val bme = lookupBlockManagerMasterEndpoint(sc) + bme.getMapOutputTrackerMaster() + } + + test(s"Test that cache blocks are recorded.") { Review Comment: Inconsistent test naming. The test name includes `s"..."` string interpolation syntax but doesn't use any variables. The `s` prefix should be removed for all tests that don't use string interpolation: lines 63, 83, 115, and 151. ########## core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala: ########## @@ -249,6 +267,62 @@ class BlockManagerMasterEndpoint( context.reply(updateRDDBlockVisibility(taskId, visible)) } + private def updateBlockAtime(blockId: BlockId) = { + // First handle "regular" blocks + if (!blockId.isShuffle) { + // Only update access times if we have the cleaner enabled. + if (conf.get(config.SPARK_TTL_BLOCK_CLEANER).isDefined) { + // Note: we don't _really_ care about concurrency here too much, if we have + // conflicting updates in time they're going to "close enough" to be a wash + // so we don't bother checking the return value here. + // For now we only do RDD blocks, because I'm not convinced it's safe to TTL + // clean Broadcast blocks, but maybe we can revisit that. + blockId.asRDDId.map { r => rddAccessTime.put(r.rddId, System.currentTimeMillis()) } + } + } else if (conf.get(config.SPARK_TTL_SHUFFLE_BLOCK_CLEANER).isDefined) { + // We track shuffles in the mapoutput tracker. + blockId.asShuffleId.map(s => mapOutputTracker.updateShuffleAtime(s.shuffleId)) + } + } + + + private class TTLCleaner extends Runnable { + override def run(): Unit = { + // Poll the shuffle access times if we're configured for it. + conf.get(config.SPARK_TTL_BLOCK_CLEANER) match { + case Some(ttl) => + while (true) { + val maxAge = System.currentTimeMillis() - ttl + // Find the elements to be removed & update oldest remaining time (if any) + var oldest = System.currentTimeMillis() + val toBeRemoved = rddAccessTime.asScala.flatMap { case (rddId, atime) => Review Comment: Potential race condition: The TTL cleaner thread reads from `rddAccessTime` and calls `removeRdd()`, which also modifies `rddAccessTime`. Since `rddAccessTime` is a non-concurrent `JHashMap`, concurrent access from the cleaner thread and other threads updating access times could lead to `ConcurrentModificationException` or data corruption. Consider using a `ConcurrentHashMap` or adding proper synchronization. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
