This is an automated email from the ASF dual-hosted git repository. baunsgaard pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push: new 0386fe1917 [MINOR] Fix in Python API GatewayServerListener 0386fe1917 is described below commit 0386fe1917fd71df6008f24c4f000efd0ffece0f Author: e-strauss <92718421+e-stra...@users.noreply.github.com> AuthorDate: Mon Mar 31 15:22:01 2025 +0200 [MINOR] Fix in Python API GatewayServerListener The current inner class DMLGateWayListener implements function from the GatewayServerListener interface, which are never invoked by the GatewayServer (since the GatewayServer, which also implements GatewayServerListener, does not implement these methods. Furthermore, DMLGateWayListener previously called, Sys.exit(), which I think is not correct, since it breaks the proper shutdown of the GatewayServer. Finally, this commit added a new unit case, which checks the functionality of the D [...] Closes #2243 --- .../java/org/apache/sysds/api/PythonDMLScript.java | 42 ++++------- .../sysds/test/usertest/pythonapi/StartupTest.java | 88 ++++++++++++++++++++++ 2 files changed, 101 insertions(+), 29 deletions(-) diff --git a/src/main/java/org/apache/sysds/api/PythonDMLScript.java b/src/main/java/org/apache/sysds/api/PythonDMLScript.java index c4957d4e9f..80f5ffcd75 100644 --- a/src/main/java/org/apache/sysds/api/PythonDMLScript.java +++ b/src/main/java/org/apache/sysds/api/PythonDMLScript.java @@ -21,17 +21,20 @@ package org.apache.sysds.api; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.apache.log4j.Level; +import org.apache.log4j.Logger; import org.apache.sysds.api.jmlc.Connection; +import py4j.DefaultGatewayServerListener; import py4j.GatewayServer; -import py4j.GatewayServerListener; import py4j.Py4JNetworkException; -import py4j.Py4JServerConnection; + public class PythonDMLScript { private static final Log LOG = LogFactory.getLog(PythonDMLScript.class.getName()); final private Connection _connection; + public static GatewayServer GwS; /** * Entry point for Python API. @@ -42,7 +45,7 @@ public class PythonDMLScript { public static void main(String[] args) throws Exception { final DMLOptions dmlOptions = DMLOptions.parseCLArguments(args); DMLScript.loadConfiguration(dmlOptions.configFile); - final GatewayServer GwS = new GatewayServer(new PythonDMLScript(), dmlOptions.pythonPort); + GwS = new GatewayServer(new PythonDMLScript(), dmlOptions.pythonPort); GwS.addListener(new DMLGateWayListener()); try { GwS.start(); @@ -67,38 +70,20 @@ public class PythonDMLScript { _connection = new Connection(); } + public static void setDMLGateWayListenerLoggerLevel(Level l){ + Logger.getLogger(DMLGateWayListener.class).setLevel(l); + } + public Connection getConnection() { return _connection; } - protected static class DMLGateWayListener implements GatewayServerListener { + protected static class DMLGateWayListener extends DefaultGatewayServerListener { private static final Log LOG = LogFactory.getLog(DMLGateWayListener.class.getName()); - @Override - public void connectionError(Exception e) { - LOG.warn("Connection error: " + e.getMessage()); - System.exit(1); - } - - @Override - public void connectionStarted(Py4JServerConnection gatewayConnection) { - LOG.debug("Connection Started: " + gatewayConnection.toString()); - } - - @Override - public void connectionStopped(Py4JServerConnection gatewayConnection) { - LOG.debug("Connection stopped: " + gatewayConnection.toString()); - } - - @Override - public void serverError(Exception e) { - LOG.error("Server Error " + e.getMessage()); - } - @Override public void serverPostShutdown() { LOG.info("Shutdown done"); - System.exit(0); } @Override @@ -108,13 +93,12 @@ public class PythonDMLScript { @Override public void serverStarted() { - LOG.info("GatewayServer Started"); + LOG.info("GatewayServer started"); } @Override public void serverStopped() { - LOG.info("GatewayServer Stopped"); - System.exit(0); + LOG.info("GatewayServer stopped"); } } diff --git a/src/test/java/org/apache/sysds/test/usertest/pythonapi/StartupTest.java b/src/test/java/org/apache/sysds/test/usertest/pythonapi/StartupTest.java index 4b8395107f..9e7cda13ee 100644 --- a/src/test/java/org/apache/sysds/test/usertest/pythonapi/StartupTest.java +++ b/src/test/java/org/apache/sysds/test/usertest/pythonapi/StartupTest.java @@ -19,11 +19,55 @@ package org.apache.sysds.test.usertest.pythonapi; +import org.apache.log4j.Level; +import org.apache.log4j.Logger; +import org.apache.log4j.spi.LoggingEvent; import org.apache.sysds.api.PythonDMLScript; +import org.apache.sysds.test.LoggingUtils; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; import org.junit.Test; +import py4j.GatewayServer; + +import java.security.Permission; +import java.util.List; + /** Simple tests to verify startup of Python Gateway server happens without crashes */ public class StartupTest { + private LoggingUtils.TestAppender appender; + private SecurityManager sm; + + @Before + public void setUp() { + appender = LoggingUtils.overwrite(); + sm = System.getSecurityManager(); + System.setSecurityManager(new NoExitSecurityManager()); + PythonDMLScript.setDMLGateWayListenerLoggerLevel(Level.ALL); + Logger.getLogger(PythonDMLScript.class.getName()).setLevel(Level.ALL); + } + + @After + public void tearDown() { + LoggingUtils.reinsert(appender); + System.setSecurityManager(sm); + } + + private void assertLogMessages(String... expectedMessages) { + List<LoggingEvent> log = LoggingUtils.reinsert(appender); + log.stream().forEach(l -> System.out.println(l.getMessage())); + Assert.assertEquals("Unexpected number of log messages", expectedMessages.length, log.size()); + + for (int i = 0; i < expectedMessages.length; i++) { + // order does not matter + boolean found = false; + for (String message : expectedMessages) { + found |= log.get(i).getMessage().toString().startsWith(message); + } + Assert.assertTrue("Unexpected log message: " + log.get(i).getMessage(),found); + } + } @Test(expected = Exception.class) public void testStartupIncorrect_1() throws Exception { @@ -50,4 +94,48 @@ public class StartupTest { // Number out of range PythonDMLScript.main(new String[] {"-python", "918757"}); } + + @Test + public void testStartupIncorrect_6() throws Exception { + GatewayServer gws1 = null; + try { + PythonDMLScript.main(new String[]{"-python", "4001"}); + gws1 = PythonDMLScript.GwS; + Thread.sleep(200); + PythonDMLScript.main(new String[]{"-python", "4001"}); + Thread.sleep(200); + } catch (SecurityException e) { + assertLogMessages( + "GatewayServer started", + "failed startup" + ); + gws1.shutdown(); + } + } + + @Test + public void testStartupCorrect() throws Exception { + PythonDMLScript.main(new String[]{"-python", "4002"}); + Thread.sleep(200); + PythonDMLScript script = (PythonDMLScript) PythonDMLScript.GwS.getGateway().getEntryPoint(); + script.getConnection(); + PythonDMLScript.GwS.shutdown(); + Thread.sleep(200); + assertLogMessages( + "GatewayServer started", + "Starting JVM shutdown", + "Shutdown done", + "GatewayServer stopped" + ); + } + + class NoExitSecurityManager extends SecurityManager { + @Override + public void checkPermission(Permission perm) { } + + @Override + public void checkExit(int status) { + throw new SecurityException("Intercepted exit()"); + } + } }