This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 8356ea6  [SYSTEMDS-2600,2626] Fix federated backend request 
interference
8356ea6 is described below

commit 8356ea6861ed5dc5cc15430bce26cf5e4bdd7c9e
Author: Matthias Boehm <[email protected]>
AuthorDate: Tue Aug 18 19:47:16 2020 +0200

    [SYSTEMDS-2600,2626] Fix federated backend request interference
    
    This patch fixes two major issues of request interference from multiple
    coordinator threads.
    
    First, we now properly maintain separate execution context at the
    federated site for different request streams from parfor workers which
    otherwise could interfer (e.g., on rmvar instructions for shared input
    variables)
    
    Second, even within a stream of federated requests (e.g., execute and
    cleanup) could out output each other if there are no data dependencies
    or synchronization between them. We now added barriers for federated
    requests wherever this was necessary.
    
    Last, this patch also fixes unnecessary warning messages of the parfor
    optimizer, specifically in a setting with forced singlenode execution.
    
    Closes #1028.
---
 .../runtime/controlprogram/ParForProgramBlock.java |  1 +
 .../controlprogram/caching/CacheableData.java      | 10 +++-
 .../controlprogram/context/ExecutionContext.java   | 11 +++-
 .../context/SparkExecutionContext.java             |  2 +-
 .../federated/ExecutionContextMap.java             | 61 ++++++++++++++++++++++
 .../controlprogram/federated/FederatedRequest.java | 13 ++++-
 .../controlprogram/federated/FederatedWorker.java  | 12 ++---
 .../federated/FederatedWorkerHandler.java          | 44 ++++++++--------
 .../controlprogram/federated/FederationMap.java    | 45 ++++++++++++----
 .../parfor/opt/CostEstimatorHops.java              |  5 +-
 .../instructions/cp/VariableCPInstruction.java     |  2 +-
 .../fed/AggregateBinaryFEDInstruction.java         | 12 ++---
 .../fed/AggregateUnaryFEDInstruction.java          |  4 +-
 .../instructions/fed/AppendFEDInstruction.java     |  2 +-
 .../fed/BinaryMatrixMatrixFEDInstruction.java      |  8 +--
 .../fed/BinaryMatrixScalarFEDInstruction.java      |  4 +-
 .../runtime/instructions/fed/FEDInstruction.java   |  9 ++++
 .../instructions/fed/FEDInstructionUtils.java      | 20 ++++---
 .../fed/ParameterizedBuiltinFEDInstruction.java    |  2 +-
 .../instructions/fed/TsmmFEDInstruction.java       |  4 +-
 .../org/apache/sysds/test/AutomatedTestBase.java   |  2 +-
 .../functions/federated/FederatedKmeansTest.java   |  3 +-
 22 files changed, 200 insertions(+), 76 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java 
b/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
index 9e15139..c6d7e6e 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
@@ -1177,6 +1177,7 @@ public class ParForProgramBlock extends ForProgramBlock
                        
                        //deep copy execution context (including prepare parfor 
update-in-place)
                        ExecutionContext cpEc = 
ProgramConverter.createDeepCopyExecutionContext(ec);
+                       cpEc.setTID(pwID);
 
                        // If GPU mode is enabled, gets a GPUContext from the 
pool of GPUContexts
                        // and sets it in the ExecutionContext of the parfor
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
index c809a84..949e60a 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
@@ -629,6 +629,10 @@ public abstract class CacheableData<T extends CacheBlock> 
extends Data
                }
        }
        
+       public void clearData() {
+               clearData(-1);
+       }
+       
        /**
         * Sets the cache block reference to <code>null</code>, abandons the 
old block.
         * Makes the "envelope" empty.  Run it to finalize the object 
(otherwise the
@@ -637,8 +641,10 @@ public abstract class CacheableData<T extends CacheBlock> 
extends Data
         * In-Status:  EMPTY, EVICTABLE, EVICTED;
         * Out-Status: EMPTY.
         * 
+        * @param tid thread ID
+        * 
         */
-       public synchronized void clearData() 
+       public synchronized void clearData(long tid) 
        {
                // check if cleanup enabled and possible 
                if( !isCleanupEnabled() ) 
@@ -669,7 +675,7 @@ public abstract class CacheableData<T extends CacheBlock> 
extends Data
                
                //clear federated matrix
                if( _fedMapping != null )
-                       _fedMapping.cleanup(_fedMapping.getID());
+                       _fedMapping.cleanup(tid, _fedMapping.getID());
                
                // change object state EMPTY
                setDirty(false);
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
index fcb5db3..a34b77e 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
@@ -70,6 +70,7 @@ public class ExecutionContext {
        
        //symbol table
        protected LocalVariableMap _variables;
+       protected long _tid = -1;
        protected boolean _autoCreateVars;
        
        //lineage map, cache, prepared dedup blocks
@@ -131,6 +132,14 @@ public class ExecutionContext {
        public void setAutoCreateVars(boolean flag) {
                _autoCreateVars = flag;
        }
+       
+       public void setTID(long tid) {
+               _tid = tid;
+       }
+       
+       public long getTID() {
+               return _tid;
+       }
 
        /**
         * Get the i-th GPUContext
@@ -750,7 +759,7 @@ public class ExecutionContext {
                try {
                        //compute ref count only if matrix cleanup actually 
necessary
                        if ( mo.isCleanupEnabled() && 
!getVariables().hasReferences(mo) )  {
-                               mo.clearData(); //clean cached data
+                               mo.clearData(getTID()); //clean cached data
                                if( fileExists ) {
                                        
HDFSTool.deleteFileIfExistOnHDFS(mo.getFileName());
                                        
HDFSTool.deleteFileIfExistOnHDFS(mo.getFileName()+".mtd");
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
index bb6a94d..65348f1 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
@@ -1350,7 +1350,7 @@ public class SparkExecutionContext extends 
ExecutionContext
                        //compute ref count only if matrix cleanup actually 
necessary
                        if( !getVariables().hasReferences(mo) ) {
                                //clean cached data
-                               mo.clearData();
+                               mo.clearData(getTID());
 
                                //clean hdfs data if no pending rdd operations 
on it
                                if( mo.isHDFSFileExists() && 
mo.getFileName()!=null ) {
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/ExecutionContextMap.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/ExecutionContextMap.java
new file mode 100644
index 0000000..1d06f46
--- /dev/null
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/ExecutionContextMap.java
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.controlprogram.federated;
+
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContextFactory;
+
+public class ExecutionContextMap {
+       private final ExecutionContext _main;
+       private final Map<Long, ExecutionContext> _parEc;
+       
+       public ExecutionContextMap() {
+               _main = createExecutionContext();
+               _parEc = new ConcurrentHashMap<>();
+       }
+       
+       public ExecutionContext get(long tid) {
+               //return main execution context
+               if( tid <= 0 )
+                       return _main;
+               
+               //atomic probe, create if necessary, and return
+               return _parEc.computeIfAbsent(tid,
+                       k -> deriveExecutionContext(_main));
+       }
+       
+       private static ExecutionContext createExecutionContext() {
+               ExecutionContext ec = ExecutionContextFactory.createContext();
+               ec.setAutoCreateVars(true); //w/o createvar inst
+               return ec;
+       }
+       
+       private static ExecutionContext deriveExecutionContext(ExecutionContext 
ec) {
+               //derive execution context from main to make shared variables 
available
+               //but allow normal instruction processing and removal if 
necessary
+               ExecutionContext ec2 = ExecutionContextFactory
+                       .createContext(ec.getVariables(), ec.getProgram());
+               ec2.setAutoCreateVars(true); //w/o createvar inst
+               return ec2;
+       }
+}
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java
index 5618d36..d62e6f6 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java
@@ -41,6 +41,7 @@ public class FederatedRequest implements Serializable {
        
        private RequestType _method;
        private long _id;
+       private long _tid;
        private List<Object> _data;
        private boolean _checkPrivacy;
        
@@ -73,6 +74,14 @@ public class FederatedRequest implements Serializable {
                return _id;
        }
        
+       public long getTID() {
+               return _tid;
+       }
+       
+       public void setTID(long tid) {
+               _tid = tid;
+       }
+       
        public Object getParam(int i) {
                return _data.get(i);
        }
@@ -112,7 +121,9 @@ public class FederatedRequest implements Serializable {
                StringBuilder sb = new StringBuilder("FederatedRequest[");
                sb.append(_method); sb.append(";");
                sb.append(_id); sb.append(";");
-               sb.append(Arrays.toString(_data.toArray())); sb.append("]");
+               sb.append("t"); sb.append(_tid); sb.append(";");
+               if( _method != RequestType.PUT_VAR )
+                       sb.append(Arrays.toString(_data.toArray())); 
sb.append("]");
                return sb.toString();
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorker.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorker.java
index 1eca3a9..dae75e4 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorker.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorker.java
@@ -32,21 +32,15 @@ import io.netty.handler.codec.serialization.ObjectDecoder;
 import io.netty.handler.codec.serialization.ObjectEncoder;
 import org.apache.log4j.Logger;
 import org.apache.sysds.conf.DMLConfig;
-import org.apache.sysds.runtime.controlprogram.BasicProgramBlock;
-import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
-import org.apache.sysds.runtime.controlprogram.context.ExecutionContextFactory;
 
 public class FederatedWorker {
        protected static Logger log = Logger.getLogger(FederatedWorker.class);
 
        private int _port;
-       private final ExecutionContext _ec;
-       private final BasicProgramBlock _pb;
+       private final ExecutionContextMap _ecm;
        
        public FederatedWorker(int port) {
-               _ec = ExecutionContextFactory.createContext();
-               _ec.setAutoCreateVars(true); //w/o createvar inst
-               _pb = new BasicProgramBlock(null);
+               _ecm = new ExecutionContextMap();
                _port = (port == -1) ?
                        Integer.parseInt(DMLConfig.DEFAULT_FEDERATED_PORT) : 
port;
        }
@@ -65,7 +59,7 @@ public class FederatedWorker {
                                                        new 
ObjectDecoder(Integer.MAX_VALUE,
                                                                
ClassResolvers.weakCachingResolver(ClassLoader.getSystemClassLoader())))
                                                .addLast("ObjectEncoder", new 
ObjectEncoder())
-                                               
.addLast("FederatedWorkerHandler", new FederatedWorkerHandler(_ec, _pb));
+                                               
.addLast("FederatedWorkerHandler", new FederatedWorkerHandler(_ecm));
                                }
                        }).option(ChannelOption.SO_BACKLOG, 
128).childOption(ChannelOption.SO_KEEPALIVE, true);
                try {
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index 1afbfb1..00a8685 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -62,12 +62,13 @@ import java.util.Arrays;
 public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter {
        protected static Logger log = 
Logger.getLogger(FederatedWorkerHandler.class);
 
-       private final ExecutionContext _ec;
-       private final BasicProgramBlock _pb;
+       private final ExecutionContextMap _ecm;
        
-       public FederatedWorkerHandler(ExecutionContext ec, BasicProgramBlock 
pb) {
-               _ec = ec;
-               _pb = pb;
+       public FederatedWorkerHandler(ExecutionContextMap ecm) {
+               //Note: federated worker handler created for every command;
+               //and concurrent parfor threads at coordinator need separate
+               //execution contexts at the federated sites too
+               _ecm = ecm;
        }
 
        @Override
@@ -131,10 +132,10 @@ public class FederatedWorkerHandler extends 
ChannelInboundHandlerAdapter {
                checkNumParams(request.getNumParams(), 2);
                String filename = (String) request.getParam(0);
                DataType dt = DataType.valueOf((String)request.getParam(1));
-               return readData(filename, dt, request.getID());
+               return readData(filename, dt, request.getID(), 
request.getTID());
        }
 
-       private FederatedResponse readData(String filename, Types.DataType 
dataType, long id) {
+       private FederatedResponse readData(String filename, Types.DataType 
dataType, long id, long tid) {
                MatrixCharacteristics mc = new MatrixCharacteristics();
                mc.setBlocksize(ConfigurationManager.getBlocksize());
                CacheableData<?> cd;
@@ -180,7 +181,7 @@ public class FederatedWorkerHandler extends 
ChannelInboundHandlerAdapter {
                cd.release();
                
                //TODO spawn async load of data, otherwise on first access
-               _ec.setVariable(String.valueOf(id), cd);
+               _ecm.get(tid).setVariable(String.valueOf(id), cd);
                cd.enableCleanup(false); //guard against deletion
                
                if (dataType == Types.DataType.FRAME) {
@@ -193,7 +194,8 @@ public class FederatedWorkerHandler extends 
ChannelInboundHandlerAdapter {
        private FederatedResponse putVariable(FederatedRequest request) {
                checkNumParams(request.getNumParams(), 1);
                String varname = String.valueOf(request.getID());
-               if( _ec.containsVariable(varname) ) {
+               ExecutionContext ec = _ecm.get(request.getTID());
+               if( ec.containsVariable(varname) ) {
                        return new FederatedResponse(ResponseType.ERROR,
                                "Variable "+request.getID()+" already 
existing.");
                }
@@ -206,22 +208,19 @@ public class FederatedWorkerHandler extends 
ChannelInboundHandlerAdapter {
                        data = (ScalarObject) request.getParam(0);
                
                //set variable and construct empty response
-               _ec.setVariable(varname, data);
+               ec.setVariable(varname, data);
                return new FederatedResponse(ResponseType.SUCCESS_EMPTY);
        }
        
        private FederatedResponse getVariable(FederatedRequest request) {
                checkNumParams(request.getNumParams(), 0);
-               if( !_ec.containsVariable(String.valueOf(request.getID())) ) {
+               ExecutionContext ec = _ecm.get(request.getTID());
+               if( !ec.containsVariable(String.valueOf(request.getID())) ) {
                        return new FederatedResponse(ResponseType.ERROR,
                                "Variable "+request.getID()+" does not exist at 
federated worker.");
                }
                //get variable and construct response
-               return getVariableData(request.getID());
-       }
-       
-       private FederatedResponse getVariableData(long varID) {
-               Data dataObject = _ec.getVariable(String.valueOf(varID));
+               Data dataObject = 
ec.getVariable(String.valueOf(request.getID()));
                dataObject = PrivacyMonitor.handlePrivacy(dataObject);
                switch (dataObject.getDataType()) {
                        case TENSOR:
@@ -240,11 +239,13 @@ public class FederatedWorkerHandler extends 
ChannelInboundHandlerAdapter {
        }
        
        private FederatedResponse execInstruction(FederatedRequest request) {
-               _pb.getInstructions().clear();
-               _pb.getInstructions().add(InstructionParser
+               ExecutionContext ec = _ecm.get(request.getTID());
+               BasicProgramBlock pb = new BasicProgramBlock(null);
+               pb.getInstructions().clear();
+               pb.getInstructions().add(InstructionParser
                        .parseSingleInstruction((String)request.getParam(0)));
                try {
-                       _pb.execute(_ec); //execute single instruction
+                       pb.execute(ec); //execute single instruction
                }
                catch(Exception ex) {
                        return new FederatedResponse(ResponseType.ERROR, 
ex.getMessage());
@@ -254,16 +255,17 @@ public class FederatedWorkerHandler extends 
ChannelInboundHandlerAdapter {
        
        private FederatedResponse execUDF(FederatedRequest request) {
                checkNumParams(request.getNumParams(), 1);
+               ExecutionContext ec = _ecm.get(request.getTID());
                
                //get function and input parameters
                FederatedUDF udf = (FederatedUDF) request.getParam(0);
                Data[] inputs = Arrays.stream(udf.getInputIDs())
-                       .mapToObj(id -> _ec.getVariable(String.valueOf(id)))
+                       .mapToObj(id -> ec.getVariable(String.valueOf(id)))
                        .toArray(Data[]::new);
                
                //execute user-defined function
                try {
-                       return udf.execute(_ec, inputs);
+                       return udf.execute(ec, inputs);
                }
                catch(Exception ex) {
                        return new FederatedResponse(ResponseType.ERROR, 
ex.getMessage());
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
index d323bad..371c3ff 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
@@ -20,6 +20,7 @@
 package org.apache.sysds.runtime.controlprogram.federated;
 
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.List;
 import java.util.Map;
 import java.util.Map.Entry;
@@ -95,22 +96,37 @@ public class FederationMap
                return ret.toArray(new FederatedRequest[0]);
        }
        
-       @SuppressWarnings("unchecked")
-       public Future<FederatedResponse>[] execute(FederatedRequest... fr) {
-               List<Future<FederatedResponse>> ret = new ArrayList<>();
-               for(Entry<FederatedRange, FederatedData> e : _fedMap.entrySet())
-                       ret.add(e.getValue().executeFederatedOperation(fr));
-               return ret.toArray(new Future[0]);
+       public Future<FederatedResponse>[] execute(long tid, 
FederatedRequest... fr) {
+               return execute(tid, false, fr);
+       }
+       
+       public Future<FederatedResponse>[] execute(long tid, boolean wait, 
FederatedRequest... fr) {
+               return execute(tid, wait, null, fr);
+       }
+       
+       public Future<FederatedResponse>[] execute(long tid, FederatedRequest[] 
frSlices, FederatedRequest... fr) {
+               return execute(tid, false, frSlices, fr);
        }
        
        @SuppressWarnings("unchecked")
-       public Future<FederatedResponse>[] execute(FederatedRequest[] frSlices, 
FederatedRequest... fr) {
-               //executes step1[] - step 2 - ... step4 (only first step 
federated-data-specific)
+       public Future<FederatedResponse>[] execute(long tid, boolean wait, 
FederatedRequest[] frSlices, FederatedRequest... fr) {
+               // executes step1[] - step 2 - ... step4 (only first step 
federated-data-specific)
+               setThreadID(tid, frSlices, fr);
                List<Future<FederatedResponse>> ret = new ArrayList<>(); 
                int pos = 0;
                for(Entry<FederatedRange, FederatedData> e : _fedMap.entrySet())
-                       
ret.add(e.getValue().executeFederatedOperation(addAll(frSlices[pos++], fr)));
-               return ret.toArray(new Future[0]);
+                       ret.add(e.getValue().executeFederatedOperation(
+                               (frSlices!=null) ? addAll(frSlices[pos++], fr) 
: fr));
+               
+               // prepare results (future federated responses), with optional 
wait to ensure the 
+               // order of requests without data dependencies (e.g., cleanup 
RPCs)
+               Future<FederatedResponse>[] ret2 = ret.toArray(new Future[0]);
+               if( wait ) {
+                       Arrays.stream(ret2).forEach(e -> {
+                               try {e.get();} catch(Exception ex) {throw new 
DMLRuntimeException(ex);}
+                       });
+               }
+               return ret2;
        }
        
        public List<Pair<FederatedRange, Future<FederatedResponse>>> 
requestFederatedData() {
@@ -125,9 +141,10 @@ public class FederationMap
                return readResponses;
        }
        
-       public void cleanup(long... id) {
+       public void cleanup(long tid, long... id) {
                FederatedRequest request = new 
FederatedRequest(RequestType.EXEC_INST, -1,
                        
VariableCPInstruction.prepareRemoveInstruction(id).toString());
+               request.setTID(tid);
                for(FederatedData fd : _fedMap.values())
                        fd.executeFederatedOperation(request);
        }
@@ -204,6 +221,12 @@ public class FederationMap
                fedMapCopy._ID = newVarID;
                return fedMapCopy;
        }
+       
+       private static void setThreadID(long tid, FederatedRequest[]... frsets) 
{
+               for( FederatedRequest[] frset : frsets )
+                       if( frset != null )
+                               Arrays.stream(frset).forEach(fr -> 
fr.setTID(tid));
+       }
 
        private static class MappingTask implements Callable<Void> {
                private final FederatedRange _range;
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/CostEstimatorHops.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/CostEstimatorHops.java
index a881d37..eb70d0c 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/CostEstimatorHops.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/CostEstimatorHops.java
@@ -58,6 +58,7 @@ public class CostEstimatorHops extends CostEstimator
                
                //handle specific cases 
                double DEFAULT_MEM_REMOTE = 
OptimizerUtils.isSparkExecutionMode() ? DEFAULT_MEM_SP : 0;
+               boolean forcedExec =  DMLScript.getGlobalExecMode() == 
ExecMode.SINGLE_NODE || h.getForcedExecType()!=null;
                
                if( value >= DEFAULT_MEM_REMOTE )
                {
@@ -67,7 +68,7 @@ public class CostEstimatorHops extends CostEstimator
                        }
                        //check for invalid cp memory estimate
                        else if ( h.getExecType()==ExecType.CP && value >= 
OptimizerUtils.getLocalMemBudget() ) {
-                               if( DMLScript.getGlobalExecMode() != 
ExecMode.SINGLE_NODE && h.getForcedExecType()==null )
+                               if( !forcedExec )
                                        LOG.warn("Memory estimate larger than 
budget but CP exec type (op="+h.getOpString()+", name="+h.getName()+", 
memest="+h.getMemEstimate()+").");
                                value = DEFAULT_MEM_REMOTE;
                        }
@@ -84,7 +85,7 @@ public class CostEstimatorHops extends CostEstimator
                        value = DEFAULT_MEM_REMOTE;
                }
                
-               if( value <= 0 ) { //no mem estimate
+               if( value <= 0 && !forcedExec ) { //no mem estimate
                        LOG.warn("Cannot get memory estimate for hop 
(op="+h.getOpString()+", name="+h.getName()+", 
memest="+h.getMemEstimate()+").");
                        value = CostEstimator.DEFAULT_MEM_ESTIMATE_CP;
                }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java
index 96cb4c6..c2a05da 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java
@@ -741,7 +741,7 @@ public class VariableCPInstruction extends CPInstruction 
implements LineageTrace
                        // no other variable in the symbol table points to the 
same Data object as that of input1.getName()
                        
                        //remove matrix object from cache
-                       m.clearData();
+                       m.clearData(ec.getTID());
                }
        }
 
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
index 14f81bf..6fd6173 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
@@ -69,15 +69,15 @@ public class AggregateBinaryFEDInstruction extends 
BinaryFEDInstruction {
                        if( mo2.getNumColumns() == 1 ) { //MV
                                FederatedRequest fr3 = new 
FederatedRequest(RequestType.GET_VAR, fr2.getID());
                                //execute federated operations and aggregate
-                               Future<FederatedResponse>[] tmp = 
mo1.getFedMapping().execute(fr1, fr2, fr3);
+                               Future<FederatedResponse>[] tmp = 
mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3);
                                MatrixBlock ret = FederationUtils.rbind(tmp);
-                               mo1.getFedMapping().cleanup(fr1.getID(), 
fr2.getID());
+                               mo1.getFedMapping().cleanup(getTID(), 
fr1.getID(), fr2.getID());
                                ec.setMatrixOutput(output.getName(), ret);
                        }
                        else { //MM
                                //execute federated operations and aggregate
-                               mo1.getFedMapping().execute(fr1, fr2);
-                               mo1.getFedMapping().cleanup(fr1.getID());
+                               mo1.getFedMapping().execute(getTID(), true, 
fr1, fr2);
+                               mo1.getFedMapping().cleanup(getTID(), 
fr1.getID());
                                MatrixObject out = ec.getMatrixObject(output);
                                
out.getDataCharacteristics().set(mo1.getNumRows(), mo2.getNumColumns(), 
(int)mo1.getBlocksize());
                                
out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr2.getID(), 
mo2.getNumColumns()));
@@ -91,9 +91,9 @@ public class AggregateBinaryFEDInstruction extends 
BinaryFEDInstruction {
                                new CPOperand[]{input1, input2}, new 
long[]{fr1[0].getID(), mo2.getFedMapping().getID()});
                        FederatedRequest fr3 = new 
FederatedRequest(RequestType.GET_VAR, fr2.getID());
                        //execute federated operations and aggregate
-                       Future<FederatedResponse>[] tmp = 
mo2.getFedMapping().execute(fr1, fr2, fr3);
+                       Future<FederatedResponse>[] tmp = 
mo2.getFedMapping().execute(getTID(), fr1, fr2, fr3);
                        MatrixBlock ret = FederationUtils.aggAdd(tmp);
-                       mo2.getFedMapping().cleanup(fr1[0].getID(), 
fr2.getID());
+                       mo2.getFedMapping().cleanup(getTID(), fr1[0].getID(), 
fr2.getID());
                        ec.setMatrixOutput(output.getName(), ret);
                }
                else { //other combinations
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
index a9b655b..e87bf57 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
@@ -63,11 +63,11 @@ public class AggregateUnaryFEDInstruction extends 
UnaryFEDInstruction {
                
                //execute federated commands and cleanups
                FederationMap map = in.getFedMapping();
-               Future<FederatedResponse>[] tmp = map.execute(fr1, fr2);
-               map.cleanup(fr1.getID());
+               Future<FederatedResponse>[] tmp = map.execute(getTID(), fr1, 
fr2);
                if( output.isScalar() )
                        ec.setVariable(output.getName(), 
FederationUtils.aggScalar(aop, tmp));
                else
                        ec.setMatrixOutput(output.getName(), 
FederationUtils.aggMatrix(aop, tmp, map));
+               map.cleanup(getTID(), fr1.getID());
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AppendFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AppendFEDInstruction.java
index 8fed7f7..985d117 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AppendFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AppendFEDInstruction.java
@@ -80,7 +80,7 @@ public class AppendFEDInstruction extends 
BinaryFEDInstruction {
                        FederatedRequest fr1 = 
mo1.getFedMapping().broadcast(mo2);
                        FederatedRequest fr2 = 
FederationUtils.callInstruction(instString, output,
                                new CPOperand[]{input1, input2}, new 
long[]{mo1.getFedMapping().getID(), fr1.getID()});
-                       mo1.getFedMapping().execute(fr1, fr2);
+                       mo1.getFedMapping().execute(getTID(), true, fr1, fr2);
                        //derive new fed mapping for output
                        MatrixObject out = ec.getMatrixObject(output);
                        out.getDataCharacteristics().set(dc1.getRows(), 
dc1.getCols()+dc2.getCols(),
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
index 7813f6a..7166373 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
@@ -51,16 +51,16 @@ public class BinaryMatrixMatrixFEDInstruction extends 
BinaryFEDInstruction
                        fr2 = FederationUtils.callInstruction(instString, 
output, new CPOperand[]{input1, input2},
                                new long[]{mo1.getFedMapping().getID(), 
fr1[0].getID()});
                        //execute federated instruction and cleanup 
intermediates
-                       mo1.getFedMapping().execute(fr1, fr2);
-                       mo1.getFedMapping().cleanup(fr1[0].getID());
+                       mo1.getFedMapping().execute(getTID(), true, fr1, fr2);
+                       mo1.getFedMapping().cleanup(getTID(), fr1[0].getID());
                }
                else { //MM or MV col vector
                        FederatedRequest fr1 = 
mo1.getFedMapping().broadcast(mo2);
                        fr2 = FederationUtils.callInstruction(instString, 
output, new CPOperand[]{input1, input2},
                                new long[]{mo1.getFedMapping().getID(), 
fr1.getID()});
                        //execute federated instruction and cleanup 
intermediates
-                       mo1.getFedMapping().execute(fr1, fr2);
-                       mo1.getFedMapping().cleanup(fr1.getID());
+                       mo1.getFedMapping().execute(getTID(), true, fr1, fr2);
+                       mo1.getFedMapping().cleanup(getTID(), fr1.getID());
                }
                
                //derive new fed mapping for output
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixScalarFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixScalarFEDInstruction.java
index 0e05ca8..75bfe33 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixScalarFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixScalarFEDInstruction.java
@@ -46,10 +46,10 @@ public class BinaryMatrixScalarFEDInstruction extends 
BinaryFEDInstruction
                        new CPOperand[]{matrix, (fr1 != null)?scalar:null},
                        new long[]{mo.getFedMapping().getID(), (fr1 != 
null)?fr1.getID():-1});
                
-               mo.getFedMapping().execute((fr1!=null) ?
+               mo.getFedMapping().execute(getTID(), true, (fr1!=null) ?
                        new FederatedRequest[]{fr1, fr2}: new 
FederatedRequest[]{fr2});
                if( fr1 != null )
-                       mo.getFedMapping().cleanup(fr1.getID());
+                       mo.getFedMapping().cleanup(getTID(), fr1.getID());
                
                //derive new fed mapping for output
                MatrixObject out = ec.getMatrixObject(output);
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
index 9e58e52..6df1b1e 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
@@ -39,6 +39,7 @@ public abstract class FEDInstruction extends Instruction {
        
        protected final FEDType _fedType;
        protected final Operator _optr;
+       protected long _tid = -1; //main
        
        protected FEDInstruction(FEDType type, String opcode, String istr) {
                this(type, null, opcode, istr);
@@ -60,6 +61,14 @@ public abstract class FEDInstruction extends Instruction {
                return _fedType;
        }
        
+       public long getTID() {
+               return _tid;
+       }
+       
+       public void setTID(long tid) {
+               _tid = tid;
+       }
+       
        @Override
        public Instruction preprocessInstruction(ExecutionContext ec) {
                Instruction tmp = super.preprocessInstruction(ec);
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index 0a5a2a2..4325456 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -34,13 +34,14 @@ public class FEDInstructionUtils {
        // counterpart, since we do not propagate the information that a matrix 
is federated, therefore we can not decide
        // to choose a federated instruction earlier.
        public static Instruction checkAndReplaceCP(Instruction inst, 
ExecutionContext ec) {
+               FEDInstruction fedinst = null;
                if (inst instanceof AggregateBinaryCPInstruction) {
                        AggregateBinaryCPInstruction instruction = 
(AggregateBinaryCPInstruction) inst;
                        if( instruction.input1.isMatrix() && 
instruction.input2.isMatrix() ) {
                                MatrixObject mo1 = 
ec.getMatrixObject(instruction.input1);
                                MatrixObject mo2 = 
ec.getMatrixObject(instruction.input2);
                                if (mo1.isFederated() || mo2.isFederated()) {
-                                       return 
AggregateBinaryFEDInstruction.parseInstruction(inst.getInstructionString());
+                                       fedinst = 
AggregateBinaryFEDInstruction.parseInstruction(inst.getInstructionString());
                                }
                        }
                }
@@ -49,7 +50,7 @@ public class FEDInstructionUtils {
                        if( instruction.input1.isMatrix() && 
ec.containsVariable(instruction.input1) ) {
                                MatrixObject mo1 = 
ec.getMatrixObject(instruction.input1);
                                if (mo1.isFederated() && 
instruction.getAUType() == AggregateUnaryCPInstruction.AUType.DEFAULT)
-                                       return 
AggregateUnaryFEDInstruction.parseInstruction(inst.getInstructionString());
+                                       fedinst = 
AggregateUnaryFEDInstruction.parseInstruction(inst.getInstructionString());
                        }
                }
                else if (inst instanceof BinaryCPInstruction) {
@@ -57,13 +58,13 @@ public class FEDInstructionUtils {
                        if( (instruction.input1.isMatrix() && 
ec.getMatrixObject(instruction.input1).isFederated())
                                || (instruction.input2.isMatrix() && 
ec.getMatrixObject(instruction.input2).isFederated()) ) {
                                if(!instruction.getOpcode().equals("append")) 
//TODO support rbind/cbind
-                                       return 
BinaryFEDInstruction.parseInstruction(inst.getInstructionString());
+                                       fedinst = 
BinaryFEDInstruction.parseInstruction(inst.getInstructionString());
                        }
                }
                else if( inst instanceof ParameterizedBuiltinCPInstruction ) {
                        ParameterizedBuiltinCPInstruction pinst = 
(ParameterizedBuiltinCPInstruction)inst;
                        if(pinst.getOpcode().equals("replace") && 
pinst.getTarget(ec).isFederated()) {
-                               return 
ParameterizedBuiltinFEDInstruction.parseInstruction(pinst.getInstructionString());
+                               fedinst = 
ParameterizedBuiltinFEDInstruction.parseInstruction(pinst.getInstructionString());
                        }
                }
                else if (inst instanceof 
MultiReturnParameterizedBuiltinCPInstruction) {
@@ -71,7 +72,7 @@ public class FEDInstructionUtils {
                        if(minst.getOpcode().equals("transformencode") && 
minst.input1.isFrame()) {
                                CacheableData<?> fo = 
ec.getCacheableData(minst.input1);
                                if(fo.isFederated()) {
-                                       return 
MultiReturnParameterizedBuiltinFEDInstruction
+                                       fedinst = 
MultiReturnParameterizedBuiltinFEDInstruction
                                                
.parseInstruction(minst.getInstructionString());
                                }
                        }
@@ -80,8 +81,15 @@ public class FEDInstructionUtils {
                        MMTSJCPInstruction linst = (MMTSJCPInstruction) inst;
                        MatrixObject mo = ec.getMatrixObject(linst.input1);
                        if( mo.isFederated() )
-                               return 
TsmmFEDInstruction.parseInstruction(linst.getInstructionString());
+                               fedinst = 
TsmmFEDInstruction.parseInstruction(linst.getInstructionString());
                }
+               
+               //set thread id for federated context management
+               if( fedinst != null ) {
+                       fedinst.setTID(ec.getTID());
+                       return fedinst;
+               }
+               
                return inst;
        }
        
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
index 3a5ff8a..ec28965 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
@@ -99,7 +99,7 @@ public class ParameterizedBuiltinFEDInstruction extends 
ComputationFEDInstructio
                        MatrixObject mo = getTarget(ec);
                        FederatedRequest fr1 = 
FederationUtils.callInstruction(instString, output,
                                new CPOperand[]{getTargetOperand()}, new 
long[]{mo.getFedMapping().getID()});
-                       mo.getFedMapping().execute(fr1);
+                       mo.getFedMapping().execute(getTID(), true, fr1);
                        
                        //derive new fed mapping for output
                        MatrixObject out = ec.getMatrixObject(output);
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
index a3061ed..292bced 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
@@ -69,9 +69,9 @@ public class TsmmFEDInstruction extends BinaryFEDInstruction {
                        FederatedRequest fr2 = new 
FederatedRequest(RequestType.GET_VAR, fr1.getID());
                        
                        //execute federated operations and aggregate
-                       Future<FederatedResponse>[] tmp = 
mo1.getFedMapping().execute(fr1, fr2);
+                       Future<FederatedResponse>[] tmp = 
mo1.getFedMapping().execute(getTID(), fr1, fr2);
                        MatrixBlock ret = FederationUtils.aggAdd(tmp);
-                       mo1.getFedMapping().cleanup(fr1.getID());
+                       mo1.getFedMapping().cleanup(getTID(), fr1.getID());
                        ec.setMatrixOutput(output.getName(), ret);
                }
                else { //other combinations
diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java 
b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
index 7e63127..b40a637 100644
--- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
+++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
@@ -100,7 +100,7 @@ public abstract class AutomatedTestBase {
        public static final boolean TEST_GPU = false;
        public static final double GPU_TOLERANCE = 1e-9;
 
-       public static final int FED_WORKER_WAIT = 500; // in ms
+       public static final int FED_WORKER_WAIT = 750; // in ms
 
        // With OpenJDK 8u242 on Windows, the new changes in JDK are not 
allowing
        // to set the native library paths internally thus breaking the code.
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/FederatedKmeansTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/FederatedKmeansTest.java
index a216fb3..6991797 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/FederatedKmeansTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/FederatedKmeansTest.java
@@ -62,8 +62,7 @@ public class FederatedKmeansTest extends AutomatedTestBase {
                // rows have to be even and > 1
                return Arrays.asList(new Object[][] {
                        {10000, 10, 1}, {2000, 50, 1}, {1000, 100, 1},
-                       //TODO support for multi-threaded federated interactions
-                       //{10000, 10, 16}, {2000, 50, 16}, {1000, 100, 16}, 
//concurrent requests
+                       {10000, 10, 4}, {2000, 50, 4}, {1000, 100, 4}, 
//concurrent requests
                });
        }
 

Reply via email to