This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 2a8cb78  [SYSTEMDS-2730] Modified fed removeEmpty
2a8cb78 is described below

commit 2a8cb78827daed00fe016f6af22ab24f154be40c
Author: Olga <[email protected]>
AuthorDate: Tue Nov 17 21:21:55 2020 +0100

    [SYSTEMDS-2730] Modified fed removeEmpty
    
    This commits change the remove empty federated command, to
    among other things improve the split function performance.
    
    Closes #1109
---
 scripts/builtin/split.dml                          |  13 +-
 .../controlprogram/federated/FederationMap.java    | 204 +++++++++---------
 .../fed/ParameterizedBuiltinFEDInstruction.java    | 234 +++++++++------------
 .../primitives/FederatedRemoveEmptyTest.java       |  10 +-
 .../federated/primitives/FederatedSplitTest.java   |   8 +-
 .../functions/federated/FederatedSplitTest.dml     |   3 +-
 .../federated/FederatedSplitTestReference.dml      |   2 +-
 7 files changed, 217 insertions(+), 257 deletions(-)

diff --git a/scripts/builtin/split.dml b/scripts/builtin/split.dml
index 5e6f1c5..c5c1066 100644
--- a/scripts/builtin/split.dml
+++ b/scripts/builtin/split.dml
@@ -53,12 +53,13 @@ m_split = function(Matrix[Double] X, Matrix[Double] Y, 
Double f=0.7, Boolean con
   }
   # sampled train/test splits
   else {
+    # create random select vector according to f and then
+    # extract tuples via permutation (selection) matrix multiply
+    # or directly via removeEmpty by selection vector
     I = rand(rows=nrow(X), cols=1, seed=seed) <= f;
-    P1 = removeEmpty(target=diag(I), margin="rows", select=I);
-    P2 = removeEmpty(target=diag(I==0), margin="rows", select=I==0);
-    Xtrain = P1 %*% X;
-    Ytrain = P1 %*% Y;
-    Xtest = P2 %*% X;
-    Ytest = P2 %*% Y;
+    Xtrain = removeEmpty(target=X, margin="rows", select=I);
+    Ytrain = removeEmpty(target=Y, margin="rows", select=I);
+    Xtest = removeEmpty(target=X, margin="rows", select=(I==0));
+    Ytest = removeEmpty(target=Y, margin="rows", select=(I==0));
   }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
index 037ce8c..2ce3cb7 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
@@ -43,12 +43,11 @@ import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.util.CommonThreadPool;
 import org.apache.sysds.runtime.util.IndexRange;
 
-public class FederationMap
-{
+public class FederationMap {
        public enum FType {
-               ROW, //row partitioned, groups of rows
-               COL, //column partitioned, groups of columns
-               FULL,  // Meaning both Row and Column indicating a single 
federated location and a full matrix
+               ROW, // row partitioned, groups of rows
+               COL, // column partitioned, groups of columns
+               FULL, // Meaning both Row and Column indicating a single 
federated location and a full matrix
                OTHER;
 
                public boolean isRowPartitioned() {
@@ -56,11 +55,11 @@ public class FederationMap
                }
 
                public boolean isColPartitioned() {
-                       return this == ROW || this == FULL;
+                       return this == COL || this == FULL;
                }
 
-               public boolean isType(FType t){
-                       switch (t) {
+               public boolean isType(FType t) {
+                       switch(t) {
                                case ROW:
                                        return isRowPartitioned();
                                case COL:
@@ -72,161 +71,159 @@ public class FederationMap
                        }
                }
        }
-       
+
        private long _ID = -1;
        private final Map<FederatedRange, FederatedData> _fedMap;
        private FType _type;
-       
+
        public FederationMap(Map<FederatedRange, FederatedData> fedMap) {
                this(-1, fedMap);
        }
-       
+
        public FederationMap(long ID, Map<FederatedRange, FederatedData> 
fedMap) {
                this(ID, fedMap, FType.OTHER);
        }
-       
+
        public FederationMap(long ID, Map<FederatedRange, FederatedData> 
fedMap, FType type) {
                _ID = ID;
                _fedMap = fedMap;
                _type = type;
        }
-       
+
        public long getID() {
                return _ID;
        }
-       
+
        public FType getType() {
                return _type;
        }
-       
+
        public boolean isInitialized() {
                return _ID >= 0;
        }
-       
+
        public void setType(FType type) {
                _type = type;
        }
-       
+
        public int getSize() {
                return _fedMap.size();
        }
-       
+
        public FederatedRange[] getFederatedRanges() {
                return _fedMap.keySet().toArray(new FederatedRange[0]);
        }
 
-       public Map<FederatedRange, FederatedData> getFedMapping(){
+       public Map<FederatedRange, FederatedData> getFedMapping() {
                return _fedMap;
        }
-       
+
        public FederatedRequest broadcast(CacheableData<?> data) {
-               //prepare single request for all federated data
+               // prepare single request for all federated data
                long id = FederationUtils.getNextFedDataID();
                CacheBlock cb = data.acquireReadAndRelease();
                return new FederatedRequest(RequestType.PUT_VAR, id, cb);
        }
-       
+
        public FederatedRequest broadcast(ScalarObject scalar) {
-               //prepare single request for all federated data
+               // prepare single request for all federated data
                long id = FederationUtils.getNextFedDataID();
                return new FederatedRequest(RequestType.PUT_VAR, id, scalar);
        }
-       
+
        /**
-        * Creates separate slices of an input data object according
-        * to the index ranges of federated data. Theses slices are then
-        * wrapped in separate federated requests for broadcasting.
+        * Creates separate slices of an input data object according to the 
index ranges of federated data. Theses slices
+        * are then wrapped in separate federated requests for broadcasting.
         * 
-        * @param data input data object (matrix, tensor, frame)
-        * @param transposed false: slice according to federated data,
-        *                   true: slice according to transposed federated data
+        * @param data       input data object (matrix, tensor, frame)
+        * @param transposed false: slice according to federated data, true: 
slice according to transposed federated data
         * @return array of federated requests corresponding to federated data
         */
        public FederatedRequest[] broadcastSliced(CacheableData<?> data, 
boolean transposed) {
-               //prepare broadcast id and pin input
+               // prepare broadcast id and pin input
                long id = FederationUtils.getNextFedDataID();
                CacheBlock cb = data.acquireReadAndRelease();
-               
-               //prepare indexing ranges
+
+               // prepare indexing ranges
                int[][] ix = new int[_fedMap.size()][];
                int pos = 0;
                for(Entry<FederatedRange, FederatedData> e : 
_fedMap.entrySet()) {
-                       int rl = transposed ? 0 : 
e.getKey().getBeginDimsInt()[0];
-                       int ru = transposed ? cb.getNumRows()-1 : 
e.getKey().getEndDimsInt()[0]-1;
-                       int cl = transposed ? e.getKey().getBeginDimsInt()[0] : 
0;
-                       int cu = transposed ? e.getKey().getEndDimsInt()[0]-1 : 
cb.getNumColumns()-1;
+                       int rl, ru, cl, cu;
+                       // TODO Handle different cases than ROW aligned 
Matrices.
+                       rl = transposed ? 0 : e.getKey().getBeginDimsInt()[0];
+                       ru = transposed ? cb.getNumRows() - 1 : 
e.getKey().getEndDimsInt()[0] - 1;
+                       cl = transposed ? e.getKey().getBeginDimsInt()[0] : 0;
+                       cu = transposed ? e.getKey().getEndDimsInt()[0] - 1 : 
cb.getNumColumns() - 1;
                        ix[pos++] = new int[] {rl, ru, cl, cu};
                }
-               
-               //multi-threaded block slicing and federation request creation
+
+               // multi-threaded block slicing and federation request creation
                FederatedRequest[] ret = new FederatedRequest[ix.length];
-               Arrays.parallelSetAll(ret, i ->
-                       new FederatedRequest(RequestType.PUT_VAR, id,
-                       cb.slice(ix[i][0], ix[i][1], ix[i][2], ix[i][3], new 
MatrixBlock())));
+               Arrays.parallelSetAll(ret,
+                       i -> new FederatedRequest(RequestType.PUT_VAR, id,
+                               cb.slice(ix[i][0], ix[i][1], ix[i][2], 
ix[i][3], new MatrixBlock())));
                return ret;
        }
-       
+
        public boolean isAligned(FederationMap that, boolean transposed) {
-               //determines if the two federated data are aligned row/column 
partitions
-               //at the same federated site (which allows for purely federated 
operation)
+               // determines if the two federated data are aligned row/column 
partitions
+               // at the same federated site (which allows for purely 
federated operation)
                boolean ret = true;
                for(Entry<FederatedRange, FederatedData> e : 
_fedMap.entrySet()) {
-                       FederatedRange range = !transposed ? e.getKey() :
-                               new FederatedRange(e.getKey()).transpose();
+                       FederatedRange range = !transposed ? e.getKey() : new 
FederatedRange(e.getKey()).transpose();
                        FederatedData dat2 = that._fedMap.get(range);
                        ret &= e.getValue().equalAddress(dat2);
                }
                return ret;
        }
-       
+
        public Future<FederatedResponse>[] execute(long tid, 
FederatedRequest... fr) {
                return execute(tid, false, fr);
        }
-       
+
        public Future<FederatedResponse>[] execute(long tid, boolean wait, 
FederatedRequest... fr) {
                return execute(tid, wait, null, fr);
        }
-       
+
        public Future<FederatedResponse>[] execute(long tid, FederatedRequest[] 
frSlices, FederatedRequest... fr) {
                return execute(tid, false, frSlices, fr);
        }
-       
+
        @SuppressWarnings("unchecked")
-       public Future<FederatedResponse>[] execute(long tid, boolean wait, 
FederatedRequest[] frSlices, FederatedRequest... fr) {
+       public Future<FederatedResponse>[] execute(long tid, boolean wait, 
FederatedRequest[] frSlices,
+               FederatedRequest... fr) {
                // executes step1[] - step 2 - ... step4 (only first step 
federated-data-specific)
                setThreadID(tid, frSlices, fr);
-               List<Future<FederatedResponse>> ret = new ArrayList<>(); 
+               List<Future<FederatedResponse>> ret = new ArrayList<>();
                int pos = 0;
                for(Entry<FederatedRange, FederatedData> e : _fedMap.entrySet())
-                       ret.add(e.getValue().executeFederatedOperation(
-                               (frSlices!=null) ? addAll(frSlices[pos++], fr) 
: fr));
-               
-               // prepare results (future federated responses), with optional 
wait to ensure the 
+                       
ret.add(e.getValue().executeFederatedOperation((frSlices != null) ? 
addAll(frSlices[pos++], fr) : fr));
+
+               // prepare results (future federated responses), with optional 
wait to ensure the
                // order of requests without data dependencies (e.g., cleanup 
RPCs)
-               if( wait )
+               if(wait)
                        FederationUtils.waitFor(ret);
                return ret.toArray(new Future[0]);
        }
-       
+
        public List<Pair<FederatedRange, Future<FederatedResponse>>> 
requestFederatedData() {
-               if( !isInitialized() )
+               if(!isInitialized())
                        throw new DMLRuntimeException("Federated matrix read 
only supported on initialized FederatedData");
-               
+
                List<Pair<FederatedRange, Future<FederatedResponse>>> 
readResponses = new ArrayList<>();
                FederatedRequest request = new 
FederatedRequest(RequestType.GET_VAR, _ID);
                for(Map.Entry<FederatedRange, FederatedData> e : 
_fedMap.entrySet())
-                       readResponses.add(new ImmutablePair<>(e.getKey(), 
-                               
e.getValue().executeFederatedOperation(request)));
+                       readResponses.add(new ImmutablePair<>(e.getKey(), 
e.getValue().executeFederatedOperation(request)));
                return readResponses;
        }
-       
+
        public FederatedRequest cleanup(long tid, long... id) {
                FederatedRequest request = new 
FederatedRequest(RequestType.EXEC_INST, -1,
                        
VariableCPInstruction.prepareRemoveInstruction(id).toString());
                request.setTID(tid);
                return request;
        }
-       
+
        public void execCleanup(long tid, long... id) {
                FederatedRequest request = new 
FederatedRequest(RequestType.EXEC_INST, -1,
                        
VariableCPInstruction.prepareRemoveInstruction(id).toString());
@@ -234,16 +231,17 @@ public class FederationMap
                List<Future<FederatedResponse>> tmp = new ArrayList<>();
                for(FederatedData fd : _fedMap.values())
                        tmp.add(fd.executeFederatedOperation(request));
-               //wait to avoid interference w/ following requests
+               // wait to avoid interference w/ following requests
                FederationUtils.waitFor(tmp);
        }
-       
+
        private static FederatedRequest[] addAll(FederatedRequest a, 
FederatedRequest[] b) {
                FederatedRequest[] ret = new FederatedRequest[b.length + 1];
-               ret[0] = a; System.arraycopy(b, 0, ret, 1, b.length);
+               ret[0] = a;
+               System.arraycopy(b, 0, ret, 1, b.length);
                return ret;
        }
-       
+
        public FederationMap identCopy(long tid, long id) {
                Future<FederatedResponse>[] copyInstr = execute(tid,
                        new FederatedRequest(RequestType.EXEC_INST, _ID,
@@ -262,25 +260,25 @@ public class FederationMap
                copyFederationMap._type = _type;
                return copyFederationMap;
        }
-       
+
        public FederationMap copyWithNewID() {
                return copyWithNewID(FederationUtils.getNextFedDataID());
        }
-       
+
        public FederationMap copyWithNewID(long id) {
                Map<FederatedRange, FederatedData> map = new TreeMap<>();
-               //TODO handling of file path, but no danger as never written
-               for( Entry<FederatedRange, FederatedData> e : 
_fedMap.entrySet() ) {
+               // TODO handling of file path, but no danger as never written
+               for(Entry<FederatedRange, FederatedData> e : 
_fedMap.entrySet()) {
                        if(e.getKey().getSize() != 0)
                                map.put(new FederatedRange(e.getKey()), 
e.getValue().copyWithNewID(id));
                }
                return new FederationMap(id, map, _type);
        }
-       
+
        public FederationMap copyWithNewID(long id, long clen) {
                Map<FederatedRange, FederatedData> map = new TreeMap<>();
-               //TODO handling of file path, but no danger as never written
-               for( Entry<FederatedRange, FederatedData> e : 
_fedMap.entrySet() )
+               // TODO handling of file path, but no danger as never written
+               for(Entry<FederatedRange, FederatedData> e : _fedMap.entrySet())
                        map.put(new FederatedRange(e.getKey(), clen), 
e.getValue().copyWithNewID(id));
                return new FederationMap(id, map, _type);
        }
@@ -295,24 +293,28 @@ public class FederationMap
        public FederationMap transpose() {
                Map<FederatedRange, FederatedData> tmp = new TreeMap<>(_fedMap);
                _fedMap.clear();
-               for( Entry<FederatedRange, FederatedData> e : tmp.entrySet() ) {
+               for(Entry<FederatedRange, FederatedData> e : tmp.entrySet()) {
                        _fedMap.put(new FederatedRange(e.getKey()).transpose(), 
e.getValue().copyWithNewID(_ID));
                }
-               //derive output type
+               // derive output type
                switch(_type) {
-                       case FULL: _type = FType.FULL; break;
-                       case ROW: _type = FType.COL; break;
-                       case COL: _type = FType.ROW; break;
-                       default: _type = FType.OTHER;
+                       case FULL:
+                               _type = FType.FULL;
+                               break;
+                       case ROW:
+                               _type = FType.COL;
+                               break;
+                       case COL:
+                               _type = FType.ROW;
+                               break;
+                       default:
+                               _type = FType.OTHER;
                }
                return this;
        }
 
-       
        public long getMaxIndexInRange(int dim) {
-               return _fedMap.keySet().stream()
-                       .mapToLong(range -> range.getEndDims()[dim]).max()
-                       .orElse(-1L);
+               return _fedMap.keySet().stream().mapToLong(range -> 
range.getEndDims()[dim]).max().orElse(-1L);
        }
 
        /**
@@ -352,27 +354,27 @@ public class FederationMap
                fedMapCopy._ID = newVarID;
                return fedMapCopy;
        }
-       
+
        public FederationMap filter(IndexRange ixrange) {
-               FederationMap ret = this.clone(); //same ID
-               
+               FederationMap ret = this.clone(); // same ID
+
                Iterator<Entry<FederatedRange, FederatedData>> iter = 
ret._fedMap.entrySet().iterator();
-               while( iter.hasNext() ) {
+               while(iter.hasNext()) {
                        Entry<FederatedRange, FederatedData> e = iter.next();
                        FederatedRange range = e.getKey();
-                       long rs = range.getBeginDims()[0], re = 
range.getEndDims()[0],
-                               cs = range.getBeginDims()[1], ce = 
range.getEndDims()[1];
-                       boolean overlap = ((ixrange.colStart <= ce) && 
(ixrange.colEnd >= cs)
-                               && (ixrange.rowStart <= re) && (ixrange.rowEnd 
>= rs));
-                       if( !overlap )
+                       long rs = range.getBeginDims()[0], re = 
range.getEndDims()[0], cs = range.getBeginDims()[1],
+                               ce = range.getEndDims()[1];
+                       boolean overlap = ((ixrange.colStart <= ce) && 
(ixrange.colEnd >= cs) && (ixrange.rowStart <= re) &&
+                               (ixrange.rowEnd >= rs));
+                       if(!overlap)
                                iter.remove();
                }
                return ret;
        }
-       
+
        private static void setThreadID(long tid, FederatedRequest[]... frsets) 
{
-               for( FederatedRequest[] frset : frsets )
-                       if( frset != null )
+               for(FederatedRequest[] frset : frsets)
+                       if(frset != null)
                                Arrays.stream(frset).forEach(fr -> 
fr.setTID(tid));
        }
 
@@ -399,14 +401,14 @@ public class FederationMap
        }
 
        @Override
-       public String toString(){
+       public String toString() {
                StringBuilder sb = new StringBuilder();
                sb.append("Fed Map: " + _type);
                sb.append("\t ID:" + _ID);
-               sb.append("\n"+ _fedMap);
+               sb.append("\n" + _fedMap);
                return sb.toString();
        }
-       
+
        @Override
        public FederationMap clone() {
                return copyWithNewID(getID());
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
index c50671e..6588909 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
@@ -19,7 +19,6 @@
 
 package org.apache.sysds.runtime.instructions.fed;
 
-import java.util.AbstractMap;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.HashMap;
@@ -136,9 +135,8 @@ public class ParameterizedBuiltinFEDInstruction extends 
ComputationFEDInstructio
                        
out.getDataCharacteristics().set(mo.getDataCharacteristics());
                        
out.setFedMapping(mo.getFedMapping().copyWithNewID(fr1.getID()));
                }
-               else if(opcode.equals("rmempty")) {
+               else if(opcode.equals("rmempty"))
                        rmempty(ec);
-               }
                else if(opcode.equalsIgnoreCase("transformdecode"))
                        transformDecode(ec);
                else if(opcode.equalsIgnoreCase("transformapply"))
@@ -149,32 +147,33 @@ public class ParameterizedBuiltinFEDInstruction extends 
ComputationFEDInstructio
        }
 
        private void rmempty(ExecutionContext ec) {
+               String margin = params.get("margin");
+               if( !(margin.equals("rows") || margin.equals("cols")) )
+                       throw new DMLRuntimeException("Unspupported margin 
identifier '"+margin+"'.");
+
                MatrixObject mo = (MatrixObject) getTarget(ec);
+               MatrixObject select = params.containsKey("select") ? 
ec.getMatrixObject(params.get("select")) : null;
                MatrixObject out = ec.getMatrixObject(output);
-               Map<FederatedRange, int[]> dcs;
-               if((instString.contains("margin=rows") && 
mo.isFederated(FederationMap.FType.ROW)) ||
-                       (instString.contains("margin=cols") && 
mo.isFederated(FederationMap.FType.COL))) {
-                       FederatedRequest fr1 = 
FederationUtils.callInstruction(instString,
-                               output,
-                               new CPOperand[] {getTargetOperand()},
-                               new long[] {mo.getFedMapping().getID()});
-                       mo.getFedMapping().execute(getTID(), true, fr1);
-                       
out.setFedMapping(mo.getFedMapping().copyWithNewID(fr1.getID()));
 
-                       // new ranges
-                       dcs = new HashMap<>();
-                       out.getFedMapping().forEachParallel((range, data) -> {
+               boolean marginRow = params.get("margin").equals("rows");
+               boolean k = ((marginRow && 
mo.getFedMapping().getType().isColPartitioned()) ||
+                       (!marginRow && 
mo.getFedMapping().getType().isRowPartitioned()));
+
+               MatrixBlock s = new MatrixBlock();
+               if(select == null && k) {
+                       List<MatrixBlock> colSums = new ArrayList<>();
+                       mo.getFedMapping().forEachParallel((range, data) -> {
                                try {
                                        FederatedResponse response = data
                                                .executeFederatedOperation(new 
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1,
-                                                       new 
GetDataCharacteristics(data.getVarID())))
+                                                       new 
GetVector(data.getVarID(), margin.equals("rows"))))
                                                .get();
 
                                        if(!response.isSuccessful())
                                                
response.throwExceptionFromResponse();
-                                       int[] subRangeCharacteristics = (int[]) 
response.getData()[0];
-                                       synchronized(dcs) {
-                                               dcs.put(range, 
subRangeCharacteristics);
+                                       MatrixBlock vector = (MatrixBlock) 
response.getData()[0];
+                                       synchronized(colSums) {
+                                               colSums.add(vector);
                                        }
                                }
                                catch(Exception e) {
@@ -182,53 +181,75 @@ public class ParameterizedBuiltinFEDInstruction extends 
ComputationFEDInstructio
                                }
                                return null;
                        });
+                       // find empty in matrix
+                       BinaryOperator plus = 
InstructionUtils.parseBinaryOperator("+");
+                       BinaryOperator greater = 
InstructionUtils.parseBinaryOperator(">");
+                       s = colSums.get(0);
+                       for(int i = 1; i < colSums.size(); i++)
+                               s = s.binaryOperationsInPlace(plus, 
colSums.get(i));
+                       s = s.binaryOperationsInPlace(greater, new 
MatrixBlock(s.getNumRows(), s.getNumColumns(), 0.0));
+                       select = ExecutionContext.createMatrixObject(s);
+
+                       long varID = FederationUtils.getNextFedDataID();
+                       ec.setVariable(String.valueOf(varID), select);
+                       params.put("select", String.valueOf(varID));
+                       // construct new string
+                       String[] oldString = 
InstructionUtils.getInstructionParts(instString);
+                       String[] newString = new String[oldString.length+1];
+                       newString[2] = "select="+varID;
+                       System.arraycopy(oldString, 0, newString, 0,2);
+                       System.arraycopy(oldString,2, newString, 3, 
newString.length-3);
+                       instString = 
instString.replace(InstructionUtils.concatOperands(oldString), 
InstructionUtils.concatOperands(newString));
                }
-               else {
-                       Map.Entry<FederationMap, Map<FederatedRange, int[]>> 
entry = rmemptyC(ec, mo);
-                       out.setFedMapping(entry.getKey());
-                       dcs = entry.getValue();
-               }
-               out.getDataCharacteristics().set(mo.getDataCharacteristics());
-               for(int i = 0; i < 
mo.getFedMapping().getFederatedRanges().length; i++) {
-                       int[] newRange = 
dcs.get(out.getFedMapping().getFederatedRanges()[i]);
-
-                       
out.getFedMapping().getFederatedRanges()[i].setBeginDim(0,
-                               
(out.getFedMapping().getFederatedRanges()[i].getBeginDims()[0] == 0 ||
-                                       i == 0) ? 0 : 
out.getFedMapping().getFederatedRanges()[i - 1].getEndDims()[0]);
-
-                       out.getFedMapping().getFederatedRanges()[i].setEndDim(0,
-                               
out.getFedMapping().getFederatedRanges()[i].getBeginDims()[0] + newRange[0]);
-
-                       
out.getFedMapping().getFederatedRanges()[i].setBeginDim(1,
-                               
(out.getFedMapping().getFederatedRanges()[i].getBeginDims()[1] == 0 ||
-                                       i == 0) ? 0 : 
out.getFedMapping().getFederatedRanges()[i - 1].getEndDims()[1]);
 
-                       out.getFedMapping().getFederatedRanges()[i].setEndDim(1,
-                               
out.getFedMapping().getFederatedRanges()[i].getBeginDims()[1] + newRange[1]);
+               if (select == null) {
+                       FederatedRequest fr1 = 
FederationUtils.callInstruction(instString, output,
+                               new CPOperand[] {getTargetOperand()},
+                               new long[] {mo.getFedMapping().getID()});
+                       mo.getFedMapping().execute(getTID(), true, fr1);
+                       
out.setFedMapping(mo.getFedMapping().copyWithNewID(fr1.getID()));
                }
+               else if (!k) {
+                       //construct commands: broadcast , fed rmempty, clean 
broadcast
+                       FederatedRequest[] fr1 = 
mo.getFedMapping().broadcastSliced(select, !marginRow);
+                       FederatedRequest fr2 = 
FederationUtils.callInstruction(instString,
+                               output,
+                               new CPOperand[] {getTargetOperand(), new 
CPOperand(params.get("select"), ValueType.FP64, DataType.MATRIX)},
+                               new long[] {mo.getFedMapping().getID(), 
fr1[0].getID()});
+                       FederatedRequest fr3 = 
mo.getFedMapping().cleanup(getTID(), fr1[0].getID());
+
+                       //execute federated operations and set output
+                       mo.getFedMapping().execute(getTID(), true, fr1, fr2, 
fr3);
+                       
out.setFedMapping(mo.getFedMapping().copyWithNewID(fr2.getID()));
+               } else {
+                       //construct commands: broadcast , fed rmempty, clean 
broadcast
+                       FederatedRequest fr1 = 
mo.getFedMapping().broadcast(select);
+                       FederatedRequest fr2 = 
FederationUtils.callInstruction(instString,
+                               output,
+                               new CPOperand[] {getTargetOperand(), new 
CPOperand(params.get("select"), ValueType.FP64, DataType.MATRIX)},
+                               new long[] {mo.getFedMapping().getID(), 
fr1.getID()});
+                       FederatedRequest fr3 = 
mo.getFedMapping().cleanup(getTID(), fr1.getID());
 
-               
out.getDataCharacteristics().set(out.getFedMapping().getMaxIndexInRange(0),
-                       out.getFedMapping().getMaxIndexInRange(1),
-                       (int) mo.getBlocksize());
-       }
-
-       private Map.Entry<FederationMap, Map<FederatedRange, int[]>> 
rmemptyC(ExecutionContext ec, MatrixObject mo) {
-               boolean marginRow = instString.contains("margin=rows");
+                       //execute federated operations and set output
+                       mo.getFedMapping().execute(getTID(), true, fr1, fr2, 
fr3);
+                       
out.setFedMapping(mo.getFedMapping().copyWithNewID(fr2.getID()));
+               }
 
-               // find empty in ranges
-               List<MatrixBlock> colSums = new ArrayList<>();
-               mo.getFedMapping().forEachParallel((range, data) -> {
+               // new ranges
+               Map<FederatedRange, int[]> dcs = new HashMap<>();
+               Map<FederatedRange, int[]> finalDcs1 = dcs;
+               out.getFedMapping().forEachParallel((range, data) -> {
                        try {
                                FederatedResponse response = data
                                        .executeFederatedOperation(new 
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1,
-                                               new GetVector(data.getVarID(), 
marginRow)))
+                                               new 
GetDataCharacteristics(data.getVarID())))
                                        .get();
 
                                if(!response.isSuccessful())
                                        response.throwExceptionFromResponse();
-                               MatrixBlock vector = (MatrixBlock) 
response.getData()[0];
-                               synchronized(colSums) {
-                                       colSums.add(vector);
+                               int[] subRangeCharacteristics = (int[]) 
response.getData()[0];
+                               synchronized(finalDcs1) {
+                                       finalDcs1.put(range, 
subRangeCharacteristics);
                                }
                        }
                        catch(Exception e) {
@@ -236,46 +257,28 @@ public class ParameterizedBuiltinFEDInstruction extends 
ComputationFEDInstructio
                        }
                        return null;
                });
+               dcs = finalDcs1;
+               out.getDataCharacteristics().set(mo.getDataCharacteristics());
+               for(int i = 0; i < 
mo.getFedMapping().getFederatedRanges().length; i++) {
+                       int[] newRange = 
dcs.get(out.getFedMapping().getFederatedRanges()[i]);
 
-               // find empty in matrix
-               BinaryOperator plus = InstructionUtils.parseBinaryOperator("+");
-               BinaryOperator greater = 
InstructionUtils.parseBinaryOperator(">");
-               MatrixBlock tmp1 = colSums.get(0);
-               for(int i = 1; i < colSums.size(); i++)
-                       tmp1 = tmp1.binaryOperationsInPlace(plus, 
colSums.get(i));
-               tmp1 = tmp1.binaryOperationsInPlace(greater, new 
MatrixBlock(tmp1.getNumRows(), tmp1.getNumColumns(), 0.0));
+                       
out.getFedMapping().getFederatedRanges()[i].setBeginDim(0,
+                               
(out.getFedMapping().getFederatedRanges()[i].getBeginDims()[0] == 0 ||
+                                       i == 0) ? 0 : 
out.getFedMapping().getFederatedRanges()[i - 1].getEndDims()[0]);
 
-               // remove empty from matrix
-               Map<FederatedRange, int[]> dcs = new HashMap<>();
-               long varID = FederationUtils.getNextFedDataID();
-               MatrixBlock finalTmp = new MatrixBlock(tmp1);
-               FederationMap resMapping;
-               if(tmp1.sum() == (marginRow ? tmp1.getNumColumns() : 
tmp1.getNumRows())) {
-                       resMapping = mo.getFedMapping();
-               }
-               else {
-                       resMapping = mo.getFedMapping().mapParallel(varID, 
(range, data) -> {
-                               try {
-                                       FederatedResponse response = data
-                                               .executeFederatedOperation(new 
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1,
-                                                       new 
ParameterizedBuiltinFEDInstruction.RemoveEmpty(data.getVarID(), varID, finalTmp,
-                                                               
params.containsKey("select") ? ec.getMatrixInput(params.get("select")) : null,
-                                                               
Boolean.parseBoolean(params.get("empty.return").toLowerCase()), marginRow)))
-                                               .get();
-                                       if(!response.isSuccessful())
-                                               
response.throwExceptionFromResponse();
-                                       int[] subRangeCharacteristics = (int[]) 
response.getData()[0];
-                                       synchronized(dcs) {
-                                               dcs.put(range, 
subRangeCharacteristics);
-                                       }
-                               }
-                               catch(Exception e) {
-                                       throw new DMLRuntimeException(e);
-                               }
-                               return null;
-                       });
+                       out.getFedMapping().getFederatedRanges()[i].setEndDim(0,
+                               
out.getFedMapping().getFederatedRanges()[i].getBeginDims()[0] + newRange[0]);
+
+                       
out.getFedMapping().getFederatedRanges()[i].setBeginDim(1,
+                               
(out.getFedMapping().getFederatedRanges()[i].getBeginDims()[1] == 0 ||
+                                       i == 0) ? 0 : 
out.getFedMapping().getFederatedRanges()[i - 1].getEndDims()[1]);
+
+                       out.getFedMapping().getFederatedRanges()[i].setEndDim(1,
+                               
out.getFedMapping().getFederatedRanges()[i].getBeginDims()[1] + newRange[1]);
                }
-               return new AbstractMap.SimpleEntry<>(resMapping, dcs);
+
+               
out.getDataCharacteristics().set(out.getFedMapping().getMaxIndexInRange(0),
+                       out.getFedMapping().getMaxIndexInRange(1), (int) 
mo.getBlocksize());
        }
 
        private void transformDecode(ExecutionContext ec) {
@@ -506,52 +509,9 @@ public class ParameterizedBuiltinFEDInstruction extends 
ComputationFEDInstructio
                @Override
                public FederatedResponse execute(ExecutionContext ec, Data... 
data) {
                        MatrixBlock mb = ((MatrixObject) 
data[0]).acquireReadAndRelease();
-                       return new FederatedResponse(ResponseType.SUCCESS, new 
int[] {mb.getNumRows(), mb.getNumColumns()});
-               }
-       }
-
-       private static class RemoveEmpty extends FederatedUDF {
-
-               private static final long serialVersionUID = 12341521331L;
-               private final MatrixBlock _vector;
-               private final long _outputID;
-               private MatrixBlock _select;
-               private boolean _emptyReturn;
-               private final boolean _marginRow;
-
-               public RemoveEmpty(long varID, long outputID, MatrixBlock 
vector, MatrixBlock select, boolean emptyReturn,
-                       boolean marginRow) {
-                       super(new long[] {varID});
-                       _vector = vector;
-                       _outputID = outputID;
-                       _select = select;
-                       _emptyReturn = emptyReturn;
-                       _marginRow = marginRow;
-               }
-
-               @Override
-               public FederatedResponse execute(ExecutionContext ec, Data... 
data) {
-                       MatrixBlock mb = ((MatrixObject) 
data[0]).acquireReadAndRelease();
-
-                       BinaryOperator plus = 
InstructionUtils.parseBinaryOperator("+");
-                       BinaryOperator minus = 
InstructionUtils.parseBinaryOperator("-");
-
-                       mb = mb.binaryOperationsInPlace(plus, new 
MatrixBlock(mb.getNumRows(), mb.getNumColumns(), 1.0));
-                       for(int i = 0; i < mb.getNumRows(); i++)
-                               for(int j = 0; j < mb.getNumColumns(); j++)
-                                       if(_marginRow)
-                                               mb.setValue(i, j, 
_vector.getValue(i, 0) * mb.getValue(i, j));
-                                       else
-                                               mb.setValue(i, j, 
_vector.getValue(0, j) * mb.getValue(i, j));
-
-                       MatrixBlock res = mb.removeEmptyOperations(new 
MatrixBlock(), _marginRow, _emptyReturn, _select);
-                       res = res.binaryOperationsInPlace(minus, new 
MatrixBlock(res.getNumRows(), res.getNumColumns(), 1.0));
-
-                       MatrixObject mout = 
ExecutionContext.createMatrixObject(res);
-                       ec.setVariable(String.valueOf(_outputID), mout);
-
-                       return new 
FederatedResponse(FederatedResponse.ResponseType.SUCCESS,
-                               new int[] {res.getNumRows(), 
res.getNumColumns()});
+                       int r = mb.getDenseBlockValues() != null ? 
mb.getNumRows() : 0;
+                       int c = mb.getDenseBlockValues() != null ? 
mb.getNumColumns(): 0;
+                       return new FederatedResponse(ResponseType.SUCCESS, new 
int[] {r, c});
                }
        }
 
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRemoveEmptyTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRemoveEmptyTest.java
index a629270..10a6711 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRemoveEmptyTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRemoveEmptyTest.java
@@ -55,7 +55,10 @@ public class FederatedRemoveEmptyTest extends 
AutomatedTestBase {
 
        @Parameterized.Parameters
        public static Collection<Object[]> data() {
-               return Arrays.asList(new Object[][] {{20, 10, true}, {20, 12, 
false}});
+               return Arrays.asList(new Object[][] {
+                       {20, 12, true},
+                       {20, 12, false}
+               });
        }
 
        @Override
@@ -94,11 +97,6 @@ public class FederatedRemoveEmptyTest extends 
AutomatedTestBase {
 
                for(int k : new int[] {1, 2, 3}) {
                        Arrays.fill(X3[k], 0);
-                       if(!rowPartitioned) {
-                               Arrays.fill(X1[k], 0);
-                               Arrays.fill(X2[k], 0);
-                               Arrays.fill(X4[k], 0);
-                       }
                }
 
                MatrixCharacteristics mc = new MatrixCharacteristics(r, c, 
blocksize, r * c);
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedSplitTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedSplitTest.java
index 3e640c0..afd2ffe 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedSplitTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedSplitTest.java
@@ -54,7 +54,9 @@ public class FederatedSplitTest extends AutomatedTestBase {
 
        @Parameterized.Parameters
        public static Collection<Object[]> data() {
-               return Arrays.asList(new Object[][] {{152, 12, "TRUE"}, {132, 
11, "FALSE"}});
+               return Arrays.asList(new Object[][] {
+                       // {152, 12, "TRUE"}, 
+                       {132, 11, "FALSE"}});
        }
 
        @Override
@@ -125,9 +127,7 @@ public class FederatedSplitTest extends AutomatedTestBase {
                if(cont.equals("TRUE"))
                        
Assert.assertTrue(heavyHittersContainsString("fed_rightIndex"));
                else {
-                       
Assert.assertTrue(heavyHittersContainsString("fed_ba+*"));
-                       // TODO add federated diag operator.
-                       // 
Assert.assertTrue(heavyHittersContainsString("fed_rdiag"));
+                       
Assert.assertTrue(heavyHittersContainsString("fed_rmempty"));
                }
                
                TestUtils.shutdownThreads(t1, t2);
diff --git a/src/test/scripts/functions/federated/FederatedSplitTest.dml 
b/src/test/scripts/functions/federated/FederatedSplitTest.dml
index 44c59a9..e1fc647 100644
--- a/src/test/scripts/functions/federated/FederatedSplitTest.dml
+++ b/src/test/scripts/functions/federated/FederatedSplitTest.dml
@@ -24,7 +24,6 @@ X = federated(addresses=list($X1, $X2),
 Y = federated(addresses=list($Y1, $Y2),
     ranges=list(list(0, 0), list($r / 2, $c), list($r / 2, 0), list($r, $c)))
 
-
-[Xtr, Xte, Ytr, Yte] = split(X=X,Y=Y,f=0.95, cont=$Cont, seed = 13)
+[Xtr, Xte, Ytr, Yte] = split(X=X, Y=Y, f=0.95, cont=$Cont, seed = 13)
 write(Xte, $Z)
 print(toString(Xte))
diff --git 
a/src/test/scripts/functions/federated/FederatedSplitTestReference.dml 
b/src/test/scripts/functions/federated/FederatedSplitTestReference.dml
index 4db8e1f..962dd84 100644
--- a/src/test/scripts/functions/federated/FederatedSplitTestReference.dml
+++ b/src/test/scripts/functions/federated/FederatedSplitTestReference.dml
@@ -21,6 +21,6 @@
 
 X = rbind(read($X1), read($X2))
 Y = rbind(read($Y1), read($Y2))
-[Xtr, Xte, Ytr, Yte] = split(X=X,Y=Y, f=0.95 ,cont=$Cont, seed = 13)
+[Xtr, Xte, Ytr, Yte] = split(X=X, Y=Y, f=0.95, cont=$Cont, seed = 13)
 write(Xte, $Z)
 print(toString(Xte))

Reply via email to