This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 0745c164c0 [MINOR] Cleanup federated netty setup
0745c164c0 is described below

commit 0745c164c05ec200532435950b170ebf8b713c9c
Author: baunsgaard <[email protected]>
AuthorDate: Mon May 2 17:24:12 2022 +0200

    [MINOR] Cleanup federated netty setup
    
    This commit simply move a bit of the netty setup around to make the
    code cleaner, also some of these moving around give slight improvements
    in small federated requests allowing slightly faster startup of transfer.
    
    Closes #1599
---
 src/main/java/org/apache/sysds/api/DMLScript.java  | 14 +---
 .../apache/sysds/conf/ConfigurationManager.java    |  4 +
 .../controlprogram/federated/FederatedData.java    | 71 ++++++++++--------
 .../controlprogram/federated/FederatedWorker.java  | 86 ++++++++++++----------
 .../federated/FederatedWorkerHandler.java          | 44 ++++++-----
 .../federated/FederatedWorkloadAnalyzer.java       |  3 +-
 .../controlprogram/federated/FederationUtils.java  |  8 ++
 .../test/component/federated/FedWorkerBase.java    |  2 +-
 .../test/component/federated/FedWorkerMatrix.java  |  4 +-
 .../test/component/federated/FedWorkerScalar.java  |  1 -
 10 files changed, 134 insertions(+), 103 deletions(-)

diff --git a/src/main/java/org/apache/sysds/api/DMLScript.java 
b/src/main/java/org/apache/sysds/api/DMLScript.java
index bfd5a089fe..d74cf59bf7 100644
--- a/src/main/java/org/apache/sysds/api/DMLScript.java
+++ b/src/main/java/org/apache/sysds/api/DMLScript.java
@@ -25,7 +25,6 @@ import java.io.FileReader;
 import java.io.IOException;
 import java.io.InputStream;
 import java.io.InputStreamReader;
-import java.security.cert.CertificateException;
 import java.text.DateFormat;
 import java.text.SimpleDateFormat;
 import java.util.Date;
@@ -73,14 +72,14 @@ import org.apache.sysds.runtime.lineage.LineageCacheConfig;
 import org.apache.sysds.runtime.lineage.LineageCacheConfig.LineageCachePolicy;
 import org.apache.sysds.runtime.lineage.LineageCacheConfig.ReuseCacheType;
 import org.apache.sysds.runtime.privacy.CheckedConstraintsLog;
-import org.apache.sysds.runtime.util.LocalFileUtils;
 import org.apache.sysds.runtime.util.CommonThreadPool;
 import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.runtime.util.LocalFileUtils;
 import org.apache.sysds.utils.Explain;
-import org.apache.sysds.utils.NativeHelper;
-import org.apache.sysds.utils.Statistics;
 import org.apache.sysds.utils.Explain.ExplainCounts;
 import org.apache.sysds.utils.Explain.ExplainType;
+import org.apache.sysds.utils.NativeHelper;
+import org.apache.sysds.utils.Statistics;
 
 public class DMLScript 
 {
@@ -281,12 +280,7 @@ public class DMLScript
                        
                        if(dmlOptions.fedWorker) {
                                loadConfiguration(fnameOptConfig);
-                               try {
-                                       new 
FederatedWorker(dmlOptions.fedWorkerPort, dmlOptions.debug).run();
-                               }
-                               catch(CertificateException e) {
-                                       e.printStackTrace();
-                               }
+                               new FederatedWorker(dmlOptions.fedWorkerPort, 
dmlOptions.debug);
                                return true;
                        }
 
diff --git a/src/main/java/org/apache/sysds/conf/ConfigurationManager.java 
b/src/main/java/org/apache/sysds/conf/ConfigurationManager.java
index 5b7bf7ac65..505f33b19f 100644
--- a/src/main/java/org/apache/sysds/conf/ConfigurationManager.java
+++ b/src/main/java/org/apache/sysds/conf/ConfigurationManager.java
@@ -215,6 +215,10 @@ public class ConfigurationManager
                return getDMLConfig().getIntValue(DMLConfig.FEDERATED_TIMEOUT);
        }
 
+       public static boolean isFederatedSSL(){
+               return 
getDMLConfig().getBooleanValue(DMLConfig.USE_SSL_FEDERATED_COMMUNICATION);
+       }
+
        ///////////////////////////////////////
        // Thread-local classes
        
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
index 95901b4a2e..70e41a9e9b 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
@@ -49,11 +49,10 @@ import io.netty.channel.EventLoopGroup;
 import io.netty.channel.nio.NioEventLoopGroup;
 import io.netty.channel.socket.SocketChannel;
 import io.netty.channel.socket.nio.NioSocketChannel;
-import io.netty.handler.codec.serialization.ClassResolvers;
-import io.netty.handler.codec.serialization.ObjectDecoder;
 import io.netty.handler.codec.serialization.ObjectEncoder;
 import io.netty.handler.ssl.SslContext;
 import io.netty.handler.ssl.SslContextBuilder;
+import io.netty.handler.ssl.SslHandler;
 import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
 import io.netty.handler.timeout.ReadTimeoutHandler;
 import io.netty.util.concurrent.Promise;
@@ -142,9 +141,8 @@ public class FederatedData {
                if(!_dataType.isMatrix() && !_dataType.isFrame())
                        throw new DMLRuntimeException("Federated datatype \"" + 
_dataType.toString() + "\" is not supported.");
                _varID = id;
-               FederatedRequest request = (mtd != null ) ? 
-                       new FederatedRequest(RequestType.READ_VAR, id, mtd) :
-                       new FederatedRequest(RequestType.READ_VAR, id);
+               FederatedRequest request = (mtd != null) ? new 
FederatedRequest(RequestType.READ_VAR, id,
+                       mtd) : new FederatedRequest(RequestType.READ_VAR, id);
                request.appendParam(_filepath);
                request.appendParam(_dataType.name());
                return executeFederatedOperation(request);
@@ -165,42 +163,44 @@ public class FederatedData {
                FederatedRequest... request) {
                try {
                        final Bootstrap b = new Bootstrap();
-
                        if(workerGroup == null)
                                createWorkGroup();
-
+                       b.group(workerGroup);
+                       b.channel(NioSocketChannel.class);
                        final DataRequestHandler handler = new 
DataRequestHandler();
                        // Client Netty
-                       
b.group(workerGroup).channel(NioSocketChannel.class).handler(new 
ChannelInitializer<SocketChannel>() {
-                               @Override
-                               protected void initChannel(SocketChannel ch) 
throws Exception {
-                                       final ChannelPipeline cp = 
ch.pipeline();
-                                       
if(ConfigurationManager.getDMLConfig().getBooleanValue(DMLConfig.USE_SSL_FEDERATED_COMMUNICATION))
-                                               
cp.addLast(SslConstructor().context.newHandler(ch.alloc(), 
address.getAddress().getHostAddress(),
-                                                       address.getPort()));
-
-                                       final int timeout = 
ConfigurationManager.getFederatedTimeout();
-                                       if(timeout > -1)
-                                               cp.addLast("timeout", new 
ReadTimeoutHandler(timeout));
-
-                                       cp.addLast("ObjectDecoder", new 
ObjectDecoder(Integer.MAX_VALUE,
-                                               
ClassResolvers.weakCachingResolver(ClassLoader.getSystemClassLoader())));
-                                       cp.addLast("FederatedOperationHandler", 
handler);
-                                       cp.addLast("FederatedRequestEncoder", 
new FederatedRequestEncoder());
-                               }
-                       });
+
+                       b.handler(createChannel(address, handler));
 
                        ChannelFuture f = b.connect(address).sync();
                        Promise<FederatedResponse> promise = 
f.channel().eventLoop().newPromise();
                        handler.setPromise(promise);
                        f.channel().writeAndFlush(request);
-                       return promise;
+
+                       return handler.getProm();
                }
                catch(Exception e) {
                        throw new DMLRuntimeException("Failed sending federated 
operation", e);
                }
        }
 
+       private static ChannelInitializer<SocketChannel> 
createChannel(InetSocketAddress address, DataRequestHandler handler){
+               final int timeout = ConfigurationManager.getFederatedTimeout();
+               final boolean ssl = ConfigurationManager.isFederatedSSL();
+
+               return new ChannelInitializer<SocketChannel>() {
+                       @Override
+                       protected void initChannel(SocketChannel ch) throws 
Exception {
+                               final ChannelPipeline cp = ch.pipeline();
+                               if(ssl)
+                                       cp.addLast(createSSLHandler(ch, 
address));
+                               if(timeout > -1)
+                                       cp.addLast(new 
ReadTimeoutHandler(timeout));
+                               cp.addLast(FederationUtils.decoder(), new 
FederatedRequestEncoder(), handler);
+                       }
+               };
+       }
+
        public static void clearFederatedWorkers() {
                if(_allFedSites.isEmpty())
                        return;
@@ -223,17 +223,22 @@ public class FederatedData {
                }
        }
 
+       private static SslHandler createSSLHandler(SocketChannel ch, 
InetSocketAddress address){
+               return SslConstructor().context.newHandler(ch.alloc(), 
address.getAddress().getHostAddress(),
+                                                       address.getPort());
+       }
+
        public static void resetFederatedSites() {
                _allFedSites.clear();
        }
 
-       public static void clearWorkGroup(){
+       public static void clearWorkGroup() {
                if(workerGroup != null)
                        workerGroup.shutdownGracefully();
                workerGroup = null;
        }
 
-       public synchronized static void createWorkGroup(){
+       public synchronized static void createWorkGroup() {
                if(workerGroup == null)
                        workerGroup = new 
NioEventLoopGroup(DMLConfig.DEFAULT_NUMBER_OF_FEDERATED_WORKER_THREADS);
        }
@@ -250,11 +255,13 @@ public class FederatedData {
 
                @Override
                public void channelRead(ChannelHandlerContext ctx, Object msg) {
-                       if(_prom == null)
-                               throw new DMLRuntimeException("Read while no 
message was sent");
                        _prom.setSuccess((FederatedResponse) msg);
                        ctx.close();
                }
+
+               public Promise<FederatedResponse> getProm() {
+                       return _prom;
+               }
        }
 
        private static class SslContextMan {
@@ -271,9 +278,9 @@ public class FederatedData {
        }
 
        private static SslContextMan SslConstructor() {
-               if(sslInstance == null) 
+               if(sslInstance == null)
                        return new SslContextMan();
-               else 
+               else
                        return sslInstance;
        }
 
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorker.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorker.java
index 9a8fe38f6e..d090d0553c 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorker.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorker.java
@@ -27,6 +27,14 @@ import java.util.concurrent.TimeUnit;
 
 import javax.net.ssl.SSLException;
 
+import org.apache.log4j.Logger;
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.conf.ConfigurationManager;
+import org.apache.sysds.conf.DMLConfig;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import 
org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
+import org.apache.sysds.runtime.lineage.LineageCacheConfig;
+
 import io.netty.bootstrap.ServerBootstrap;
 import io.netty.buffer.ByteBuf;
 import io.netty.channel.ChannelFuture;
@@ -37,18 +45,10 @@ import io.netty.channel.ChannelPipeline;
 import io.netty.channel.nio.NioEventLoopGroup;
 import io.netty.channel.socket.SocketChannel;
 import io.netty.channel.socket.nio.NioServerSocketChannel;
-import io.netty.handler.codec.serialization.ClassResolvers;
-import io.netty.handler.codec.serialization.ObjectDecoder;
 import io.netty.handler.codec.serialization.ObjectEncoder;
 import io.netty.handler.ssl.SslContext;
 import io.netty.handler.ssl.SslContextBuilder;
 import io.netty.handler.ssl.util.SelfSignedCertificate;
-import org.apache.sysds.api.DMLScript;
-import org.apache.log4j.Logger;
-import org.apache.sysds.conf.ConfigurationManager;
-import org.apache.sysds.conf.DMLConfig;
-import 
org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
-import org.apache.sysds.runtime.lineage.LineageCacheConfig;
 
 public class FederatedWorker {
        protected static Logger log = Logger.getLogger(FederatedWorker.class);
@@ -73,39 +73,28 @@ public class FederatedWorker {
                LineageCacheConfig.setConfig(DMLScript.LINEAGE_REUSE);
                LineageCacheConfig.setCachePolicy(DMLScript.LINEAGE_POLICY);
                LineageCacheConfig.setEstimator(DMLScript.LINEAGE_ESTIMATE);
+
+               run();
        }
 
-       public void run() throws CertificateException, SSLException {
+       private void run() {
                log.info("Setting up Federated Worker on port " + _port);
                int par_conn = 
ConfigurationManager.getDMLConfig().getIntValue(DMLConfig.FEDERATED_PAR_CONN);
                final int EVENT_LOOP_THREADS = (par_conn > 0) ? par_conn : 
InfrastructureAnalyzer.getLocalParallelism();
                NioEventLoopGroup bossGroup = new NioEventLoopGroup(1);
-               ThreadPoolExecutor workerTPE = new ThreadPoolExecutor(1, 
Integer.MAX_VALUE,
-                       10, TimeUnit.SECONDS, new 
SynchronousQueue<Runnable>(true));
+               ThreadPoolExecutor workerTPE = new ThreadPoolExecutor(1, 
Integer.MAX_VALUE, 10, TimeUnit.SECONDS,
+                       new SynchronousQueue<Runnable>(true));
                NioEventLoopGroup workerGroup = new 
NioEventLoopGroup(EVENT_LOOP_THREADS, workerTPE);
-               ServerBootstrap b = new ServerBootstrap();
-               // TODO add ability to use real ssl files, not self signed 
certificates.
-               SelfSignedCertificate cert = new SelfSignedCertificate();
-               final SslContext cont2 = 
SslContextBuilder.forServer(cert.certificate(), cert.privateKey()).build();
 
+               final boolean ssl = ConfigurationManager.isFederatedSSL();
                try {
-                       b.group(bossGroup, 
workerGroup).channel(NioServerSocketChannel.class)
-                               .childHandler(new 
ChannelInitializer<SocketChannel>() {
-                                       @Override
-                                       public void initChannel(SocketChannel 
ch) {
-                                               ChannelPipeline cp = 
ch.pipeline();
-
-                                               
if(ConfigurationManager.getDMLConfig()
-                                                       
.getBooleanValue(DMLConfig.USE_SSL_FEDERATED_COMMUNICATION)) {
-                                                       
cp.addLast(cont2.newHandler(ch.alloc()));
-                                               }
-                                               cp.addLast("ObjectDecoder",
-                                                       new 
ObjectDecoder(Integer.MAX_VALUE,
-                                                               
ClassResolvers.weakCachingResolver(ClassLoader.getSystemClassLoader())));
-                                               
cp.addLast("FederatedResponseEncoder", new FederatedResponseEncoder());
-                                               
cp.addLast("FederatedWorkerHandler", new FederatedWorkerHandler(_flt, _frc, 
_fan));
-                                       }
-                               }).option(ChannelOption.SO_BACKLOG, 
128).childOption(ChannelOption.SO_KEEPALIVE, true);
+                       final ServerBootstrap b = new ServerBootstrap();
+                       b.group(bossGroup, workerGroup);
+                       b.channel(NioServerSocketChannel.class);
+                       b.childHandler(createChannel(ssl));
+                       b.option(ChannelOption.SO_BACKLOG, 128);
+                       b.childOption(ChannelOption.SO_KEEPALIVE, true);
+
                        log.info("Starting Federated Worker server at port: " + 
_port);
                        ChannelFuture f = b.bind(_port).sync();
                        log.info("Started Federated Worker at port: " + _port);
@@ -113,7 +102,7 @@ public class FederatedWorker {
                }
                catch(Exception e) {
                        log.info("Federated worker interrupted");
-                       if ( _debug ){
+                       if(_debug) {
                                log.error(e.getMessage());
                                e.printStackTrace();
                        }
@@ -127,14 +116,15 @@ public class FederatedWorker {
 
        public static class FederatedResponseEncoder extends ObjectEncoder {
                @Override
-               protected ByteBuf allocateBuffer(ChannelHandlerContext ctx, 
Serializable msg,
-                       boolean preferDirect) throws Exception {
+               protected ByteBuf allocateBuffer(ChannelHandlerContext ctx, 
Serializable msg, boolean preferDirect)
+                       throws Exception {
                        int initCapacity = 256; // default initial capacity
                        if(msg instanceof FederatedResponse) {
-                               FederatedResponse response = 
(FederatedResponse)msg;
+                               FederatedResponse response = 
(FederatedResponse) msg;
                                try {
                                        initCapacity = 
Math.toIntExact(response.estimateSerializationBufferSize());
-                               } catch(ArithmeticException ae) { // size of 
cache block exceeds integer limits
+                               }
+                               catch(ArithmeticException ae) { // size of 
cache block exceeds integer limits
                                        initCapacity = Integer.MAX_VALUE;
                                }
                        }
@@ -144,4 +134,26 @@ public class FederatedWorker {
                                return ctx.alloc().heapBuffer(initCapacity);
                }
        }
+
+       private ChannelInitializer<SocketChannel> createChannel(boolean ssl) {
+               try {
+                       // TODO add ability to use real ssl files, not self 
signed certificates.
+                       final SelfSignedCertificate cert = new 
SelfSignedCertificate();
+                       final SslContext cont2 = 
SslContextBuilder.forServer(cert.certificate(), cert.privateKey()).build();
+
+                       return new ChannelInitializer<SocketChannel>() {
+                               @Override
+                               public void initChannel(SocketChannel ch) {
+                                       final ChannelPipeline cp = 
ch.pipeline();
+                                       if(ssl)
+                                               
cp.addLast(cont2.newHandler(ch.alloc()));
+                                       cp.addLast(FederationUtils.decoder(), 
new FederatedResponseEncoder());
+                                       cp.addLast(new 
FederatedWorkerHandler(_flt, _frc, _fan));
+                               }
+                       };
+               }
+               catch(CertificateException | SSLException e) {
+                       throw new DMLRuntimeException("Failed creating channel 
SSL", e);
+               }
+       }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index 9fbc7d7804..769dabc173 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -428,26 +428,32 @@ public class FederatedWorkerHandler extends 
ChannelInboundHandlerAdapter {
        }
 
        private FederatedResponse getVariable(FederatedRequest request, 
ExecutionContextMap ecm) {
-               checkNumParams(request.getNumParams(), 0);
-               ExecutionContext ec = ecm.get(request.getTID());
-               if(!ec.containsVariable(String.valueOf(request.getID())))
-                       throw new FederatedWorkerHandlerException(
-                               "Variable " + request.getID() + " does not 
exist at federated worker.");
+               try{
 
-               // get variable and construct response
-               Data dataObject = 
ec.getVariable(String.valueOf(request.getID()));
-               dataObject = PrivacyMonitor.handlePrivacy(dataObject);
-               switch(dataObject.getDataType()) {
-                       case TENSOR:
-                       case MATRIX:
-                       case FRAME:
-                               return new 
FederatedResponse(ResponseType.SUCCESS, ((CacheableData<?>) 
dataObject).acquireReadAndRelease());
-                       case LIST:
-                               return new 
FederatedResponse(ResponseType.SUCCESS, ((ListObject) dataObject).getData());
-                       case SCALAR:
-                               return new 
FederatedResponse(ResponseType.SUCCESS, dataObject);
-                       default:
-                               throw new 
FederatedWorkerHandlerException("Unsupported return datatype " + 
dataObject.getDataType().name());
+                       checkNumParams(request.getNumParams(), 0);
+                       ExecutionContext ec = ecm.get(request.getTID());
+                       
if(!ec.containsVariable(String.valueOf(request.getID())))
+                               throw new FederatedWorkerHandlerException(
+                                       "Variable " + request.getID() + " does 
not exist at federated worker.");
+       
+                       // get variable and construct response
+                       Data dataObject = 
ec.getVariable(String.valueOf(request.getID()));
+                       dataObject = PrivacyMonitor.handlePrivacy(dataObject);
+                       switch(dataObject.getDataType()) {
+                               case TENSOR:
+                               case MATRIX:
+                               case FRAME:
+                                       return new 
FederatedResponse(ResponseType.SUCCESS, ((CacheableData<?>) 
dataObject).acquireReadAndRelease());
+                               case LIST:
+                                       return new 
FederatedResponse(ResponseType.SUCCESS, ((ListObject) dataObject).getData());
+                               case SCALAR:
+                                       return new 
FederatedResponse(ResponseType.SUCCESS, dataObject);
+                               default:
+                                       throw new 
FederatedWorkerHandlerException("Unsupported return datatype " + 
dataObject.getDataType().name());
+                       }
+               }
+               catch(Exception e){
+                       throw new FederatedWorkerHandlerException("Failed to 
getVariable " , e);
                }
        }
 
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkloadAnalyzer.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkloadAnalyzer.java
index 78a2eab7e8..1db1a458be 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkloadAnalyzer.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkloadAnalyzer.java
@@ -32,7 +32,7 @@ import 
org.apache.sysds.runtime.instructions.cp.AggregateBinaryCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.ComputationCPInstruction;
 
 public class FederatedWorkloadAnalyzer {
-       private static final Log LOG = 
LogFactory.getLog(FederatedWorkerHandler.class.getName());
+       protected static final Log LOG = 
LogFactory.getLog(FederatedWorkerHandler.class.getName());
 
        /** Frequency value for how many instructions before we do a pass for 
compression */
        private static int compressRunFrequency = 10;
@@ -84,7 +84,6 @@ public class FederatedWorkloadAnalyzer {
                                getOrMakeCounter(mm, 
Long.parseLong(n2)).incLMM(c2);
                                counter++;
                        }
-                       LOG.error(mm + " " + Long.parseLong(n2));
                }
        }
 
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
index a31d8beaf1..671cd0b744 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
@@ -56,6 +56,9 @@ import 
org.apache.sysds.runtime.matrix.operators.BinaryOperator;
 import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
 import org.apache.sysds.runtime.matrix.operators.SimpleOperator;
 
+import io.netty.handler.codec.serialization.ClassResolvers;
+import io.netty.handler.codec.serialization.ObjectDecoder;
+
 public class FederationUtils {
        protected static Logger log = Logger.getLogger(FederationUtils.class);
        private static final IDSequence _idSeq = new IDSequence();
@@ -555,4 +558,9 @@ public class FederationUtils {
                        dataParts.add(readResponse.getValue());
                return FederationUtils.aggAdd(dataParts.toArray(new Future[0]));
        }
+
+       public static ObjectDecoder decoder() {
+               return new ObjectDecoder(Integer.MAX_VALUE,
+                       
ClassResolvers.weakCachingResolver(ClassLoader.getSystemClassLoader()));
+       }
 }
diff --git 
a/src/test/java/org/apache/sysds/test/component/federated/FedWorkerBase.java 
b/src/test/java/org/apache/sysds/test/component/federated/FedWorkerBase.java
index c7d7a0a596..1bf5d33006 100644
--- a/src/test/java/org/apache/sysds/test/component/federated/FedWorkerBase.java
+++ b/src/test/java/org/apache/sysds/test/component/federated/FedWorkerBase.java
@@ -50,7 +50,7 @@ public abstract class FedWorkerBase {
 
        protected static int startWorker(String confPath) {
                final int port = AutomatedTestBase.getRandomAvailablePort();
-               AutomatedTestBase.startLocalFedWorkerThread(port, new String[] 
{"-config", confPath}, 3000);
+               AutomatedTestBase.startLocalFedWorkerThread(port, new String[] 
{"-config", confPath}, 5000);
                return port;
        }
 
diff --git 
a/src/test/java/org/apache/sysds/test/component/federated/FedWorkerMatrix.java 
b/src/test/java/org/apache/sysds/test/component/federated/FedWorkerMatrix.java
index 7bacda79a0..f0c7ec3a30 100644
--- 
a/src/test/java/org/apache/sysds/test/component/federated/FedWorkerMatrix.java
+++ 
b/src/test/java/org/apache/sysds/test/component/federated/FedWorkerMatrix.java
@@ -52,6 +52,9 @@ public class FedWorkerMatrix extends FedWorkerBase {
                final MatrixBlock mb10x1000 = 
TestUtils.generateTestMatrixBlock(10, 1000, 0.5, 9.5, 1.0, 1342);
                tests.add(new Object[] {port, mb10x1000, 10});
 
+               // final MatrixBlock mb1000x1000 = 
TestUtils.generateTestMatrixBlock(1000, 1000, 0.5, 9.5, 1.0, 1342);
+               // tests.add(new Object[] {port, mb1000x1000, 300});
+
                return tests;
        }
 
@@ -81,5 +84,4 @@ public class FedWorkerMatrix extends FedWorkerBase {
                                "Not equivalent matrix block returned from 
federated site");
                }
        }
-
 }
diff --git 
a/src/test/java/org/apache/sysds/test/component/federated/FedWorkerScalar.java 
b/src/test/java/org/apache/sysds/test/component/federated/FedWorkerScalar.java
index 16afe13cb1..9709699bfc 100644
--- 
a/src/test/java/org/apache/sysds/test/component/federated/FedWorkerScalar.java
+++ 
b/src/test/java/org/apache/sysds/test/component/federated/FedWorkerScalar.java
@@ -76,5 +76,4 @@ public class FedWorkerScalar extends FedWorkerBase {
                        assertEquals("values not equivalent", vrInit, vr, 
0.0000001);
                }
        }
-
 }

Reply via email to