This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 0386fe1917 [MINOR] Fix in Python API GatewayServerListener
0386fe1917 is described below

commit 0386fe1917fd71df6008f24c4f000efd0ffece0f
Author: e-strauss <92718421+e-stra...@users.noreply.github.com>
AuthorDate: Mon Mar 31 15:22:01 2025 +0200

    [MINOR] Fix in Python API GatewayServerListener
    
    The current inner class DMLGateWayListener implements function from the 
GatewayServerListener interface, which are never invoked by the GatewayServer 
(since the GatewayServer, which also implements GatewayServerListener, does not 
implement these methods. Furthermore, DMLGateWayListener previously called, 
Sys.exit(), which I think is not correct, since it breaks the proper shutdown 
of the GatewayServer. Finally, this commit added a new unit case, which checks 
the functionality of the D [...]
    
    Closes #2243
---
 .../java/org/apache/sysds/api/PythonDMLScript.java | 42 ++++-------
 .../sysds/test/usertest/pythonapi/StartupTest.java | 88 ++++++++++++++++++++++
 2 files changed, 101 insertions(+), 29 deletions(-)

diff --git a/src/main/java/org/apache/sysds/api/PythonDMLScript.java 
b/src/main/java/org/apache/sysds/api/PythonDMLScript.java
index c4957d4e9f..80f5ffcd75 100644
--- a/src/main/java/org/apache/sysds/api/PythonDMLScript.java
+++ b/src/main/java/org/apache/sysds/api/PythonDMLScript.java
@@ -21,17 +21,20 @@ package org.apache.sysds.api;
 
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
+import org.apache.log4j.Level;
+import org.apache.log4j.Logger;
 import org.apache.sysds.api.jmlc.Connection;
 
+import py4j.DefaultGatewayServerListener;
 import py4j.GatewayServer;
-import py4j.GatewayServerListener;
 import py4j.Py4JNetworkException;
-import py4j.Py4JServerConnection;
+
 
 public class PythonDMLScript {
 
        private static final Log LOG = 
LogFactory.getLog(PythonDMLScript.class.getName());
        final private Connection _connection;
+       public static GatewayServer GwS;
 
        /**
         * Entry point for Python API.
@@ -42,7 +45,7 @@ public class PythonDMLScript {
        public static void main(String[] args) throws Exception {
                final DMLOptions dmlOptions = DMLOptions.parseCLArguments(args);
                DMLScript.loadConfiguration(dmlOptions.configFile);
-               final GatewayServer GwS = new GatewayServer(new 
PythonDMLScript(), dmlOptions.pythonPort);
+               GwS = new GatewayServer(new PythonDMLScript(), 
dmlOptions.pythonPort);
                GwS.addListener(new DMLGateWayListener());
                try {
                        GwS.start();
@@ -67,38 +70,20 @@ public class PythonDMLScript {
                _connection = new Connection();
        }
 
+       public static void setDMLGateWayListenerLoggerLevel(Level l){
+               Logger.getLogger(DMLGateWayListener.class).setLevel(l);
+       }
+
        public Connection getConnection() {
                return _connection;
        }
 
-       protected static class DMLGateWayListener implements 
GatewayServerListener {
+       protected static class DMLGateWayListener extends 
DefaultGatewayServerListener {
                private static final Log LOG = 
LogFactory.getLog(DMLGateWayListener.class.getName());
 
-               @Override
-               public void connectionError(Exception e) {
-                       LOG.warn("Connection error: " + e.getMessage());
-                       System.exit(1);
-               }
-
-               @Override
-               public void connectionStarted(Py4JServerConnection 
gatewayConnection) {
-                       LOG.debug("Connection Started: " + 
gatewayConnection.toString());
-               }
-
-               @Override
-               public void connectionStopped(Py4JServerConnection 
gatewayConnection) {
-                       LOG.debug("Connection stopped: " + 
gatewayConnection.toString());
-               }
-
-               @Override
-               public void serverError(Exception e) {
-                       LOG.error("Server Error " + e.getMessage());
-               }
-
                @Override
                public void serverPostShutdown() {
                        LOG.info("Shutdown done");
-                       System.exit(0);
                }
 
                @Override
@@ -108,13 +93,12 @@ public class PythonDMLScript {
 
                @Override
                public void serverStarted() {
-                       LOG.info("GatewayServer Started");
+                       LOG.info("GatewayServer started");
                }
 
                @Override
                public void serverStopped() {
-                       LOG.info("GatewayServer Stopped");
-                       System.exit(0);
+                       LOG.info("GatewayServer stopped");
                }
        }
 
diff --git 
a/src/test/java/org/apache/sysds/test/usertest/pythonapi/StartupTest.java 
b/src/test/java/org/apache/sysds/test/usertest/pythonapi/StartupTest.java
index 4b8395107f..9e7cda13ee 100644
--- a/src/test/java/org/apache/sysds/test/usertest/pythonapi/StartupTest.java
+++ b/src/test/java/org/apache/sysds/test/usertest/pythonapi/StartupTest.java
@@ -19,11 +19,55 @@
 
 package org.apache.sysds.test.usertest.pythonapi;
 
+import org.apache.log4j.Level;
+import org.apache.log4j.Logger;
+import org.apache.log4j.spi.LoggingEvent;
 import org.apache.sysds.api.PythonDMLScript;
+import org.apache.sysds.test.LoggingUtils;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
 import org.junit.Test;
+import py4j.GatewayServer;
+
+import java.security.Permission;
+import java.util.List;
+
 
 /** Simple tests to verify startup of Python Gateway server happens without 
crashes */
 public class StartupTest {
+       private LoggingUtils.TestAppender appender;
+       private SecurityManager sm;
+
+       @Before
+       public void setUp() {
+               appender = LoggingUtils.overwrite();
+               sm = System.getSecurityManager();
+               System.setSecurityManager(new NoExitSecurityManager());
+               PythonDMLScript.setDMLGateWayListenerLoggerLevel(Level.ALL);
+               
Logger.getLogger(PythonDMLScript.class.getName()).setLevel(Level.ALL);
+       }
+
+       @After
+       public void tearDown() {
+               LoggingUtils.reinsert(appender);
+               System.setSecurityManager(sm);
+       }
+
+       private void assertLogMessages(String... expectedMessages) {
+               List<LoggingEvent> log = LoggingUtils.reinsert(appender);
+               log.stream().forEach(l -> System.out.println(l.getMessage()));
+               Assert.assertEquals("Unexpected number of log messages", 
expectedMessages.length, log.size());
+
+               for (int i = 0; i < expectedMessages.length; i++) {
+                       // order does not matter
+                       boolean found = false;
+                       for (String message : expectedMessages) {
+                               found |= 
log.get(i).getMessage().toString().startsWith(message);
+                       }
+                       Assert.assertTrue("Unexpected log message: " + 
log.get(i).getMessage(),found);
+               }
+       }
 
        @Test(expected = Exception.class)
        public void testStartupIncorrect_1() throws Exception {
@@ -50,4 +94,48 @@ public class StartupTest {
                // Number out of range
                PythonDMLScript.main(new String[] {"-python", "918757"});
        }
+
+       @Test
+       public void testStartupIncorrect_6() throws Exception {
+               GatewayServer gws1 = null;
+               try {
+                       PythonDMLScript.main(new String[]{"-python", "4001"});
+                       gws1 = PythonDMLScript.GwS;
+                       Thread.sleep(200);
+                       PythonDMLScript.main(new String[]{"-python", "4001"});
+                       Thread.sleep(200);
+               } catch (SecurityException e) {
+                       assertLogMessages(
+                                       "GatewayServer started",
+                                       "failed startup"
+                       );
+                       gws1.shutdown();
+               }
+       }
+
+       @Test
+       public void testStartupCorrect() throws Exception {
+               PythonDMLScript.main(new String[]{"-python", "4002"});
+               Thread.sleep(200);
+               PythonDMLScript script = (PythonDMLScript) 
PythonDMLScript.GwS.getGateway().getEntryPoint();
+               script.getConnection();
+               PythonDMLScript.GwS.shutdown();
+               Thread.sleep(200);
+               assertLogMessages(
+                               "GatewayServer started",
+                               "Starting JVM shutdown",
+                               "Shutdown done",
+                               "GatewayServer stopped"
+               );
+       }
+
+       class NoExitSecurityManager extends SecurityManager {
+               @Override
+               public void checkPermission(Permission perm) { }
+
+               @Override
+               public void checkExit(int status) {
+                       throw new SecurityException("Intercepted exit()");
+               }
+       }
 }

Reply via email to