This is an automated email from the ASF dual-hosted git repository. amichai pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/aries-rsa.git
commit 75c97d75ba65a685f92e036d2a335625877855ae Author: Amichai Rothman <[email protected]> AuthorDate: Thu Oct 5 16:14:56 2023 +0300 ARIES-2122 - Allow multiple services on one port --- .../aries/rsa/provider/tcp/MethodInvoker.java | 4 ++ .../apache/aries/rsa/provider/tcp/TCPProvider.java | 42 +++++++++++++++++-- .../apache/aries/rsa/provider/tcp/TCPServer.java | 44 ++++++++++++++----- .../apache/aries/rsa/provider/tcp/TcpEndpoint.java | 49 ++++++++++++++++++---- .../rsa/provider/tcp/TcpInvocationHandler.java | 9 ++-- .../provider/tcp/ser/BasicObjectInputStream.java | 5 ++- .../aries/rsa/provider/tcp/TcpEndpointTest.java | 6 +-- .../rsa/provider/tcp/TcpProviderIntentTest.java | 2 +- .../aries/rsa/provider/tcp/TcpProviderTest.java | 46 ++++++++++++++++---- .../rsa/provider/tcp/myservice/MyService.java | 3 ++ .../rsa/provider/tcp/myservice/MyServiceImpl.java | 11 +++++ 11 files changed, 184 insertions(+), 37 deletions(-) diff --git a/provider/tcp/src/main/java/org/apache/aries/rsa/provider/tcp/MethodInvoker.java b/provider/tcp/src/main/java/org/apache/aries/rsa/provider/tcp/MethodInvoker.java index 5ff2f77e..0dcc810c 100644 --- a/provider/tcp/src/main/java/org/apache/aries/rsa/provider/tcp/MethodInvoker.java +++ b/provider/tcp/src/main/java/org/apache/aries/rsa/provider/tcp/MethodInvoker.java @@ -40,6 +40,10 @@ public class MethodInvoker { this.primTypes.put(Character.TYPE, Character.class); } + public Object getService() { + return service; + } + public Object invoke(String methodName, Object[] args) throws Exception { Class<?>[] parameterTypesAr = getTypes(args); Method method = getMethod(methodName, parameterTypesAr); diff --git a/provider/tcp/src/main/java/org/apache/aries/rsa/provider/tcp/TCPProvider.java b/provider/tcp/src/main/java/org/apache/aries/rsa/provider/tcp/TCPProvider.java index a54d6140..bdc85053 100644 --- a/provider/tcp/src/main/java/org/apache/aries/rsa/provider/tcp/TCPProvider.java +++ b/provider/tcp/src/main/java/org/apache/aries/rsa/provider/tcp/TCPProvider.java @@ -18,11 +18,13 @@ */ package org.apache.aries.rsa.provider.tcp; +import java.io.IOException; import java.lang.reflect.InvocationHandler; import java.lang.reflect.Proxy; import java.net.URI; import java.util.Arrays; import java.util.Collection; +import java.util.HashMap; import java.util.HashSet; import java.util.Map; import java.util.Set; @@ -52,6 +54,8 @@ public class TCPProvider implements DistributionProvider { private Logger logger = LoggerFactory.getLogger(TCPProvider.class); + private Map<Integer, TCPServer> servers = new HashMap<>(); + @Override public String[] getSupportedTypes() { return new String[] {TCP_CONFIG_TYPE}; @@ -79,7 +83,38 @@ public class TCPProvider implements DistributionProvider { logger.warn("Unsupported intents found: {}. Not exporting service", intents); return null; } - return new TcpEndpoint(serviceO, effectiveProperties); + TcpEndpoint endpoint = new TcpEndpoint(serviceO, effectiveProperties, this::removeServer); + addServer(serviceO, endpoint); + return endpoint; + } + + private synchronized void addServer(Object serviceO, TcpEndpoint endpoint) { + // port 0 means dynamically allocated free port + int port = endpoint.getPort(); + TCPServer server = servers.get(port); + if (server == null || port == 0) { + server = new TCPServer(endpoint.getHostname(), port, endpoint.getNumThreads()); + port = server.getPort(); // get the real port + endpoint.setPort(port); + servers.put(port, server); + } + // different services may configure different number of threads - we pick the max + if (endpoint.getNumThreads() > server.getNumThreads()) { + server.setNumThreads(endpoint.getNumThreads()); + } + server.addService(endpoint.description().getId(), serviceO); + } + + private synchronized void removeServer(TcpEndpoint endpoint) { + TCPServer server = servers.get(endpoint.getPort()); + server.removeService(endpoint.description().getId()); + if (server.isEmpty()) { + try { + servers.remove(endpoint.getPort()).close(); + } catch (IOException ioe) { + throw new RuntimeException(ioe); + } + } } @Override @@ -89,9 +124,10 @@ public class TCPProvider implements DistributionProvider { EndpointDescription endpoint) throws IntentUnsatisfiedException { try { - URI address = new URI(endpoint.getId()); + String endpointId = endpoint.getId(); + URI address = new URI(endpointId); int timeout = new EndpointPropertiesParser(endpoint).getTimeoutMillis(); - InvocationHandler handler = new TcpInvocationHandler(cl, address.getHost(), address.getPort(), timeout); + InvocationHandler handler = new TcpInvocationHandler(cl, address.getHost(), address.getPort(), endpointId, timeout); return Proxy.newProxyInstance(cl, interfaces, handler); } catch (Exception e) { throw new RuntimeException(e); diff --git a/provider/tcp/src/main/java/org/apache/aries/rsa/provider/tcp/TCPServer.java b/provider/tcp/src/main/java/org/apache/aries/rsa/provider/tcp/TCPServer.java index e5fb5b4c..93824222 100644 --- a/provider/tcp/src/main/java/org/apache/aries/rsa/provider/tcp/TCPServer.java +++ b/provider/tcp/src/main/java/org/apache/aries/rsa/provider/tcp/TCPServer.java @@ -26,6 +26,7 @@ import java.lang.reflect.InvocationTargetException; import java.net.ServerSocket; import java.net.Socket; import java.net.SocketException; +import java.util.Map; import java.util.concurrent.*; import org.apache.aries.rsa.provider.tcp.ser.BasicObjectOutputStream; @@ -37,14 +38,11 @@ import org.slf4j.LoggerFactory; public class TCPServer implements Closeable, Runnable { private Logger log = LoggerFactory.getLogger(TCPServer.class); private ServerSocket serverSocket; - private Object service; + private Map<String, MethodInvoker> invokers = new ConcurrentHashMap<>(); private volatile boolean running; - private ExecutorService executor; - private MethodInvoker invoker; + private ThreadPoolExecutor executor; - public TCPServer(Object service, String localip, Integer port, int numThreads) { - this.service = service; - this.invoker = new MethodInvoker(service); + public TCPServer(String localip, int port, int numThreads) { try { this.serverSocket = new ServerSocket(port); this.serverSocket.setReuseAddress(true); @@ -62,6 +60,28 @@ public class TCPServer implements Closeable, Runnable { return this.serverSocket.getLocalPort(); } + public void addService(String endpointId, Object service) { + invokers.put(endpointId, new MethodInvoker(service)); + } + + public void removeService(String endpointId) { + invokers.remove(endpointId); + } + + public boolean isEmpty() { + return invokers.isEmpty(); + } + + public void setNumThreads(int numThreads) { + numThreads++; // plus one for server socket accepting thread + executor.setCorePoolSize(numThreads); + executor.setMaximumPoolSize(numThreads); + } + + public int getNumThreads() { + return executor.getMaximumPoolSize() - 1; // excluding socket accepting thread + } + public void run() { while (running) { try { @@ -76,11 +96,15 @@ public class TCPServer implements Closeable, Runnable { } private void handleConnection(Socket socket) { - ClassLoader serviceCL = service.getClass().getClassLoader(); try (Socket sock = socket; - ObjectInputStream in = new BasicObjectInputStream(socket.getInputStream(), serviceCL); + BasicObjectInputStream in = new BasicObjectInputStream(socket.getInputStream()); ObjectOutputStream out = new BasicObjectOutputStream(socket.getOutputStream())) { - handleCall(in, out); + String endpointId = in.readUTF(); + MethodInvoker invoker = invokers.get(endpointId); + if (invoker == null) + throw new IllegalArgumentException("invalid endpoint: " + endpointId); + in.addClassLoader(invoker.getService().getClass().getClassLoader()); + handleCall(invoker, in, out); } catch (SocketException se) { return; // e.g. connection closed by client } catch (Exception e) { @@ -88,7 +112,7 @@ public class TCPServer implements Closeable, Runnable { } } - private void handleCall(ObjectInputStream in, ObjectOutputStream out) throws Exception { + private void handleCall(MethodInvoker invoker, ObjectInputStream in, ObjectOutputStream out) throws Exception { String methodName = (String)in.readObject(); Object[] args = (Object[])in.readObject(); Throwable error = null; diff --git a/provider/tcp/src/main/java/org/apache/aries/rsa/provider/tcp/TcpEndpoint.java b/provider/tcp/src/main/java/org/apache/aries/rsa/provider/tcp/TcpEndpoint.java index f6d4f7a9..73d1dfdc 100644 --- a/provider/tcp/src/main/java/org/apache/aries/rsa/provider/tcp/TcpEndpoint.java +++ b/provider/tcp/src/main/java/org/apache/aries/rsa/provider/tcp/TcpEndpoint.java @@ -20,29 +20,42 @@ package org.apache.aries.rsa.provider.tcp; import java.io.IOException; import java.util.Arrays; +import java.util.HashMap; import java.util.Map; +import java.util.function.Consumer; import org.apache.aries.rsa.spi.Endpoint; import org.osgi.service.remoteserviceadmin.EndpointDescription; import org.osgi.service.remoteserviceadmin.RemoteConstants; public class TcpEndpoint implements Endpoint { + + private String hostname; + private int port; + private int numThreads; + private Consumer<TcpEndpoint> closeCallback; + private EndpointDescription epd; - private TCPServer tcpServer; - public TcpEndpoint(Object service, Map<String, Object> effectiveProperties) { + public TcpEndpoint(Object service, Map<String, Object> effectiveProperties, Consumer<TcpEndpoint> closeCallback) { if (service == null) { throw new NullPointerException("Service must not be null"); } if (effectiveProperties.get(TCPProvider.TCP_CONFIG_TYPE + ".id") != null) { throw new IllegalArgumentException("For the tck .. Just to please you!"); } + this.closeCallback = closeCallback; EndpointPropertiesParser parser = new EndpointPropertiesParser(effectiveProperties); - Integer port = parser.getPort(); - String hostName = parser.getHostname(); - int numThreads = parser.getNumThreads(); - tcpServer = new TCPServer(service, hostName, port, numThreads); - String endpointId = String.format("tcp://%s:%s/%s", hostName, tcpServer.getPort(), parser.getId()); + port = parser.getPort(); // this may initially be 0 for dynamic port + hostname = parser.getHostname(); + numThreads = parser.getNumThreads(); + updateEndpointDescription(effectiveProperties); + } + + private void updateEndpointDescription(Map<String, Object> effectiveProperties) { + effectiveProperties = new HashMap<>(effectiveProperties); + EndpointPropertiesParser parser = new EndpointPropertiesParser(effectiveProperties); + String endpointId = String.format("tcp://%s:%s/%s", hostname, port, parser.getId()); effectiveProperties.put(RemoteConstants.ENDPOINT_ID, endpointId); effectiveProperties.put(RemoteConstants.SERVICE_EXPORTED_CONFIGS, ""); effectiveProperties.put(RemoteConstants.SERVICE_INTENTS, Arrays.asList("osgi.basic", "osgi.async")); @@ -52,6 +65,25 @@ public class TcpEndpoint implements Endpoint { this.epd = new EndpointDescription(effectiveProperties); } + public String getHostname() { + return hostname; + } + + public int getPort() { + return port; + } + + public void setPort(int port) { + if (this.port == port) + return; + this.port = port; + updateEndpointDescription(epd.getProperties()); + } + + public int getNumThreads() { + return numThreads; + } + @Override public EndpointDescription description() { return this.epd; @@ -59,6 +91,7 @@ public class TcpEndpoint implements Endpoint { @Override public void close() throws IOException { - tcpServer.close(); + if (closeCallback != null) + closeCallback.accept(this); } } diff --git a/provider/tcp/src/main/java/org/apache/aries/rsa/provider/tcp/TcpInvocationHandler.java b/provider/tcp/src/main/java/org/apache/aries/rsa/provider/tcp/TcpInvocationHandler.java index 241908a2..9f08fbd2 100644 --- a/provider/tcp/src/main/java/org/apache/aries/rsa/provider/tcp/TcpInvocationHandler.java +++ b/provider/tcp/src/main/java/org/apache/aries/rsa/provider/tcp/TcpInvocationHandler.java @@ -19,7 +19,6 @@ package org.apache.aries.rsa.provider.tcp; import java.io.IOException; -import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.lang.reflect.InvocationHandler; import java.lang.reflect.InvocationTargetException; @@ -45,14 +44,16 @@ import org.osgi.util.promise.Promise; public class TcpInvocationHandler implements InvocationHandler { private String host; private int port; + private String endpointId; private ClassLoader cl; private int timeoutMillis; - public TcpInvocationHandler(ClassLoader cl, String host, int port, int timeoutMillis) + public TcpInvocationHandler(ClassLoader cl, String host, int port, String endpointId, int timeoutMillis) throws UnknownHostException, IOException { this.cl = cl; this.host = host; this.port = port; + this.endpointId = endpointId; this.timeoutMillis = timeoutMillis; } @@ -106,11 +107,13 @@ public class TcpInvocationHandler implements InvocationHandler { ObjectOutputStream out = new BasicObjectOutputStream(socket.getOutputStream()) ) { socket.setSoTimeout(timeoutMillis); + out.writeUTF(endpointId); out.writeObject(method.getName()); out.writeObject(args); out.flush(); - try (ObjectInputStream in = new BasicObjectInputStream(socket.getInputStream(), cl)) { + try (BasicObjectInputStream in = new BasicObjectInputStream(socket.getInputStream())) { + in.addClassLoader(cl); error = (Throwable) in.readObject(); result = readReplaceVersion(in.readObject()); } diff --git a/provider/tcp/src/main/java/org/apache/aries/rsa/provider/tcp/ser/BasicObjectInputStream.java b/provider/tcp/src/main/java/org/apache/aries/rsa/provider/tcp/ser/BasicObjectInputStream.java index ffa1d1cb..fe48ee16 100644 --- a/provider/tcp/src/main/java/org/apache/aries/rsa/provider/tcp/ser/BasicObjectInputStream.java +++ b/provider/tcp/src/main/java/org/apache/aries/rsa/provider/tcp/ser/BasicObjectInputStream.java @@ -36,7 +36,7 @@ public class BasicObjectInputStream extends ObjectInputStream { private final Set<ClassLoader> loaders = new LinkedHashSet<>(); // retains insertion order - public BasicObjectInputStream(InputStream in, ClassLoader loader) throws IOException { + public BasicObjectInputStream(InputStream in) throws IOException { super(in); AccessController.doPrivileged(new PrivilegedAction<Void>() { public Void run() { @@ -44,6 +44,9 @@ public class BasicObjectInputStream extends ObjectInputStream { return null; } }); + } + + public void addClassLoader(ClassLoader loader) { loaders.add(loader); // the original classloader goes first } diff --git a/provider/tcp/src/test/java/org/apache/aries/rsa/provider/tcp/TcpEndpointTest.java b/provider/tcp/src/test/java/org/apache/aries/rsa/provider/tcp/TcpEndpointTest.java index 5bf1b374..6b518f33 100644 --- a/provider/tcp/src/test/java/org/apache/aries/rsa/provider/tcp/TcpEndpointTest.java +++ b/provider/tcp/src/test/java/org/apache/aries/rsa/provider/tcp/TcpEndpointTest.java @@ -43,7 +43,7 @@ public class TcpEndpointTest { props = new HashMap<>(); props.put(Constants.OBJECTCLASS, new String[]{MyService.class.getName()}); props.put(RemoteConstants.SERVICE_IMPORTED_CONFIGS, ""); - service = new MyServiceImpl(); + service = new MyServiceImpl(null); } @Test @@ -51,7 +51,7 @@ public class TcpEndpointTest { props.put("aries.rsa.port", PORT); props.put("aries.rsa.hostname", HOSTNAME); props.put("aries.rsa.id", "testme"); - TcpEndpoint tcpEndpoint = new TcpEndpoint(service, props); + TcpEndpoint tcpEndpoint = new TcpEndpoint(service, props, null); EndpointDescription epd = tcpEndpoint.description(); Assert.assertEquals("tcp://" + HOSTNAME + ":" + PORT + "/testme", epd.getId()); tcpEndpoint.close(); @@ -59,7 +59,7 @@ public class TcpEndpointTest { @Test public void testEndpointPropertiesDefault() throws IOException { - TcpEndpoint tcpEndpoint = new TcpEndpoint(service, props); + TcpEndpoint tcpEndpoint = new TcpEndpoint(service, props, null); EndpointDescription epd = tcpEndpoint.description(); Assert.assertNotNull(epd.getId()); tcpEndpoint.close(); diff --git a/provider/tcp/src/test/java/org/apache/aries/rsa/provider/tcp/TcpProviderIntentTest.java b/provider/tcp/src/test/java/org/apache/aries/rsa/provider/tcp/TcpProviderIntentTest.java index 3c18f2b3..b2486ba6 100644 --- a/provider/tcp/src/test/java/org/apache/aries/rsa/provider/tcp/TcpProviderIntentTest.java +++ b/provider/tcp/src/test/java/org/apache/aries/rsa/provider/tcp/TcpProviderIntentTest.java @@ -47,7 +47,7 @@ public class TcpProviderIntentTest { exportedInterfaces = new Class[] {MyService.class}; bc = EasyMock.mock(BundleContext.class); provider = new TCPProvider(); - myService = new MyServiceImpl(); + myService = new MyServiceImpl(null); } @Test diff --git a/provider/tcp/src/test/java/org/apache/aries/rsa/provider/tcp/TcpProviderTest.java b/provider/tcp/src/test/java/org/apache/aries/rsa/provider/tcp/TcpProviderTest.java index 73af5051..600506e6 100644 --- a/provider/tcp/src/test/java/org/apache/aries/rsa/provider/tcp/TcpProviderTest.java +++ b/provider/tcp/src/test/java/org/apache/aries/rsa/provider/tcp/TcpProviderTest.java @@ -27,6 +27,8 @@ import static org.junit.Assert.assertTrue; import java.io.IOException; import java.lang.reflect.InvocationTargetException; +import java.net.InetSocketAddress; +import java.net.ServerSocket; import java.net.SocketTimeoutException; import java.util.ArrayList; import java.util.HashMap; @@ -62,26 +64,45 @@ public class TcpProviderTest { private static final int TIMEOUT = 200; private static final int NUM_CALLS = 100; private static MyService myServiceProxy; + private static MyService myServiceProxy2; private static Endpoint ep; + private static Endpoint ep2; + + protected static int getFreePort() throws IOException { + try (ServerSocket socket = new ServerSocket()) { + socket.setReuseAddress(true); // enables quickly reopening socket on same port + socket.bind(new InetSocketAddress(0)); // zero finds a free port + return socket.getLocalPort(); + } + } @BeforeClass - public static void createServerAndProxy() { + public static void createServerAndProxy() throws IOException { Class<?>[] exportedInterfaces = new Class[] {MyService.class}; TCPProvider provider = new TCPProvider(); Map<String, Object> props = new HashMap<>(); EndpointHelper.addObjectClass(props, exportedInterfaces); + int port = getFreePort(); props.put("aries.rsa.hostname", "localhost"); + props.put("aries.rsa.port", port); props.put("aries.rsa.numThreads", "10"); props.put("osgi.basic.timeout", TIMEOUT); - MyService myService = new MyServiceImpl(); BundleContext bc = EasyMock.mock(BundleContext.class); - ep = provider.exportService(myService, bc, props, exportedInterfaces); + props.put("aries.rsa.id", "service1"); + ep = provider.exportService(new MyServiceImpl("service1"), bc, props, exportedInterfaces); + props.put("aries.rsa.id", "service2"); + ep2 = provider.exportService(new MyServiceImpl("service2"), bc, props, exportedInterfaces); Assert.assertThat(ep.description().getId(), startsWith("tcp://localhost:")); - System.out.println(ep.description()); - myServiceProxy = (MyService)provider.importEndpoint(MyService.class.getClassLoader(), - bc, - exportedInterfaces, - ep.description()); + myServiceProxy = (MyService)provider.importEndpoint( + MyService.class.getClassLoader(), + bc, + exportedInterfaces, + ep.description()); + myServiceProxy2 = (MyService)provider.importEndpoint( + MyService.class.getClassLoader(), + bc, + exportedInterfaces, + ep2.description()); } @Test @@ -113,6 +134,15 @@ public class TcpProviderTest { myServiceProxy.echo("test"); } + @Test + public void testCallSharedPort() { + Object port1 = ep.description().getProperties().get("aries.rsa.port"); + Object port2 = ep2.description().getProperties().get("aries.rsa.port"); + assertEquals(port1, port2); + assertEquals("service1", myServiceProxy.getId()); + assertEquals("service2", myServiceProxy2.getId()); + } + @Test public void testCallOneway() { myServiceProxy.callOneWay("test"); diff --git a/provider/tcp/src/test/java/org/apache/aries/rsa/provider/tcp/myservice/MyService.java b/provider/tcp/src/test/java/org/apache/aries/rsa/provider/tcp/myservice/MyService.java index 5d9320bd..0063c25b 100644 --- a/provider/tcp/src/test/java/org/apache/aries/rsa/provider/tcp/myservice/MyService.java +++ b/provider/tcp/src/test/java/org/apache/aries/rsa/provider/tcp/myservice/MyService.java @@ -27,6 +27,9 @@ import javax.jws.Oneway; import org.osgi.util.promise.Promise; public interface MyService { + + String getId(); + String echo(String msg); void callSlow(int delay); diff --git a/provider/tcp/src/test/java/org/apache/aries/rsa/provider/tcp/myservice/MyServiceImpl.java b/provider/tcp/src/test/java/org/apache/aries/rsa/provider/tcp/myservice/MyServiceImpl.java index c91814fd..f54d0fb4 100644 --- a/provider/tcp/src/test/java/org/apache/aries/rsa/provider/tcp/myservice/MyServiceImpl.java +++ b/provider/tcp/src/test/java/org/apache/aries/rsa/provider/tcp/myservice/MyServiceImpl.java @@ -30,6 +30,17 @@ import org.osgi.util.promise.Promise; public class MyServiceImpl implements MyService { + String id; + + public MyServiceImpl(String id) { + this.id = id; + } + + @Override + public String getId() { + return id; + } + @Override public String echo(String msg) { return msg;
