Each lock is held per node. The lock assures that multiple connections
to the same node don't execute anything at the same time, removing the
possibility of race conditions.

Signed-off-by: Juraj Linkeš <juraj.lin...@pantheon.tech>
---
 dts/framework/ssh_pexpect.py | 15 +++++--
 dts/framework/utils.py       | 81 ++++++++++++++++++++++++++++++++++++
 2 files changed, 92 insertions(+), 4 deletions(-)

diff --git a/dts/framework/ssh_pexpect.py b/dts/framework/ssh_pexpect.py
index bccc6fae94..cbdbb91b64 100644
--- a/dts/framework/ssh_pexpect.py
+++ b/dts/framework/ssh_pexpect.py
@@ -3,7 +3,7 @@
 from pexpect import pxssh
 
 from .exception import SSHConnectionException, SSHSessionDeadException, 
TimeoutException
-from .utils import GREEN, RED
+from .utils import GREEN, RED, parallel_lock
 
 """
 Module handles ssh sessions to TG and SUT.
@@ -12,7 +12,7 @@
 
 
 class SSHPexpect:
-    def __init__(self, node, username, password):
+    def __init__(self, node, username, password, sut_id):
         self.magic_prompt = "MAGIC PROMPT"
         self.logger = None
 
@@ -20,11 +20,18 @@ def __init__(self, node, username, password):
         self.username = username
         self.password = password
 
-        self._connect_host()
+        self._connect_host(sut_id=sut_id)
 
-    def _connect_host(self):
+    @parallel_lock(num=8)
+    def _connect_host(self, sut_id=0):
         """
         Create connection to assigned node.
+        Parameter sut_id will be used in parallel_lock thus can assure
+        isolated locks for each node.
+        Parallel ssh connections are limited to MaxStartups option in SSHD
+        configuration file. By default concurrent number is 10, so default
+        threads number is limited to 8 which less than 10. Lock number can
+        be modified along with MaxStartups value.
         """
         retry_times = 10
         try:
diff --git a/dts/framework/utils.py b/dts/framework/utils.py
index 0ffd992952..a8e739f7b2 100644
--- a/dts/framework/utils.py
+++ b/dts/framework/utils.py
@@ -2,6 +2,87 @@
 # Copyright(c) 2010-2014 Intel Corporation
 #
 
+import threading
+from functools import wraps
+
+
+def parallel_lock(num=1):
+    """
+    Wrapper function for protect parallel threads, allow multiple threads
+    share one lock. Locks are created based on function name. Thread locks are
+    separated between SUTs according to argument 'sut_id'.
+    Parameter:
+        num: Number of parallel threads for the lock
+    """
+    global locks_info
+
+    def decorate(func):
+        @wraps(func)
+        def wrapper(*args, **kwargs):
+            if "sut_id" in kwargs:
+                sut_id = kwargs["sut_id"]
+            else:
+                sut_id = 0
+
+            # in case function arguments is not correct
+            if sut_id >= len(locks_info):
+                sut_id = 0
+
+            lock_info = locks_info[sut_id]
+            uplock = lock_info["update_lock"]
+
+            name = func.__name__
+            uplock.acquire()
+
+            if name not in lock_info:
+                lock_info[name] = dict()
+                lock_info[name]["lock"] = threading.RLock()
+                lock_info[name]["current_thread"] = 1
+            else:
+                lock_info[name]["current_thread"] += 1
+
+            lock = lock_info[name]["lock"]
+
+            # make sure when owned global lock, should also own update lock
+            if lock_info[name]["current_thread"] >= num:
+                if lock._is_owned():
+                    print(
+                        RED(
+                            "SUT%d %s waiting for func lock %s"
+                            % (sut_id, threading.current_thread().name, 
func.__name__)
+                        )
+                    )
+                lock.acquire()
+            else:
+                uplock.release()
+
+            try:
+                ret = func(*args, **kwargs)
+            except Exception as e:
+                if not uplock._is_owned():
+                    uplock.acquire()
+
+                if lock._is_owned():
+                    lock.release()
+                    lock_info[name]["current_thread"] = 0
+                uplock.release()
+                raise e
+
+            if not uplock._is_owned():
+                uplock.acquire()
+
+            if lock._is_owned():
+                lock.release()
+                lock_info[name]["current_thread"] = 0
+
+            uplock.release()
+
+            return ret
+
+        return wrapper
+
+    return decorate
+
 
 def RED(text):
     return "\x1B[" + "31;1m" + str(text) + "\x1B[" + "0m"
-- 
2.20.1

Reply via email to