This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit 23284cd7c0bbe51cb43a759a9df0234a77a96689
Author: baunsgaard <baunsga...@tugraz.at>
AuthorDate: Mon Nov 15 19:04:54 2021 +0100

    [SYSTEMDS-3220] Stabilize python context creation
    
    This commit adds multiple checks and retries to the python context
    creation to make the construction stable in various stress conditions
    such as multiple contexts started and open, same port used, parallel
    thread allocation, etc.
    
    Also contained is a small fix to python Lists, to enable
    easy read and write of these with type support, instead of having to
    cast from operation_node type.
---
 src/main/java/org/apache/sysds/api/DMLOptions.java |  13 +-
 src/main/java/org/apache/sysds/api/DMLScript.java  |   2 +-
 .../java/org/apache/sysds/api/PythonDMLScript.java | 126 +++++----
 .../python/systemds/context/systemds_context.py    | 281 ++++++++++++---------
 .../sysds/test/usertest/pythonapi/StartupTest.java |  29 ++-
 5 files changed, 254 insertions(+), 197 deletions(-)

diff --git a/src/main/java/org/apache/sysds/api/DMLOptions.java 
b/src/main/java/org/apache/sysds/api/DMLOptions.java
index 5e10d73..1b911bc 100644
--- a/src/main/java/org/apache/sysds/api/DMLOptions.java
+++ b/src/main/java/org/apache/sysds/api/DMLOptions.java
@@ -71,6 +71,7 @@ public class DMLOptions {
        public boolean              lineage_debugger = false;         // 
whether enable lineage debugger
        public boolean              fedWorker     = false;
        public int                  fedWorkerPort = -1;
+       public int                  pythonPort    = -1; 
        public boolean              checkPrivacy  = false;            // Check 
which privacy constraints are loaded and checked during federated execution 
        public boolean                          federatedCompilation = false;   
  // Compile federated instructions based on input federation state and privacy 
constraints.
 
@@ -242,6 +243,10 @@ public class DMLOptions {
                        }
                }
 
+               if (line.hasOption("python")){
+                       dmlOptions.pythonPort = 
Integer.parseInt(line.getOptionValue("python"));
+               }
+
                // Named arguments map is created as ("$K, 123), ("$X", 
"X.csv"), etc
                if (line.hasOption("nvargs")){
                        String varNameRegex = "^[a-zA-Z]([a-zA-Z0-9_])*$";
@@ -302,8 +307,8 @@ public class DMLOptions {
                        .hasOptionalArg().create("gpu");
                Option debugOpt = OptionBuilder.withDescription("runs in debug 
mode; default off")
                        .create("debug");
-               Option pythonOpt = OptionBuilder.withDescription("parses 
Python-like DML")
-                       .create("python");
+               Option pythonOpt = OptionBuilder.withDescription("Python 
Context start with port argument for communication to python")
+                       .isRequired().hasArg().create("python");
                Option fileOpt = OptionBuilder.withArgName("filename")
                        .withDescription("specifies dml/pydml file to execute; 
path can be local/hdfs/gpfs (prefixed with appropriate URI)")
                        .isRequired().hasArg().create("f");
@@ -332,7 +337,6 @@ public class DMLOptions {
                options.addOption(execOpt);
                options.addOption(gpuOpt);
                options.addOption(debugOpt);
-               options.addOption(pythonOpt);
                options.addOption(lineageOpt);
                options.addOption(fedOpt);
                options.addOption(checkPrivacy);
@@ -344,7 +348,8 @@ public class DMLOptions {
                        .addOption(fileOpt)
                        .addOption(cleanOpt)
                        .addOption(helpOpt)
-                       .addOption(fedOpt);
+                       .addOption(fedOpt)
+                       .addOption(pythonOpt);
                fileOrScriptOpt.setRequired(true);
                options.addOptionGroup(fileOrScriptOpt);
                
diff --git a/src/main/java/org/apache/sysds/api/DMLScript.java 
b/src/main/java/org/apache/sysds/api/DMLScript.java
index 658b993..ebb7f3d 100644
--- a/src/main/java/org/apache/sysds/api/DMLScript.java
+++ b/src/main/java/org/apache/sysds/api/DMLScript.java
@@ -356,7 +356,7 @@ public class DMLScript
        // (core compilation and execute)
        ////////
 
-       private static void loadConfiguration(String fnameOptConfig) throws 
IOException {
+       public static void loadConfiguration(String fnameOptConfig) throws 
IOException {
                DMLConfig dmlconf = 
DMLConfig.readConfigurationFile(fnameOptConfig);
                ConfigurationManager.setGlobalConfig(dmlconf);
                CompilerConfig cconf = 
OptimizerUtils.constructCompilerConfig(dmlconf);
diff --git a/src/main/java/org/apache/sysds/api/PythonDMLScript.java 
b/src/main/java/org/apache/sysds/api/PythonDMLScript.java
index e7251e7..d93409d 100644
--- a/src/main/java/org/apache/sysds/api/PythonDMLScript.java
+++ b/src/main/java/org/apache/sysds/api/PythonDMLScript.java
@@ -26,40 +26,29 @@ import org.apache.sysds.conf.CompilerConfig;
 
 import py4j.GatewayServer;
 import py4j.GatewayServerListener;
+import py4j.Py4JNetworkException;
 import py4j.Py4JServerConnection;
 
 public class PythonDMLScript {
-       private static final Log LOG = 
LogFactory.getLog(PythonDMLScript.class.getName());
+
        private Connection _connection;
 
        /**
         * Entry point for Python API.
         * 
-        * The system returns with exit code 1, if the startup process fails, 
and 0 if the startup was successful.
-        * 
         * @param args Command line arguments.
+        * @throws Exception Throws exceptions if there is issues in startup or 
while running.
         */
-       public static void main(String[] args) {
-               if(args.length != 1) {
-                       throw new IllegalArgumentException("Python DML Script 
should be initialized with a singe number argument");
-               }
-               else {
-                       int port = Integer.parseInt(args[0]);
-                       start(port);
-               }
+       public static void main(String[] args) throws Exception {
+               final DMLOptions dmlOptions = DMLOptions.parseCLArguments(args);
+               DMLScript.loadConfiguration(dmlOptions.configFile);
+               start(dmlOptions.pythonPort);
        }
 
-       private static void start(int port) {
-               try {
-                       // TODO Add argument parsing here.
-                       GatewayServer GwS = new GatewayServer(new 
PythonDMLScript(), port);
-                       GwS.addListener(new DMLGateWayListener());
-                       GwS.start();
-               }
-               catch(py4j.Py4JNetworkException ex) {
-                       LOG.error("Py4JNetworkException while executing the 
GateWay. Is a server instance already running?");
-                       System.exit(-1);
-               }
+       private static void start(int port) throws Py4JNetworkException {
+               GatewayServer GwS = new GatewayServer(new PythonDMLScript(), 
port);
+               GwS.addListener(new DMLGateWayListener());
+               GwS.start();
        }
 
        private PythonDMLScript() {
@@ -79,50 +68,53 @@ public class PythonDMLScript {
        public Connection getConnection() {
                return _connection;
        }
-}
-
-class DMLGateWayListener implements GatewayServerListener {
-       private static final Log LOG = 
LogFactory.getLog(DMLGateWayListener.class.getName());
-
-       @Override
-       public void connectionError(Exception e) {
-               LOG.warn("Connection error: " + e.getMessage());
-       }
-
-       @Override
-       public void connectionStarted(Py4JServerConnection gatewayConnection) {
-               LOG.debug("Connection Started: " + 
gatewayConnection.toString());
-       }
-
-       @Override
-       public void connectionStopped(Py4JServerConnection gatewayConnection) {
-               LOG.debug("Connection stopped: " + 
gatewayConnection.toString());
-       }
-
-       @Override
-       public void serverError(Exception e) {
-               LOG.error("Server Error " + e.getMessage());
-       }
-
-       @Override
-       public void serverPostShutdown() {
-               LOG.info("Shutdown done");
-               System.exit(0);
-       }
-
-       @Override
-       public void serverPreShutdown() {
-               LOG.info("Starting JVM shutdown");
-       }
-
-       @Override
-       public void serverStarted() {
-               // message the python interface that the JVM is ready.
-               System.out.println("GatewayServer Started");
-       }
-
-       @Override
-       public void serverStopped() {
-               System.out.println("GatewayServer Stopped");
+       
+       protected static class DMLGateWayListener implements 
GatewayServerListener {
+               private static final Log LOG = 
LogFactory.getLog(DMLGateWayListener.class.getName());
+       
+               @Override
+               public void connectionError(Exception e) {
+                       LOG.warn("Connection error: " + e.getMessage());
+                       System.exit(1);
+               }
+       
+               @Override
+               public void connectionStarted(Py4JServerConnection 
gatewayConnection) {
+                       LOG.debug("Connection Started: " + 
gatewayConnection.toString());
+               }
+       
+               @Override
+               public void connectionStopped(Py4JServerConnection 
gatewayConnection) {
+                       LOG.debug("Connection stopped: " + 
gatewayConnection.toString());
+               }
+       
+               @Override
+               public void serverError(Exception e) {
+                       LOG.error("Server Error " + e.getMessage());
+               }
+       
+               @Override
+               public void serverPostShutdown() {
+                       LOG.info("Shutdown done");
+                       System.exit(0);
+               }
+       
+               @Override
+               public void serverPreShutdown() {
+                       LOG.info("Starting JVM shutdown");
+               }
+       
+               @Override
+               public void serverStarted() {
+                       // message the python interface that the JVM is ready.
+                       System.out.println("GatewayServer Started");
+               }
+       
+               @Override
+               public void serverStopped() {
+                       System.out.println("GatewayServer Stopped");
+                       System.exit(0);
+               }
        }
 }
+
diff --git a/src/main/python/systemds/context/systemds_context.py 
b/src/main/python/systemds/context/systemds_context.py
index afa38c2..14f7260 100644
--- a/src/main/python/systemds/context/systemds_context.py
+++ b/src/main/python/systemds/context/systemds_context.py
@@ -35,7 +35,6 @@ from typing import Dict, Iterable, Sequence, Tuple, Union
 import numpy as np
 import pandas as pd
 from py4j.java_gateway import GatewayParameters, JavaGateway
-from py4j.protocol import Py4JNetworkError
 from systemds.operator import (Frame, List, Matrix, OperationNode, Scalar,
                                Source)
 from systemds.script_building import OutputType
@@ -60,26 +59,13 @@ class SystemDSContext(object):
         Standard out and standard error form the JVM is also handled in this 
class, filling up Queues,
         that can be read from to get the printed statements from the JVM.
         """
-        command = self.__build_startup_command()
-        process, port = self.__try_startup(command, port)
-
-        # Handle Std out from the subprocess.
-        self.__stdout = Queue()
-        self.__stderr = Queue()
-
-        self.__stdout_thread = Thread(target=self.__enqueue_output, args=(
-            process.stdout, self.__stdout), daemon=True)
-
-        self.__stderr_thread = Thread(target=self.__enqueue_output, args=(
-            process.stderr, self.__stderr), daemon=True)
-
-        self.__stdout_thread.start()
-        self.__stderr_thread.start()
-
-        # Py4j connect to the started process.
-        gwp = GatewayParameters(port=port, eager_load=True)
-        self.java_gateway = JavaGateway(
-            gateway_parameters=gwp, java_process=process)
+        actual_port = self.__start(port)
+        process = self.__process
+        if process.poll() is None:
+            self.__start_gateway(actual_port)
+        else:
+            self.exception_and_close(
+                "Java process stopped before gateway could connect")
 
     def get_stdout(self, lines: int = -1):
         """Getter for the stdout of the java subprocess
@@ -103,14 +89,13 @@ class SystemDSContext(object):
         else:
             return [self.__stderr.get() for x in range(lines)]
 
-    def exception_and_close(self, exception_str: str, trace_back_limit : int = 
None):
+    def exception_and_close(self, exception_str: str, trace_back_limit: int = 
None):
         """
         Method for printing exception, printing stdout and error, while also 
closing the context correctly.
 
         :param e: the exception thrown
         """
 
-        # e = sys.exc_info()[0]
         message = ""
         stdOut = self.get_stdout()
         if stdOut:
@@ -118,101 +103,163 @@ class SystemDSContext(object):
         stdErr = self.get_stderr()
         if stdErr:
             message += "standard error  :\n" + "\n".join(stdErr)
+        message += "\n\n"
         message += exception_str
         sys.tracebacklimit = trace_back_limit
         self.close()
         raise RuntimeError(message)
 
-    def __try_startup(self, command, port, rep=0):
-        """ Try to perform startup of system.
+    def __try_startup(self, command) -> bool:
 
-        :param command: The command to execute for starting JMLC content
-        :param port: The port to try to connect to to.
-        :param rep: The number of repeated tries to startup the jvm.
-        """
-        if port == -1:
-            assignedPort = self.__get_open_port()
-        elif rep == 0:
-            assignedPort = port
-        else:
-            assignedPort = self.__get_open_port()
-        fullCommand = []
-        fullCommand.extend(command)
-        fullCommand.append(str(assignedPort))
-        process = Popen(fullCommand, stdout=PIPE, stdin=PIPE, stderr=PIPE)
+        self.__process = Popen(command, stdout=PIPE, stdin=PIPE, stderr=PIPE)
 
-        try:
-            self.__verify_startup(process)
-
-            return process, assignedPort
-        except Exception as e:
-            self.close()
-            if rep > 3:
-                raise Exception(
-                    "Failed to start SystemDS context with " + str(rep) + " 
repeated tries")
-            else:
-                rep += 1
-                print("Failed to startup JVM process, retrying: " + str(rep))
-                sleep(0.5)
-                return self.__try_startup(command, port, rep)
-
-    def __verify_startup(self, process):
-        first_stdout = process.stdout.readline()
-        if(not b"GatewayServer Started" in first_stdout):
-            stderr = process.stderr.readline().decode("utf-8")
-            if(len(stderr) > 1):
-                raise Exception(
-                    "Exception in startup of GatewayServer: " + stderr)
-            outputs = []
-            outputs.append(first_stdout.decode("utf-8"))
-            max_tries = 10
-            for i in range(max_tries):
-                next_line = process.stdout.readline()
-                if(b"GatewayServer Started" in next_line):
-                    print("WARNING: Stdout corrupted by prints: " + 
str(outputs))
-                    print("Startup success")
-                    break
-                else:
-                    outputs.append(next_line)
-
-                if (i == max_tries-1):
-                    raise Exception("Error in startup of systemDS gateway 
process: \n gateway StdOut: " + str(
-                        outputs) + " \n gateway StdErr" + 
process.stderr.readline().decode("utf-8"))
-
-    def __build_startup_command(self):
+        # Handle Std out from the subprocess.
+        self.__stdout = Queue()
+        self.__stderr = Queue()
+
+        self.__stdout_thread = Thread(target=self.__enqueue_output, args=(
+            self.__process.stdout, self.__stdout), daemon=True)
+
+        self.__stderr_thread = Thread(target=self.__enqueue_output, args=(
+            self.__process.stderr, self.__stderr), daemon=True)
+
+        self.__stdout_thread.start()
+        self.__stderr_thread.start()
+
+        return self.__verify_startup(command)
+
+    def __verify_startup(self, command) -> bool:
+        first_stdout = self.get_stdout()
+        if(not "GatewayServer Started" in first_stdout):
+            return self.__verify_startup_retry(command)
+        else:
+            return True
+
+    def __verify_startup_retry(self, command,  retry: int = 1) -> bool:
+        sleep(0.8 * retry)
+        stdout = self.get_stdout()
+        if "GatewayServer Started" in stdout:
+            return True, ""
+        elif retry < 3:  # retry 3 times
+            return self.__verify_startup_retry(command, retry + 1)
+        else:
+            error_message = "Error in startup of systemDS gateway process:"
+            error_message += "\n" + " ".join(command)
+            stderr = self.get_stderr()
+            if len(stderr) > 0:
+                error_message += "\n" + "\n".join(stderr)
+            if len(stdout) > 0:
+                error_message += "\n\n" + "\n".join(stdout)
+            self.__error_message = error_message
+            return False
+
+    def __build_startup_command(self, port: int):
 
         command = ["java", "-cp"]
         root = os.environ.get("SYSTEMDS_ROOT")
         if root == None:
             # If there is no systemds install default to use the PIP packaged 
java files.
-            root = os.path.join(get_module_dir(), "systemds-java")
+            root = os.path.join(get_module_dir())
 
         # nt means its Windows
         cp_separator = ";" if os.name == "nt" else ":"
 
         if os.environ.get("SYSTEMDS_ROOT") != None:
-            lib_cp = os.path.join(root, "target", "lib", "*")
-            systemds_cp = os.path.join(root, "target", "SystemDS.jar")
-            classpath = cp_separator.join([lib_cp, systemds_cp])
-
-            command.append(classpath)
-            files = glob(os.path.join(root, "conf", "log4j*.properties"))
-            if len(files) > 1:
-                print(
-                    "WARNING: Multiple logging files found selecting: " + 
files[0])
-            if len(files) == 0:
-                print("WARNING: No log4j file found at: "
-                      + os.path.join(root, "conf")
-                      + " therefore using default settings")
+            lib_release = os.path.join(root, "lib")
+            lib_cp = os.path.join(root, "target", "lib")
+            if os.path.exists(lib_release):
+                classpath = cp_separator.join([os.path.join(lib_release, '*')])
+            elif os.path.exists(lib_cp):
+                systemds_cp = os.path.join(root, "target", "SystemDS.jar")
+                classpath = cp_separator.join(
+                    [os.path.join(lib_cp, '*'), systemds_cp])
             else:
-                command.append("-Dlog4j.configuration=file:" + files[0])
+                raise ValueError(
+                    "Invalid setup at SYSTEMDS_ROOT env variable path")
+        else:
+            lib1 = os.path.join(root, "lib", "*")
+            lib2 = os.path.join(root, "lib")
+            classpath = cp_separator.join([lib1, lib2])
+
+        command.append(classpath)
+
+        files = glob(os.path.join(root, "conf", "log4j*.properties"))
+        if len(files) > 1:
+            print(
+                "WARNING: Multiple logging files found selecting: " + files[0])
+        if len(files) == 0:
+            print("WARNING: No log4j file found at: "
+                  + os.path.join(root, "conf")
+                  + " therefore using default settings")
         else:
-            lib_cp = os.path.join(root, "lib", "*")
-            command.append(lib_cp)
+            command.append("-Dlog4j.configuration=file:" + files[0])
 
         command.append("org.apache.sysds.api.PythonDMLScript")
 
-        return command
+        files = glob(os.path.join(root, "conf", "SystemDS*.xml"))
+        if len(files) > 1:
+            print(
+                "WARNING: Multiple config files found selecting: " + files[0])
+        if len(files) == 0:
+            print("WARNING: No log4j file found at: "
+                  + os.path.join(root, "conf")
+                  + " therefore using default settings")
+        else:
+            command.append("-config")
+            command.append(files[0])
+
+        if port == -1:
+            actual_port = self.__get_open_port()
+        else:
+            actual_port = port
+
+        command.append("--python")
+        command.append(str(actual_port))
+
+        return command, actual_port
+
+    def __start(self, port: int):
+        command, actual_port = self.__build_startup_command(port)
+        success = self.__try_startup(command)
+
+        if not success:
+            retry = 1
+            while not success and retry < 3:
+                self.__kill_Popen(self.__process)
+                # retry after waiting a bit.
+                sleep(3 * retry)
+                self.close()
+                self.__error_message = None
+                success, command, actual_port = self.__retry_start(retry)
+                retry = retry + 1
+            if not success:
+                self.exception_and_close(self.__error_message)
+        return actual_port
+
+    def __retry_start(self, ret):
+        command, actual_port = self.__build_startup_command(-1)
+        success = self.__try_startup(command)
+        return success, command, actual_port
+
+    def __start_gateway(self, actual_port: int):
+        process = self.__process
+        gwp = GatewayParameters(port=actual_port, eager_load=True)
+        self.__retry_start_gateway(process, gwp)
+
+    def __retry_start_gateway(self, process: Popen, gwp: GatewayParameters, 
retry: int = 0):
+        try:
+            self.java_gateway = JavaGateway(
+                gateway_parameters=gwp, java_process=process)
+            self.__process = None  # On success clear process variable
+            return
+        except:
+            sleep(3 * retry)
+            if retry < 3:
+                self.__retry_start_gateway(process, gwp, retry + 1)
+                return
+            else:
+                e = "Error in startup of Java Gateway"
+        self.exception_and_close(e)
 
     def __enter__(self):
         return self
@@ -224,26 +271,28 @@ class SystemDSContext(object):
 
     def close(self):
         """Close the connection to the java process and do necessary 
cleanup."""
-        if(self.__stdout_thread.is_alive()):
+        if hasattr(self, 'java_gateway'):
+            self.__kill_Popen(self.java_gateway.java_process)
+            self.java_gateway.shutdown()
+        if hasattr(self, '__process'):
+            print("Has process variable")
+            self.__kill_Popen(self.__process)
+        if hasattr(self, '__stdout_thread') and 
self.__stdout_thread.is_alive():
             self.__stdout_thread.join(0)
-        if(self.__stdout_thread.is_alive()):
+        if hasattr(self, '__stderr_thread') and 
self.__stderr_thread.is_alive():
             self.__stderr_thread.join(0)
 
-        pid = self.java_gateway.java_process.pid
-        if self.java_gateway.java_gateway_server is not None:
-            try:
-                self.java_gateway.shutdown(True)
-            except Py4JNetworkError as e:
-                if "Gateway is not connected" not in str(e):
-                    self.java_gateway.java_process.kill()
-        os.kill(pid, 14)
+    def __kill_Popen(self, process: Popen):
+        process.kill()
+        process.__exit__(None, None, None)
 
     def __enqueue_output(self, out, queue):
         """Method for handling the output from java.
         It is locating the string handeling inside a different thread, since 
the 'out.readline' is a blocking command.
         """
         for line in iter(out.readline, b""):
-            queue.put(line.decode("utf-8").strip())
+            line_string = line.decode("utf-8")
+            queue.put(line_string.strip())
 
     def __get_open_port(self):
         """Get a random available port.
@@ -291,7 +340,7 @@ class SystemDSContext(object):
     def rand(self, rows: int, cols: int,
              min: Union[float, int] = None, max: Union[float, int] = None, 
pdf: str = "uniform",
              sparsity: Union[float, int] = None, seed: Union[float, int] = 
None,
-             lambd: Union[float, int] = 1) -> 'Matrix':
+             lamb: Union[float, int] = 1) -> 'Matrix':
         """Generates a matrix filled with random values
 
         :param sds_context: SystemDS context
@@ -299,26 +348,26 @@ class SystemDSContext(object):
         :param cols: number of cols
         :param min: min value for cells
         :param max: max value for cells
-        :param pdf: "uniform"/"normal"/"poison" distribution
+        :param pdf: probability distribution function: 
"uniform"/"normal"/"poison" distribution
         :param sparsity: fraction of non-zero cells
         :param seed: random seed
-        :param lambd: lamda value for "poison" distribution
+        :param lamb: lambda value for "poison" distribution
         :return:
         """
-        available_pdfs = ["uniform", "normal", "poisson"]
+        available_pdf = ["uniform", "normal", "poisson"]
         if rows < 0:
             raise ValueError("In rand statement, can only assign rows a long 
(integer) value >= 0 "
                              "-- attempted to assign value: 
{r}".format(r=rows))
         if cols < 0:
             raise ValueError("In rand statement, can only assign cols a long 
(integer) value >= 0 "
                              "-- attempted to assign value: 
{c}".format(c=cols))
-        if pdf not in available_pdfs:
+        if pdf not in available_pdf:
             raise ValueError("The pdf passed is invalid! given: {g}, expected: 
{e}".format(
-                g=pdf, e=available_pdfs))
+                g=pdf, e=available_pdf))
 
         pdf = '\"' + pdf + '\"'
         named_input_nodes = {
-            'rows': rows, 'cols': cols, 'pdf': pdf, 'lambda': lambd}
+            'rows': rows, 'cols': cols, 'pdf': pdf, 'lambda': lamb}
         if min is not None:
             named_input_nodes['min'] = min
         if max is not None:
@@ -357,7 +406,11 @@ class SystemDSContext(object):
             output_type = OutputType.from_str(kwargs.get("value_type", None))
             kwargs["value_type"] = f'"{output_type.name}"'
             return Scalar(self, "read", [f'"{path}"'], 
named_input_nodes=kwargs, output_type=output_type)
+        elif data_type == "list":
+            # Reading a list have no extra arguments.
+            return List(self, "read", [f'"{path}"'])
 
+        kwargs["data_type"] = None
         print("WARNING: Unknown type read please add a mtd file, or specify in 
arguments")
         return OperationNode(self, "read", [f'"{path}"'], 
named_input_nodes=kwargs)
 
diff --git 
a/src/test/java/org/apache/sysds/test/usertest/pythonapi/StartupTest.java 
b/src/test/java/org/apache/sysds/test/usertest/pythonapi/StartupTest.java
index 0106099..585c6ef 100644
--- a/src/test/java/org/apache/sysds/test/usertest/pythonapi/StartupTest.java
+++ b/src/test/java/org/apache/sysds/test/usertest/pythonapi/StartupTest.java
@@ -25,29 +25,36 @@ import org.junit.Test;
 /** Simple tests to verify startup of Python Gateway server happens without 
crashes */
 public class StartupTest {
 
-       @Test(expected = IllegalArgumentException.class)
-       public void testStartupIncorrect_1() {
+       @Test(expected = Exception.class)
+       public void testStartupIncorrect_1() throws Exception {
                PythonDMLScript.main(new String[] {});
        }
 
-       @Test(expected = IllegalArgumentException.class)
-       public void testStartupIncorrect_2() {
+       @Test(expected = Exception.class)
+       public void testStartupIncorrect_2() throws Exception {
                PythonDMLScript.main(new String[] {""});
        }
 
-       @Test(expected = IllegalArgumentException.class)
-       public void testStartupIncorrect_3() {
+       @Test(expected = Exception.class)
+       public void testStartupIncorrect_3() throws Exception {
                PythonDMLScript.main(new String[] {"131", "131"});
        }
 
-       @Test(expected = NumberFormatException.class)
-       public void testStartupIncorrect_4() {
+       @Test(expected = Exception.class)
+       public void testStartupIncorrect_4() throws Exception {
                PythonDMLScript.main(new String[] {"Hello"});
        }
 
-       @Test(expected = IllegalArgumentException.class)
-       public void testStartupIncorrect_5() {
+       @Test(expected = Exception.class)
+       public void testStartupIncorrect_5() throws Exception {
                // Number out of range
-               PythonDMLScript.main(new String[] {"918757"});
+               PythonDMLScript.main(new String[] {"-python", "918757"});
+       }
+
+       @Test(expected = Exception.class)
+       public void testStartupCorrectButTwice() throws Exception {
+               // crash if you start two instances on same port.
+               PythonDMLScript.main(new String[] {"-python", "8142"});
+               PythonDMLScript.main(new String[] {"-python", "8142"});
        }
 }

Reply via email to