This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 2215cef40043 [SPARK-46353][CORE] Refactor to improve `RegisterWorker` 
unit test coverage
2215cef40043 is described below

commit 2215cef40043a3205446f8daecafed8f2360a742
Author: Dongjoon Hyun <dh...@apple.com>
AuthorDate: Tue Dec 12 09:57:43 2023 -0800

    [SPARK-46353][CORE] Refactor to improve `RegisterWorker` unit test coverage
    
    ### What changes were proposed in this pull request?
    
    This PR aims to improve the unit test coverage for `RegisterWorker` message 
handling.
    
    - Add `handleRegisterWorker` helper method which is testable easily.
    - Add new unit tests for three conditional branches.
    
    ### Why are the changes needed?
    
    It's easily to test and improve. We can add more tests in this way in the 
future.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No. This is a refactoring on the main code and only additions to the test 
methods.
    
    ### How was this patch tested?
    
    Pass the CIs.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #44284 from dongjoon-hyun/SPARK-46353.
    
    Authored-by: Dongjoon Hyun <dh...@apple.com>
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
---
 .../org/apache/spark/deploy/master/Master.scala    | 75 +++++++++++++---------
 .../apache/spark/deploy/master/MasterSuite.scala   | 59 ++++++++++++++++-
 2 files changed, 102 insertions(+), 32 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala 
b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
index a550f44fc0a4..c8679c185ad7 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
@@ -37,7 +37,7 @@ import org.apache.spark.internal.config.Deploy._
 import org.apache.spark.internal.config.UI._
 import org.apache.spark.internal.config.Worker._
 import org.apache.spark.metrics.{MetricsSystem, MetricsSystemInstances}
-import org.apache.spark.resource.{ResourceProfile, ResourceRequirement, 
ResourceUtils}
+import org.apache.spark.resource.{ResourceInformation, ResourceProfile, 
ResourceRequirement, ResourceUtils}
 import org.apache.spark.rpc._
 import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, Serializer}
 import org.apache.spark.util.{SparkUncaughtExceptionHandler, ThreadUtils, 
Utils}
@@ -75,7 +75,8 @@ private[deploy] class Master(
   private val waitingApps = new ArrayBuffer[ApplicationInfo]
   val apps = new HashSet[ApplicationInfo]
 
-  private val idToWorker = new HashMap[String, WorkerInfo]
+  // Visible for testing
+  private[master] val idToWorker = new HashMap[String, WorkerInfo]
   private val addressToWorker = new HashMap[RpcAddress, WorkerInfo]
 
   private val endpointToApp = new HashMap[RpcEndpointRef, ApplicationInfo]
@@ -106,7 +107,7 @@ private[deploy] class Master(
 
   private[master] var state = RecoveryState.STANDBY
 
-  private var persistenceEngine: PersistenceEngine = _
+  private[master] var persistenceEngine: PersistenceEngine = _
 
   private var leaderElectionAgent: LeaderElectionAgent = _
 
@@ -281,33 +282,8 @@ private[deploy] class Master(
     case RegisterWorker(
       id, workerHost, workerPort, workerRef, cores, memory, workerWebUiUrl,
       masterAddress, resources) =>
-      logInfo("Registering worker %s:%d with %d cores, %s RAM".format(
-        workerHost, workerPort, cores, Utils.megabytesToString(memory)))
-      if (state == RecoveryState.STANDBY) {
-        workerRef.send(MasterInStandby)
-      } else if (idToWorker.contains(id)) {
-        if (idToWorker(id).state == WorkerState.UNKNOWN) {
-          logInfo("Worker has been re-registered: " + id)
-          idToWorker(id).state = WorkerState.ALIVE
-        }
-        workerRef.send(RegisteredWorker(self, masterWebUiUrl, masterAddress, 
true))
-      } else {
-        val workerResources =
-          resources.map(r => r._1 -> WorkerResourceInfo(r._1, 
r._2.addresses.toImmutableArraySeq))
-        val worker = new WorkerInfo(id, workerHost, workerPort, cores, memory,
-          workerRef, workerWebUiUrl, workerResources)
-        if (registerWorker(worker)) {
-          persistenceEngine.addWorker(worker)
-          workerRef.send(RegisteredWorker(self, masterWebUiUrl, masterAddress, 
false))
-          schedule()
-        } else {
-          val workerAddress = worker.endpoint.address
-          logWarning("Worker registration failed. Attempted to re-register 
worker at same " +
-            "address: " + workerAddress)
-          workerRef.send(RegisterWorkerFailed("Attempted to re-register worker 
at same address: "
-            + workerAddress))
-        }
-      }
+      handleRegisterWorker(id, workerHost, workerPort, workerRef, cores, 
memory, workerWebUiUrl,
+        masterAddress, resources)
 
     case RegisterApplication(description, driver) =>
       // TODO Prevent repeated registrations from some driver
@@ -676,6 +652,45 @@ private[deploy] class Master(
     logInfo(f"Recovery complete in ${timeTakenNs / 1000000000d}%.3fs - 
resuming operations!")
   }
 
+  private[master] def handleRegisterWorker(
+      id: String,
+      workerHost: String,
+      workerPort: Int,
+      workerRef: RpcEndpointRef,
+      cores: Int,
+      memory: Int,
+      workerWebUiUrl: String,
+      masterAddress: RpcAddress,
+      resources: Map[String, ResourceInformation]): Unit = {
+    logInfo("Registering worker %s:%d with %d cores, %s RAM".format(
+      workerHost, workerPort, cores, Utils.megabytesToString(memory)))
+    if (state == RecoveryState.STANDBY) {
+      workerRef.send(MasterInStandby)
+    } else if (idToWorker.contains(id)) {
+      if (idToWorker(id).state == WorkerState.UNKNOWN) {
+        logInfo("Worker has been re-registered: " + id)
+        idToWorker(id).state = WorkerState.ALIVE
+      }
+      workerRef.send(RegisteredWorker(self, masterWebUiUrl, masterAddress, 
true))
+    } else {
+      val workerResources =
+        resources.map(r => r._1 -> WorkerResourceInfo(r._1, 
r._2.addresses.toImmutableArraySeq))
+      val worker = new WorkerInfo(id, workerHost, workerPort, cores, memory,
+        workerRef, workerWebUiUrl, workerResources)
+      if (registerWorker(worker)) {
+        persistenceEngine.addWorker(worker)
+        workerRef.send(RegisteredWorker(self, masterWebUiUrl, masterAddress, 
false))
+        schedule()
+      } else {
+        val workerAddress = worker.endpoint.address
+        logWarning("Worker registration failed. Attempted to re-register 
worker at same " +
+          "address: " + workerAddress)
+        workerRef.send(RegisterWorkerFailed("Attempted to re-register worker 
at same address: "
+          + workerAddress))
+      }
+    }
+  }
+
   /**
    * Schedule executors to be launched on the workers.
    * Returns an array containing number of cores assigned to each worker.
diff --git 
a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala 
b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala
index 9fd1991dab02..e15a5db770eb 100644
--- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala
@@ -30,12 +30,13 @@ import scala.reflect.ClassTag
 
 import org.json4s._
 import org.json4s.jackson.JsonMethods._
-import org.mockito.ArgumentMatchers.any
-import org.mockito.Mockito.{doNothing, mock, when}
+import org.mockito.ArgumentMatchers.{any, eq => meq}
+import org.mockito.Mockito.{doNothing, mock, times, verify, when}
 import org.scalatest.{BeforeAndAfter, PrivateMethodTester}
 import org.scalatest.concurrent.Eventually
 import org.scalatest.matchers.must.Matchers
 import org.scalatest.matchers.should.Matchers._
+import org.scalatestplus.mockito.MockitoSugar.{mock => smock}
 import other.supplier.{CustomPersistenceEngine, CustomRecoveryModeFactory}
 
 import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
@@ -1373,6 +1374,60 @@ class MasterSuite extends SparkFunSuite
         eventLogCodec = None)
     assert(master.invokePrivate(_createApplication(desc, null)).id === 
"spark-45756")
   }
+
+  test("SPARK-46353: handleRegisterWorker in STANDBY mode") {
+    val master = makeMaster()
+    val masterRpcAddress = smock[RpcAddress]
+    val worker = smock[RpcEndpointRef]
+
+    assert(master.state === RecoveryState.STANDBY)
+    master.handleRegisterWorker("worker-0", "localhost", 1024, worker, 10, 
4096,
+      "http://localhost:8081";, masterRpcAddress, Map.empty)
+    verify(worker, times(1)).send(meq(MasterInStandby))
+    verify(worker, times(0))
+      .send(meq(RegisteredWorker(master.self, null, masterRpcAddress, 
duplicate = true)))
+    verify(worker, times(0))
+      .send(meq(RegisteredWorker(master.self, null, masterRpcAddress, 
duplicate = false)))
+    assert(master.workers.isEmpty)
+    assert(master.idToWorker.isEmpty)
+  }
+
+  test("SPARK-46353: handleRegisterWorker in RECOVERING mode without workers") 
{
+    val master = makeMaster()
+    val masterRpcAddress = smock[RpcAddress]
+    val worker = smock[RpcEndpointRef]
+
+    master.state = RecoveryState.RECOVERING
+    master.persistenceEngine = new BlackHolePersistenceEngine()
+    master.handleRegisterWorker("worker-0", "localhost", 1024, worker, 10, 
4096,
+      "http://localhost:8081";, masterRpcAddress, Map.empty)
+    verify(worker, times(0)).send(meq(MasterInStandby))
+    verify(worker, times(1))
+      .send(meq(RegisteredWorker(master.self, null, masterRpcAddress, 
duplicate = false)))
+    assert(master.workers.size === 1)
+    assert(master.idToWorker.size === 1)
+  }
+
+  test("SPARK-46353: handleRegisterWorker in RECOVERING mode with a unknown 
worker") {
+    val master = makeMaster()
+    val masterRpcAddress = smock[RpcAddress]
+    val worker = smock[RpcEndpointRef]
+    val workerInfo = smock[WorkerInfo]
+    when(workerInfo.state).thenReturn(WorkerState.UNKNOWN)
+
+    master.state = RecoveryState.RECOVERING
+    master.workers.add(workerInfo)
+    master.idToWorker("worker-0") = workerInfo
+    master.persistenceEngine = new BlackHolePersistenceEngine()
+    master.handleRegisterWorker("worker-0", "localhost", 1024, worker, 10, 
4096,
+      "http://localhost:8081";, masterRpcAddress, Map.empty)
+    verify(worker, times(0)).send(meq(MasterInStandby))
+    verify(worker, times(1))
+      .send(meq(RegisteredWorker(master.self, null, masterRpcAddress, 
duplicate = true)))
+    assert(master.state === RecoveryState.RECOVERING)
+    assert(master.workers.nonEmpty)
+    assert(master.idToWorker.nonEmpty)
+  }
 }
 
 private class FakeRecoveryModeFactory(conf: SparkConf, ser: 
serializer.Serializer)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to