This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 46cdf83  [SYSTEMDS-2629,2630] Extended federated backend (r', perf, 
correctness)
46cdf83 is described below

commit 46cdf8374cac8122eca3368776945a3324beb1cb
Author: Matthias Boehm <[email protected]>
AuthorDate: Thu Aug 20 00:59:49 2020 +0200

    [SYSTEMDS-2629,2630] Extended federated backend (r', perf, correctness)
    
    This patch adds a federated transpose instruction, support for aligned
    federated-federated matrix multiplication, and now explicitly checks and
    maintains the partitioning scheme of federated matrices.
---
 src/main/java/org/apache/sysds/conf/DMLConfig.java |  4 +-
 .../controlprogram/caching/CacheableData.java      |  5 ++
 .../controlprogram/federated/FederatedData.java    |  7 +-
 .../controlprogram/federated/FederatedRange.java   | 10 +++
 .../controlprogram/federated/FederatedWorker.java  |  3 +-
 .../federated/FederatedWorkerHandler.java          | 33 +++++----
 .../controlprogram/federated/FederationMap.java    | 78 +++++++++++++++++++---
 .../fed/AggregateBinaryFEDInstruction.java         | 21 +++++-
 .../instructions/fed/AppendFEDInstruction.java     |  5 +-
 .../runtime/instructions/fed/FEDInstruction.java   |  1 +
 .../instructions/fed/FEDInstructionUtils.java      |  9 ++-
 .../instructions/fed/InitFEDInstruction.java       | 24 ++++---
 .../instructions/fed/ReorgFEDInstruction.java      | 69 +++++++++++++++++++
 .../federated/FederatedUrlParserTest.java          |  2 +-
 .../federated/FederatedConstructionTest.java       |  1 +
 .../functions/federated/FederatedKmeansTest.java   |  5 +-
 .../functions/federated/FederatedL2SVMTest.java    |  3 +-
 .../test/functions/federated/FederatedPCATest.java |  3 +-
 .../functions/federated/FederatedKmeansTest.dml    |  2 +-
 .../federated/FederatedKmeansTestReference.dml     |  2 +-
 20 files changed, 233 insertions(+), 54 deletions(-)

diff --git a/src/main/java/org/apache/sysds/conf/DMLConfig.java 
b/src/main/java/org/apache/sysds/conf/DMLConfig.java
index 653d7eb..74d4457 100644
--- a/src/main/java/org/apache/sysds/conf/DMLConfig.java
+++ b/src/main/java/org/apache/sysds/conf/DMLConfig.java
@@ -89,8 +89,8 @@ public class DMLConfig
        public static final String PRINT_GPU_MEMORY_INFO = 
"sysds.gpu.print.memoryInfo";
        public static final String EVICTION_SHADOW_BUFFERSIZE = 
"sysds.gpu.eviction.shadow.bufferSize";
 
-       public static final String DEFAULT_FEDERATED_PORT = "4040"; // borrowed 
default Spark Port
-       public static final String DEFAULT_NUMBER_OF_FEDERATED_WORKER_THREADS = 
"1";
+       public static final int DEFAULT_FEDERATED_PORT = 4040; // borrowed 
default Spark Port
+       public static final int DEFAULT_NUMBER_OF_FEDERATED_WORKER_THREADS = 2;
        
        //internal config
        public static final String DEFAULT_SHARED_DIR_PERMISSION = "777"; //for 
local fs and DFS
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
index 949e60a..f7e893f 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
@@ -33,6 +33,7 @@ import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.caching.LazyWriteBuffer.RPolicy;
 import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
 import 
org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
 import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence;
 import org.apache.sysds.runtime.instructions.cp.Data;
@@ -334,6 +335,10 @@ public abstract class CacheableData<T extends CacheBlock> 
extends Data
                return _fedMapping != null;
        }
        
+       public boolean isFederated(FType type) {
+               return isFederated() && _fedMapping.getType() == type;
+       }
+       
        /**
         * Gets the mapping of indices ranges to federated objects.
         * @return fedMapping mapping
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
index 2c5f902..296e6f2 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
@@ -49,7 +49,7 @@ public class FederatedData {
         * The ID of default matrix/tensor on which operations get executed if 
no other ID is given.
         */
        private long _varID = -1; // -1 is never valid since varIDs start at 0
-       private int _nrThreads = 
Integer.parseInt(DMLConfig.DEFAULT_NUMBER_OF_FEDERATED_WORKER_THREADS);
+       private int _nrThreads = 
DMLConfig.DEFAULT_NUMBER_OF_FEDERATED_WORKER_THREADS;
 
 
        public FederatedData(Types.DataType dataType, InetSocketAddress 
address, String filepath) {
@@ -88,6 +88,11 @@ public class FederatedData {
                return _varID != -1;
        }
        
+       boolean equalAddress(FederatedData that) {
+               return _address != null && that != null && that._address != 
null 
+                       && _address.equals(that._address);
+       }
+       
        public synchronized Future<FederatedResponse> initFederatedData(long 
id) {
                if(isInitialized())
                        throw new DMLRuntimeException("Tried to init already 
initialized data");
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java
index 46ebce2..23d0269 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java
@@ -109,4 +109,14 @@ public class FederatedRange implements 
Comparable<FederatedRange> {
                _endDims[1] += cshift;
                return this;
        }
+       
+       public FederatedRange transpose() {
+               long tmpBeg = _beginDims[0];
+               long tmpEnd = _endDims[0];
+               _beginDims[0] = _beginDims[1];
+               _endDims[0] = _endDims[1];
+               _beginDims[1] = tmpBeg;
+               _endDims[1] = tmpEnd;
+               return this;
+       }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorker.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorker.java
index dae75e4..c51254b 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorker.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorker.java
@@ -41,8 +41,7 @@ public class FederatedWorker {
        
        public FederatedWorker(int port) {
                _ecm = new ExecutionContextMap();
-               _port = (port == -1) ?
-                       Integer.parseInt(DMLConfig.DEFAULT_FEDERATED_PORT) : 
port;
+               _port = (port == -1) ? DMLConfig.DEFAULT_FEDERATED_PORT : port;
        }
 
        public void run() {
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index 00a8685..bb64acc 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -45,7 +45,6 @@ import 
org.apache.sysds.runtime.instructions.InstructionParser;
 import org.apache.sysds.runtime.instructions.cp.Data;
 import org.apache.sysds.runtime.instructions.cp.ListObject;
 import org.apache.sysds.runtime.instructions.cp.ScalarObject;
-import org.apache.sysds.runtime.io.FileFormatPropertiesCSV;
 import org.apache.sysds.runtime.io.IOUtilFunctions;
 import org.apache.sysds.runtime.meta.MatrixCharacteristics;
 import org.apache.sysds.runtime.meta.MetaDataFormat;
@@ -115,14 +114,15 @@ public class FederatedWorkerHandler extends 
ChannelInboundHandlerAdapter {
                                        return execUDF(request);
                                default:
                                        String message = String.format("Method 
%s is not supported.", method);
-                                       return new 
FederatedResponse(FederatedResponse.ResponseType.ERROR, new 
FederatedWorkerHandlerException(message));
+                                       return new 
FederatedResponse(ResponseType.ERROR,
+                                               new 
FederatedWorkerHandlerException(message));
                        }
                }
                catch (DMLPrivacyException | FederatedWorkerHandlerException 
ex) {
-                       return new 
FederatedResponse(FederatedResponse.ResponseType.ERROR, ex);
+                       return new FederatedResponse(ResponseType.ERROR, ex);
                }
                catch (Exception ex) {
-                       return new 
FederatedResponse(FederatedResponse.ResponseType.ERROR,
+                       return new FederatedResponse(ResponseType.ERROR,
                                new FederatedWorkerHandlerException("Exception 
of type "
                                + ex.getClass() + " thrown when processing 
request", ex));
                }
@@ -148,7 +148,7 @@ public class FederatedWorkerHandler extends 
ChannelInboundHandlerAdapter {
                                break;
                        default:
                                // should NEVER happen (if we keep request 
codes in sync with actual behaviour)
-                               return new 
FederatedResponse(FederatedResponse.ResponseType.ERROR,
+                               return new FederatedResponse(ResponseType.ERROR,
                                        new 
FederatedWorkerHandlerException("Could not recognize datatype"));
                }
                
@@ -161,7 +161,8 @@ public class FederatedWorkerHandler extends 
ChannelInboundHandlerAdapter {
                                try (BufferedReader br = new BufferedReader(new 
InputStreamReader(fs.open(path)))) {
                                        JSONObject mtd = JSONHelper.parse(br);
                                        if (mtd == null)
-                                               return new 
FederatedResponse(FederatedResponse.ResponseType.ERROR, new 
FederatedWorkerHandlerException("Could not parse metadata file"));
+                                               return new 
FederatedResponse(ResponseType.ERROR,
+                                                       new 
FederatedWorkerHandlerException("Could not parse metadata file"));
                                        
mc.setRows(mtd.getLong(DataExpression.READROWPARAM));
                                        
mc.setCols(mtd.getLong(DataExpression.READCOLPARAM));
                                        cd = (CacheableData<?>) 
PrivacyPropagator.parseAndSetPrivacyConstraint(cd, mtd);
@@ -172,23 +173,21 @@ public class FederatedWorkerHandler extends 
ChannelInboundHandlerAdapter {
                catch (Exception ex) {
                        throw new DMLRuntimeException(ex);
                }
-               cd.setMetaData(new MetaDataFormat(mc, fmt));
-               // TODO send FileFormatProperties with request and use them for 
CSV, this is currently a workaround so reading
-               //  of CSV files works
-               cd.setFileFormatProperties(new FileFormatPropertiesCSV());
-               cd.acquireRead();
-               cd.refreshMetaData(); //in pinned state
-               cd.release();
                
-               //TODO spawn async load of data, otherwise on first access
-               _ecm.get(tid).setVariable(String.valueOf(id), cd);
+               //put meta data object in symbol table, read on first operation
+               cd.setMetaData(new MetaDataFormat(mc, fmt));
                cd.enableCleanup(false); //guard against deletion
+               _ecm.get(tid).setVariable(String.valueOf(id), cd);
                
                if (dataType == Types.DataType.FRAME) {
                        FrameObject frameObject = (FrameObject) cd;
-                       return new 
FederatedResponse(FederatedResponse.ResponseType.SUCCESS, new Object[] {id, 
frameObject.getSchema()});
+                       frameObject.acquireRead();
+                       frameObject.refreshMetaData(); //get block schema
+                       frameObject.release();
+                       return new FederatedResponse(ResponseType.SUCCESS,
+                               new Object[] {id, frameObject.getSchema()});
                }
-               return new 
FederatedResponse(FederatedResponse.ResponseType.SUCCESS, id);
+               return new FederatedResponse(ResponseType.SUCCESS, id);
        }
        
        private FederatedResponse putVariable(FederatedRequest request) {
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
index 371c3ff..b25f8b9 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
@@ -43,26 +43,46 @@ import org.apache.sysds.runtime.util.CommonThreadPool;
 
 public class FederationMap
 {
+       public enum FType {
+               ROW, //row partitioned, groups of rows
+               COL, //column partitioned, groups of columns
+               OTHER,
+       }
+       
        private long _ID = -1;
        private final Map<FederatedRange, FederatedData> _fedMap;
+       private FType _type;
        
        public FederationMap(Map<FederatedRange, FederatedData> fedMap) {
                this(-1, fedMap);
        }
        
        public FederationMap(long ID, Map<FederatedRange, FederatedData> 
fedMap) {
+               this(ID, fedMap, FType.OTHER);
+       }
+       
+       public FederationMap(long ID, Map<FederatedRange, FederatedData> 
fedMap, FType type) {
                _ID = ID;
                _fedMap = fedMap;
+               _type = type;
        }
        
        public long getID() {
                return _ID;
        }
        
+       public FType getType() {
+               return _type;
+       }
+       
        public boolean isInitialized() {
                return _ID >= 0;
        }
        
+       public void setType(FType type) {
+               _type = type;
+       }
+       
        public FederatedRange[] getFederatedRanges() {
                return _fedMap.keySet().toArray(new FederatedRange[0]);
        }
@@ -96,6 +116,19 @@ public class FederationMap
                return ret.toArray(new FederatedRequest[0]);
        }
        
+       public boolean isAligned(FederationMap that, boolean transposed) {
+               //determines if the two federated data are aligned row/column 
partitions
+               //at the same federated site (which allows for purely federated 
operation)
+               boolean ret = true;
+               for(Entry<FederatedRange, FederatedData> e : 
_fedMap.entrySet()) {
+                       FederatedRange range = !transposed ? e.getKey() :
+                               new FederatedRange(e.getKey()).transpose();
+                       FederatedData dat2 = that._fedMap.get(range);
+                       ret &= e.getValue().equalAddress(dat2);
+               }
+               return ret;
+       }
+       
        public Future<FederatedResponse>[] execute(long tid, 
FederatedRequest... fr) {
                return execute(tid, false, fr);
        }
@@ -120,13 +153,9 @@ public class FederationMap
                
                // prepare results (future federated responses), with optional 
wait to ensure the 
                // order of requests without data dependencies (e.g., cleanup 
RPCs)
-               Future<FederatedResponse>[] ret2 = ret.toArray(new Future[0]);
-               if( wait ) {
-                       Arrays.stream(ret2).forEach(e -> {
-                               try {e.get();} catch(Exception ex) {throw new 
DMLRuntimeException(ex);}
-                       });
-               }
-               return ret2;
+               if( wait )
+                       waitFor(ret);
+               return ret.toArray(new Future[0]);
        }
        
        public List<Pair<FederatedRange, Future<FederatedResponse>>> 
requestFederatedData() {
@@ -145,8 +174,21 @@ public class FederationMap
                FederatedRequest request = new 
FederatedRequest(RequestType.EXEC_INST, -1,
                        
VariableCPInstruction.prepareRemoveInstruction(id).toString());
                request.setTID(tid);
+               List<Future<FederatedResponse>> tmp = new ArrayList<>();
                for(FederatedData fd : _fedMap.values())
-                       fd.executeFederatedOperation(request);
+                       tmp.add(fd.executeFederatedOperation(request));
+               //wait to avoid interference w/ following requests
+               waitFor(tmp);
+       }
+       
+       private static void waitFor(List<Future<FederatedResponse>> responses) {
+               try {
+                       for(Future<FederatedResponse> fr : responses)
+                               fr.get();
+               }
+               catch(Exception ex) {
+                       throw new DMLRuntimeException(ex);
+               }
        }
        
        private static FederatedRequest[] addAll(FederatedRequest a, 
FederatedRequest[] b) {
@@ -164,7 +206,7 @@ public class FederationMap
                //TODO handling of file path, but no danger as never written
                for( Entry<FederatedRange, FederatedData> e : 
_fedMap.entrySet() )
                        map.put(new FederatedRange(e.getKey()), new 
FederatedData(e.getValue(), id));
-               return new FederationMap(id, map);
+               return new FederationMap(id, map, _type);
        }
        
        public FederationMap copyWithNewID(long id, long clen) {
@@ -183,6 +225,24 @@ public class FederationMap
                }
                return this;
        }
+       
+       public FederationMap transpose() {
+               Map<FederatedRange, FederatedData> tmp = new TreeMap<>(_fedMap);
+               _fedMap.clear();
+               for( Entry<FederatedRange, FederatedData> e : tmp.entrySet() ) {
+                       _fedMap.put(
+                               new FederatedRange(e.getKey()).transpose(),
+                               new FederatedData(e.getValue(), _ID));
+               }
+               //derive output type
+               switch(_type) {
+                       case ROW: _type = FType.COL; break;
+                       case COL: _type = FType.ROW; break;
+                       default: _type = FType.OTHER;
+               }
+               return this;
+       }
+
 
        /**
         * Execute a function for each <code>FederatedRange</code> + 
<code>FederatedData</code> pair. The function should
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
index 6fd6173..34caec2 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
@@ -26,6 +26,7 @@ import 
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
 import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
 import 
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
@@ -61,7 +62,19 @@ public class AggregateBinaryFEDInstruction extends 
BinaryFEDInstruction {
                MatrixObject mo2 = ec.getMatrixObject(input2);
                
                //#1 federated matrix-vector multiplication
-               if(mo1.isFederated()) { // MV + MM
+               if(mo1.isFederated(FType.COL) && mo2.isFederated(FType.ROW)
+                       && mo1.getFedMapping().isAligned(mo2.getFedMapping(), 
true) ) {
+                       FederatedRequest fr1 = 
FederationUtils.callInstruction(instString, output,
+                               new CPOperand[]{input1, input2},
+                               new long[]{mo1.getFedMapping().getID(), 
mo2.getFedMapping().getID()});
+                       FederatedRequest fr2 = new 
FederatedRequest(RequestType.GET_VAR, fr1.getID());
+                       //execute federated operations and aggregate
+                       Future<FederatedResponse>[] tmp = 
mo1.getFedMapping().execute(getTID(), fr1, fr2);
+                       MatrixBlock ret = FederationUtils.aggAdd(tmp);
+                       mo2.getFedMapping().cleanup(getTID(), fr1.getID(), 
fr2.getID());
+                       ec.setMatrixOutput(output.getName(), ret);
+               }
+               else if(mo1.isFederated(FType.ROW)) { // MV + MM
                        //construct commands: broadcast rhs, fed mv, retrieve 
results
                        FederatedRequest fr1 = 
mo1.getFedMapping().broadcast(mo2);
                        FederatedRequest fr2 = 
FederationUtils.callInstruction(instString, output,
@@ -81,10 +94,11 @@ public class AggregateBinaryFEDInstruction extends 
BinaryFEDInstruction {
                                MatrixObject out = ec.getMatrixObject(output);
                                
out.getDataCharacteristics().set(mo1.getNumRows(), mo2.getNumColumns(), 
(int)mo1.getBlocksize());
                                
out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr2.getID(), 
mo2.getNumColumns()));
+                               out.getFedMapping().setType(FType.ROW);
                        }
                }
                //#2 vector - federated matrix multiplication
-               else if (mo2.isFederated()) {// VM + MM
+               else if (mo2.isFederated(FType.ROW)) {// VM + MM
                        //construct commands: broadcast rhs, fed mv, retrieve 
results
                        FederatedRequest[] fr1 = 
mo2.getFedMapping().broadcastSliced(mo1, true);
                        FederatedRequest fr2 = 
FederationUtils.callInstruction(instString, output,
@@ -98,7 +112,8 @@ public class AggregateBinaryFEDInstruction extends 
BinaryFEDInstruction {
                }
                else { //other combinations
                        throw new DMLRuntimeException("Federated 
AggregateBinary not supported with the "
-                               + "following federated objects: 
"+mo1.isFederated()+" "+mo2.isFederated());
+                               + "following federated objects: 
"+mo1.isFederated()+":"+mo1.getFedMapping()
+                               +" "+mo2.isFederated()+":"+mo2.getFedMapping());
                }
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AppendFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AppendFEDInstruction.java
index 985d117..d17b7b5 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AppendFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AppendFEDInstruction.java
@@ -23,6 +23,7 @@ import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
 import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
 import org.apache.sysds.runtime.functionobjects.OffsetColumnIndex;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
@@ -76,7 +77,7 @@ public class AppendFEDInstruction extends 
BinaryFEDInstruction {
                                + " vs " + mo2.getNumColumns());
                }
                
-               if( mo1.isFederated() && _cbind ) {
+               if( mo1.isFederated(FType.ROW) && _cbind ) {
                        FederatedRequest fr1 = 
mo1.getFedMapping().broadcast(mo2);
                        FederatedRequest fr2 = 
FederationUtils.callInstruction(instString, output,
                                new CPOperand[]{input1, input2}, new 
long[]{mo1.getFedMapping().getID(), fr1.getID()});
@@ -87,7 +88,7 @@ public class AppendFEDInstruction extends 
BinaryFEDInstruction {
                                dc1.getBlocksize(), 
dc1.getNonZeros()+dc2.getNonZeros());
                        
out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr2.getID()));
                }
-               else if( mo1.isFederated() && mo2.isFederated() && !_cbind ) {
+               else if( mo1.isFederated(FType.ROW) && 
mo2.isFederated(FType.ROW) && !_cbind ) {
                        MatrixObject out = ec.getMatrixObject(output);
                        
out.getDataCharacteristics().set(dc1.getRows()+dc2.getRows(), dc1.getCols(),
                                dc1.getBlocksize(), 
dc1.getNonZeros()+dc2.getNonZeros());
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
index 77dedfd..292702e 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
@@ -36,6 +36,7 @@ public abstract class FEDInstruction extends Instruction {
                ParameterizedBuiltin,
                Tsmm,
                MMChain,
+               Reorg,
        }
        
        protected final FEDType _fedType;
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index bbdaa8e..c8cf729 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -22,6 +22,7 @@ package org.apache.sysds.runtime.instructions.fed;
 import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
 import org.apache.sysds.runtime.instructions.Instruction;
 import org.apache.sysds.runtime.instructions.cp.*;
 import org.apache.sysds.runtime.instructions.spark.AggregateUnarySPInstruction;
@@ -40,7 +41,7 @@ public class FEDInstructionUtils {
                        if( instruction.input1.isMatrix() && 
instruction.input2.isMatrix() ) {
                                MatrixObject mo1 = 
ec.getMatrixObject(instruction.input1);
                                MatrixObject mo2 = 
ec.getMatrixObject(instruction.input2);
-                               if (mo1.isFederated() || mo2.isFederated()) {
+                               if (mo1.isFederated(FType.ROW) || 
mo2.isFederated(FType.ROW)) {
                                        fedinst = 
AggregateBinaryFEDInstruction.parseInstruction(inst.getInstructionString());
                                }
                        }
@@ -89,6 +90,12 @@ public class FEDInstructionUtils {
                                }
                        }
                }
+               else if(inst instanceof ReorgCPInstruction && 
inst.getOpcode().equals("r'")) {
+                       ReorgCPInstruction rinst = (ReorgCPInstruction) inst;
+                       CacheableData<?> mo = ec.getCacheableData(rinst.input1);
+                       if( mo.isFederated() )
+                               fedinst = 
ReorgFEDInstruction.parseInstruction(rinst.getInstructionString());
+               }
                
                //set thread id for federated context management
                if( fedinst != null ) {
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
index 8d050b3..9ae5014 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
@@ -33,6 +33,7 @@ import 
org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
 import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
 import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
 import org.apache.sysds.runtime.instructions.cp.Data;
@@ -175,8 +176,8 @@ public class InitFEDInstruction extends FEDInstruction {
                        String ipRegex = 
"^(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$";
                        if (host.matches("^\\d+\\.\\d+\\.\\d+\\.\\d+$") && 
!host.matches(ipRegex))
                                throw new IllegalArgumentException("Input Host 
address looks like an IP address but is outside range");
-                       String port = Integer.toString(address.getPort());
-                       if (port.equals("-1"))
+                       int port = address.getPort();
+                       if (port == -1)
                                port = DMLConfig.DEFAULT_FEDERATED_PORT;
                        String filePath = address.getPath();
                        if (filePath.length() <= 1)
@@ -193,7 +194,7 @@ public class InitFEDInstruction extends FEDInstruction {
                        if (address.getRef() != null)
                                throw new IllegalArgumentException("Reference 
is not supported");
                        
-                       return new String[] { host, port, filePath };
+                       return new String[] { host, String.valueOf(port), 
filePath };
                }
                catch (MalformedURLException e) {
                        throw new IllegalArgumentException("federated address 
`" + input
@@ -208,6 +209,8 @@ public class InitFEDInstruction extends FEDInstruction {
                }
                List<Pair<FederatedData, Future<FederatedResponse>>> 
idResponses = new ArrayList<>();
                long id = FederationUtils.getNextFedDataID();
+               boolean rowPartitioned = true;
+               boolean colPartitioned = true;
                for (Map.Entry<FederatedRange, FederatedData> entry : 
fedMapping.entrySet()) {
                        FederatedRange range = entry.getKey();
                        FederatedData value = entry.getValue();
@@ -215,25 +218,24 @@ public class InitFEDInstruction extends FEDInstruction {
                                long[] beginDims = range.getBeginDims();
                                long[] endDims = range.getEndDims();
                                long[] dims = 
output.getDataCharacteristics().getDims();
-                               for (int i = 0; i < dims.length; i++) {
+                               for (int i = 0; i < dims.length; i++)
                                        dims[i] = endDims[i] - beginDims[i];
-                               }
-                               // TODO check if all matrices have the same 
DataType (currently only double is supported)
                                idResponses.add(new ImmutablePair<>(value, 
value.initFederatedData(id)));
                        }
+                       rowPartitioned &= (range.getSize(1) == 
output.getNumColumns());
+                       colPartitioned &= (range.getSize(0) == 
output.getNumRows()); 
                }
                try {
-                       for (Pair<FederatedData, Future<FederatedResponse>> 
idResponse : idResponses) {
-                               FederatedResponse response = 
idResponse.getRight().get();
-                               idResponse.getLeft().setVarID((Long) 
response.getData()[0]);
-                       }
+                       for (Pair<FederatedData, Future<FederatedResponse>> 
idResponse : idResponses)
+                               idResponse.getRight().get(); //wait for 
initialization
                }
                catch (Exception e) {
                        throw new DMLRuntimeException("Federation 
initialization failed", e);
                }
-               
output.getDataCharacteristics().setNonZeros(output.getNumColumns() * 
output.getNumRows());
+               output.getDataCharacteristics().setNonZeros(-1);
                
output.getDataCharacteristics().setBlocksize(ConfigurationManager.getBlocksize());
                output.setFedMapping(new FederationMap(id, fedMapping));
+               output.getFedMapping().setType(rowPartitioned ? FType.ROW : 
colPartitioned ? FType.COL : FType.OTHER);
        }
        
        public void federateFrame(FrameObject output, List<Pair<FederatedRange, 
FederatedData>> workers) {
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
new file mode 100644
index 0000000..a4b604b
--- /dev/null
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.fed;
+
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+
+public class ReorgFEDInstruction extends UnaryFEDInstruction {
+       
+       public ReorgFEDInstruction(CPOperand in1, CPOperand out, String opcode, 
String istr) {
+               super(FEDType.Reorg, null, in1, out, opcode, istr);
+       }
+
+       public static ReorgFEDInstruction parseInstruction ( String str ) {
+               String[] parts = 
InstructionUtils.getInstructionPartsWithValueType(str);
+               String opcode = parts[0];
+               if ( opcode.equalsIgnoreCase("r'") ) {
+                       InstructionUtils.checkNumFields(str, 2, 3);
+                       CPOperand in = new CPOperand(parts[1]);
+                       CPOperand out = new CPOperand(parts[2]);
+                       return new ReorgFEDInstruction(in, out, opcode, str);
+               }
+               else {
+                       throw new DMLRuntimeException("ReorgFEDInstruction: 
unsupported opcode: "+opcode);
+               }
+       }
+       
+       @Override
+       public void processInstruction(ExecutionContext ec) {
+               MatrixObject mo1 = ec.getMatrixObject(input1);
+               
+               if( !mo1.isFederated() )
+                       throw new DMLRuntimeException("Federated Reorg: "
+                               + "Federated input expected, but invoked w/ 
"+mo1.isFederated());
+       
+               //execute transpose at federated site
+               FederatedRequest fr1 = 
FederationUtils.callInstruction(instString, output,
+                       new CPOperand[]{input1}, new 
long[]{mo1.getFedMapping().getID()});
+               mo1.getFedMapping().execute(getTID(), true, fr1);
+               
+               //drive output federated mapping
+               MatrixObject out = ec.getMatrixObject(output);
+               out.getDataCharacteristics().set(mo1.getNumColumns(),
+                       mo1.getNumRows(), (int)mo1.getBlocksize(), 
mo1.getNnz());
+               
out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr1.getID()).transpose());
+       }
+}
diff --git 
a/src/test/java/org/apache/sysds/test/component/federated/FederatedUrlParserTest.java
 
b/src/test/java/org/apache/sysds/test/component/federated/FederatedUrlParserTest.java
index 3a38c13..edcc477 100644
--- 
a/src/test/java/org/apache/sysds/test/component/federated/FederatedUrlParserTest.java
+++ 
b/src/test/java/org/apache/sysds/test/component/federated/FederatedUrlParserTest.java
@@ -161,7 +161,7 @@ public class FederatedUrlParserTest
 
        @Test
        public void checkDefaultPortIsValid() {
-               int defaultPort = 
Integer.parseInt(DMLConfig.DEFAULT_FEDERATED_PORT);
+               int defaultPort = DMLConfig.DEFAULT_FEDERATED_PORT;
                // The highest port number allowed.
                int IANA_limit = 49152;
                assertTrue(defaultPort <= IANA_limit);
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/FederatedConstructionTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/FederatedConstructionTest.java
index aa88027..8125bfe 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/FederatedConstructionTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/FederatedConstructionTest.java
@@ -124,6 +124,7 @@ public class FederatedConstructionTest extends 
AutomatedTestBase {
 
                TestConfiguration config = 
availableTestConfigurations.get(TEST_NAME);
                loadTestConfiguration(config);
+               setOutputBuffering(false);
 
                // we need the reference file to not be written to hdfs, so we 
get the correct format
                rtplatform = Types.ExecMode.SINGLE_NODE;
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/FederatedKmeansTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/FederatedKmeansTest.java
index 6991797..7da40ab 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/FederatedKmeansTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/FederatedKmeansTest.java
@@ -62,7 +62,9 @@ public class FederatedKmeansTest extends AutomatedTestBase {
                // rows have to be even and > 1
                return Arrays.asList(new Object[][] {
                        {10000, 10, 1}, {2000, 50, 1}, {1000, 100, 1},
-                       {10000, 10, 4}, {2000, 50, 4}, {1000, 100, 4}, 
//concurrent requests
+                       {10000, 10, 2}, {2000, 50, 2}, {1000, 100, 2}, 
//concurrent requests
+                       //TODO more runs e.g., 16 -> but requires rework RPC 
framework first
+                       //(e.g., see paramserv?)
                });
        }
 
@@ -127,6 +129,7 @@ public class FederatedKmeansTest extends AutomatedTestBase {
                Assert.assertTrue(heavyHittersContainsString("fed_+"));
                Assert.assertTrue(heavyHittersContainsString("fed_<="));
                Assert.assertTrue(heavyHittersContainsString("fed_/"));
+               Assert.assertTrue(heavyHittersContainsString("fed_r'"));
                
                //check that federated input files are still existing
                Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/FederatedL2SVMTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/FederatedL2SVMTest.java
index e55cfc9..4cfc70e 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/FederatedL2SVMTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/FederatedL2SVMTest.java
@@ -103,7 +103,8 @@ public class FederatedL2SVMTest extends AutomatedTestBase {
 
                TestConfiguration config = 
availableTestConfigurations.get(TEST_NAME);
                loadTestConfiguration(config);
-
+               setOutputBuffering(false);
+               
                // Run reference dml script with normal matrix
                fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
                programArgs = new String[] {"-args", input("X1"), input("X2"), 
input("Y"), expected("Z")};
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/FederatedPCATest.java 
b/src/test/java/org/apache/sysds/test/functions/federated/FederatedPCATest.java
index 53eac1e..906b124 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/FederatedPCATest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/FederatedPCATest.java
@@ -99,10 +99,11 @@ public class FederatedPCATest extends AutomatedTestBase {
 
                TestConfiguration config = 
availableTestConfigurations.get(TEST_NAME);
                loadTestConfiguration(config);
+               setOutputBuffering(false);
                
                // Run reference dml script with normal matrix
                fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
-               programArgs = new String[] {"-args", input("X1"), input("X2"),
+               programArgs = new String[] {"-stats", "-args", input("X1"), 
input("X2"),
                        String.valueOf(scaleAndShift).toUpperCase(), 
expected("Z")};
                runTest(true, false, null, -1);
 
diff --git a/src/test/scripts/functions/federated/FederatedKmeansTest.dml 
b/src/test/scripts/functions/federated/FederatedKmeansTest.dml
index 95f136c..13e89ea 100644
--- a/src/test/scripts/functions/federated/FederatedKmeansTest.dml
+++ b/src/test/scripts/functions/federated/FederatedKmeansTest.dml
@@ -21,5 +21,5 @@
 
 X = federated(addresses=list($in_X1, $in_X2),
     ranges=list(list(0, 0), list($rows / 2, $cols), list($rows / 2, 0), 
list($rows, $cols)))
-[C,Y] = kmeans(X=X, k=4, runs=$runs)
+[C,Y] = kmeans(X=X, k=4, runs=$runs, max_iter=150)
 write(C, $out)
diff --git 
a/src/test/scripts/functions/federated/FederatedKmeansTestReference.dml 
b/src/test/scripts/functions/federated/FederatedKmeansTestReference.dml
index da32c8b..e72c9b5 100644
--- a/src/test/scripts/functions/federated/FederatedKmeansTestReference.dml
+++ b/src/test/scripts/functions/federated/FederatedKmeansTestReference.dml
@@ -20,5 +20,5 @@
 #-------------------------------------------------------------
 
 X = rbind(read($1), read($2))
-[C,Y] = kmeans(X=X, k=4, runs=$3)
+[C,Y] = kmeans(X=X, k=4, runs=$3, max_iter=150)
 write(C, $4)

Reply via email to