This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 4d8ec5d  [SYSTEMDS-2732] Federated remove empty
4d8ec5d is described below

commit 4d8ec5dd64922441f0acc452eee3e49bee0653cd
Author: Olga <ovcharenko.fo...@gmail.com>
AuthorDate: Sat Nov 14 17:42:33 2020 +0100

    [SYSTEMDS-2732] Federated remove empty
    
    closes #1104
---
 .../instructions/fed/FEDInstructionUtils.java      |   2 +-
 .../instructions/fed/InitFEDInstruction.java       |   2 +-
 .../fed/ParameterizedBuiltinFEDInstruction.java    | 262 +++++++++++++++++++--
 .../primitives/FederatedRemoveEmptyTest.java       | 161 +++++++++++++
 .../federated/FederatedRemoveEmptyTest.dml         |  33 +++
 .../FederatedRemoveEmptyTestReference.dml          |  26 ++
 6 files changed, 468 insertions(+), 18 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index f4b19bf..d8af245 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -129,7 +129,7 @@ public class FEDInstructionUtils {
                }
                else if( inst instanceof ParameterizedBuiltinCPInstruction ) {
                        ParameterizedBuiltinCPInstruction pinst = 
(ParameterizedBuiltinCPInstruction) inst;
-                       if(pinst.getOpcode().equals("replace") && 
pinst.getTarget(ec).isFederated()) {
+                       if((pinst.getOpcode().equals("replace") || 
pinst.getOpcode().equals("rmempty")) && pinst.getTarget(ec).isFederated()) {
                                fedinst = 
ParameterizedBuiltinFEDInstruction.parseInstruction(pinst.getInstructionString());
                        }
                        else if((pinst.getOpcode().equals("transformdecode") || 
pinst.getOpcode().equals("transformapply")) &&
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
index 8821a71..ce7f3b4 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
@@ -234,7 +234,7 @@ public class InitFEDInstruction extends FEDInstruction {
                }
                try {
                        int timeout = 
ConfigurationManager.getDMLConfig().getIntValue(DMLConfig.DEFAULT_FEDERATED_INITIALIZATION_TIMEOUT);
-                       LOG.error("Federated Initialization with timeout: " + 
timeout);
+                       LOG.debug("Federated Initialization with timeout: " + 
timeout);
                        for (Pair<FederatedData, Future<FederatedResponse>> 
idResponse : idResponses)
                                
idResponse.getRight().get(timeout,TimeUnit.SECONDS); //wait for initialization
                }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
index f549dca..c50671e 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
@@ -19,11 +19,14 @@
 
 package org.apache.sysds.runtime.instructions.fed;
 
+import java.util.AbstractMap;
+import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.HashMap;
 import java.util.LinkedHashMap;
-
 import java.util.List;
+import java.util.Map;
+
 import org.apache.sysds.common.Types;
 import org.apache.sysds.common.Types.DataType;
 import org.apache.sysds.common.Types.ValueType;
@@ -34,6 +37,7 @@ import 
org.apache.sysds.runtime.controlprogram.caching.CacheableData;
 import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
 import 
org.apache.sysds.runtime.controlprogram.federated.FederatedResponse.ResponseType;
@@ -47,6 +51,7 @@ import org.apache.sysds.runtime.instructions.cp.CPOperand;
 import org.apache.sysds.runtime.instructions.cp.Data;
 import org.apache.sysds.runtime.matrix.data.FrameBlock;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
 import org.apache.sysds.runtime.matrix.operators.Operator;
 import org.apache.sysds.runtime.matrix.operators.SimpleOperator;
 import org.apache.sysds.runtime.meta.MatrixCharacteristics;
@@ -100,7 +105,7 @@ public class ParameterizedBuiltinFEDInstruction extends 
ComputationFEDInstructio
                LinkedHashMap<String, String> paramsMap = 
constructParameterMap(parts);
 
                // determine the appropriate value function
-               if( opcode.equalsIgnoreCase("replace") ) {
+               if(opcode.equalsIgnoreCase("replace") || 
opcode.equalsIgnoreCase("rmempty")) {
                        ValueFunction func = 
ParameterizedBuiltin.getParameterizedBuiltinFnObject(opcode);
                        return new ParameterizedBuiltinFEDInstruction(new 
SimpleOperator(func), paramsMap, out, opcode, str);
                }
@@ -120,8 +125,10 @@ public class ParameterizedBuiltinFEDInstruction extends 
ComputationFEDInstructio
                        // similar to unary federated instructions, get 
federated input
                        // execute instruction, and derive federated output 
matrix
                        MatrixObject mo = (MatrixObject) getTarget(ec);
-                       FederatedRequest fr1 = 
FederationUtils.callInstruction(instString, output,
-                               new CPOperand[] {getTargetOperand()}, new 
long[] {mo.getFedMapping().getID()});
+                       FederatedRequest fr1 = 
FederationUtils.callInstruction(instString,
+                               output,
+                               new CPOperand[] {getTargetOperand()},
+                               new long[] {mo.getFedMapping().getID()});
                        mo.getFedMapping().execute(getTID(), true, fr1);
 
                        // derive new fed mapping for output
@@ -129,6 +136,9 @@ public class ParameterizedBuiltinFEDInstruction extends 
ComputationFEDInstructio
                        
out.getDataCharacteristics().set(mo.getDataCharacteristics());
                        
out.setFedMapping(mo.getFedMapping().copyWithNewID(fr1.getID()));
                }
+               else if(opcode.equals("rmempty")) {
+                       rmempty(ec);
+               }
                else if(opcode.equalsIgnoreCase("transformdecode"))
                        transformDecode(ec);
                else if(opcode.equalsIgnoreCase("transformapply"))
@@ -138,6 +148,136 @@ public class ParameterizedBuiltinFEDInstruction extends 
ComputationFEDInstructio
                }
        }
 
+       private void rmempty(ExecutionContext ec) {
+               MatrixObject mo = (MatrixObject) getTarget(ec);
+               MatrixObject out = ec.getMatrixObject(output);
+               Map<FederatedRange, int[]> dcs;
+               if((instString.contains("margin=rows") && 
mo.isFederated(FederationMap.FType.ROW)) ||
+                       (instString.contains("margin=cols") && 
mo.isFederated(FederationMap.FType.COL))) {
+                       FederatedRequest fr1 = 
FederationUtils.callInstruction(instString,
+                               output,
+                               new CPOperand[] {getTargetOperand()},
+                               new long[] {mo.getFedMapping().getID()});
+                       mo.getFedMapping().execute(getTID(), true, fr1);
+                       
out.setFedMapping(mo.getFedMapping().copyWithNewID(fr1.getID()));
+
+                       // new ranges
+                       dcs = new HashMap<>();
+                       out.getFedMapping().forEachParallel((range, data) -> {
+                               try {
+                                       FederatedResponse response = data
+                                               .executeFederatedOperation(new 
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1,
+                                                       new 
GetDataCharacteristics(data.getVarID())))
+                                               .get();
+
+                                       if(!response.isSuccessful())
+                                               
response.throwExceptionFromResponse();
+                                       int[] subRangeCharacteristics = (int[]) 
response.getData()[0];
+                                       synchronized(dcs) {
+                                               dcs.put(range, 
subRangeCharacteristics);
+                                       }
+                               }
+                               catch(Exception e) {
+                                       throw new DMLRuntimeException(e);
+                               }
+                               return null;
+                       });
+               }
+               else {
+                       Map.Entry<FederationMap, Map<FederatedRange, int[]>> 
entry = rmemptyC(ec, mo);
+                       out.setFedMapping(entry.getKey());
+                       dcs = entry.getValue();
+               }
+               out.getDataCharacteristics().set(mo.getDataCharacteristics());
+               for(int i = 0; i < 
mo.getFedMapping().getFederatedRanges().length; i++) {
+                       int[] newRange = 
dcs.get(out.getFedMapping().getFederatedRanges()[i]);
+
+                       
out.getFedMapping().getFederatedRanges()[i].setBeginDim(0,
+                               
(out.getFedMapping().getFederatedRanges()[i].getBeginDims()[0] == 0 ||
+                                       i == 0) ? 0 : 
out.getFedMapping().getFederatedRanges()[i - 1].getEndDims()[0]);
+
+                       out.getFedMapping().getFederatedRanges()[i].setEndDim(0,
+                               
out.getFedMapping().getFederatedRanges()[i].getBeginDims()[0] + newRange[0]);
+
+                       
out.getFedMapping().getFederatedRanges()[i].setBeginDim(1,
+                               
(out.getFedMapping().getFederatedRanges()[i].getBeginDims()[1] == 0 ||
+                                       i == 0) ? 0 : 
out.getFedMapping().getFederatedRanges()[i - 1].getEndDims()[1]);
+
+                       out.getFedMapping().getFederatedRanges()[i].setEndDim(1,
+                               
out.getFedMapping().getFederatedRanges()[i].getBeginDims()[1] + newRange[1]);
+               }
+
+               
out.getDataCharacteristics().set(out.getFedMapping().getMaxIndexInRange(0),
+                       out.getFedMapping().getMaxIndexInRange(1),
+                       (int) mo.getBlocksize());
+       }
+
+       private Map.Entry<FederationMap, Map<FederatedRange, int[]>> 
rmemptyC(ExecutionContext ec, MatrixObject mo) {
+               boolean marginRow = instString.contains("margin=rows");
+
+               // find empty in ranges
+               List<MatrixBlock> colSums = new ArrayList<>();
+               mo.getFedMapping().forEachParallel((range, data) -> {
+                       try {
+                               FederatedResponse response = data
+                                       .executeFederatedOperation(new 
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1,
+                                               new GetVector(data.getVarID(), 
marginRow)))
+                                       .get();
+
+                               if(!response.isSuccessful())
+                                       response.throwExceptionFromResponse();
+                               MatrixBlock vector = (MatrixBlock) 
response.getData()[0];
+                               synchronized(colSums) {
+                                       colSums.add(vector);
+                               }
+                       }
+                       catch(Exception e) {
+                               throw new DMLRuntimeException(e);
+                       }
+                       return null;
+               });
+
+               // find empty in matrix
+               BinaryOperator plus = InstructionUtils.parseBinaryOperator("+");
+               BinaryOperator greater = 
InstructionUtils.parseBinaryOperator(">");
+               MatrixBlock tmp1 = colSums.get(0);
+               for(int i = 1; i < colSums.size(); i++)
+                       tmp1 = tmp1.binaryOperationsInPlace(plus, 
colSums.get(i));
+               tmp1 = tmp1.binaryOperationsInPlace(greater, new 
MatrixBlock(tmp1.getNumRows(), tmp1.getNumColumns(), 0.0));
+
+               // remove empty from matrix
+               Map<FederatedRange, int[]> dcs = new HashMap<>();
+               long varID = FederationUtils.getNextFedDataID();
+               MatrixBlock finalTmp = new MatrixBlock(tmp1);
+               FederationMap resMapping;
+               if(tmp1.sum() == (marginRow ? tmp1.getNumColumns() : 
tmp1.getNumRows())) {
+                       resMapping = mo.getFedMapping();
+               }
+               else {
+                       resMapping = mo.getFedMapping().mapParallel(varID, 
(range, data) -> {
+                               try {
+                                       FederatedResponse response = data
+                                               .executeFederatedOperation(new 
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1,
+                                                       new 
ParameterizedBuiltinFEDInstruction.RemoveEmpty(data.getVarID(), varID, finalTmp,
+                                                               
params.containsKey("select") ? ec.getMatrixInput(params.get("select")) : null,
+                                                               
Boolean.parseBoolean(params.get("empty.return").toLowerCase()), marginRow)))
+                                               .get();
+                                       if(!response.isSuccessful())
+                                               
response.throwExceptionFromResponse();
+                                       int[] subRangeCharacteristics = (int[]) 
response.getData()[0];
+                                       synchronized(dcs) {
+                                               dcs.put(range, 
subRangeCharacteristics);
+                                       }
+                               }
+                               catch(Exception e) {
+                                       throw new DMLRuntimeException(e);
+                               }
+                               return null;
+                       });
+               }
+               return new AbstractMap.SimpleEntry<>(resMapping, dcs);
+       }
+
        private void transformDecode(ExecutionContext ec) {
                // acquire locks
                MatrixObject mo = ec.getMatrixObject(params.get("target"));
@@ -155,14 +295,14 @@ public class ParameterizedBuiltinFEDInstruction extends 
ComputationFEDInstructio
                        long[] beginDims = range.getBeginDims();
                        long[] endDims = range.getEndDims();
                        int colStartBefore = (int) beginDims[1];
-                       
+
                        // update begin end dims (column part) considering 
columns added by dummycoding
                        globalDecoder.updateIndexRanges(beginDims, endDims);
-                       
+
                        // get the decoder segment that is relevant for this 
federated worker
                        Decoder decoder = globalDecoder
-                                       .subRangeDecoder((int) beginDims[1] + 
1, (int) endDims[1] + 1, colStartBefore);
-                       
+                               .subRangeDecoder((int) beginDims[1] + 1, (int) 
endDims[1] + 1, colStartBefore);
+
                        FrameBlock metaSlice = new FrameBlock();
                        synchronized(meta) {
                                meta.slice(0, meta.getNumRows() - 1, (int) 
beginDims[1], (int) endDims[1] - 1, metaSlice);
@@ -170,9 +310,8 @@ public class ParameterizedBuiltinFEDInstruction extends 
ComputationFEDInstructio
 
                        FederatedResponse response;
                        try {
-                               response = data.executeFederatedOperation(
-                                       new 
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1,
-                                       new DecodeMatrix(data.getVarID(), 
varID, metaSlice, decoder))).get();
+                               response = data.executeFederatedOperation(new 
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF,
+                                       -1, new DecodeMatrix(data.getVarID(), 
varID, metaSlice, decoder))).get();
                                if(!response.isSuccessful())
                                        response.throwExceptionFromResponse();
 
@@ -217,7 +356,8 @@ public class ParameterizedBuiltinFEDInstruction extends 
ComputationFEDInstructio
                        try {
                                FederatedResponse response = data
                                        .executeFederatedOperation(new 
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1,
-                                               new 
GetColumnNames(data.getVarID()))).get();
+                                               new 
GetColumnNames(data.getVarID())))
+                                       .get();
 
                                // no synchronization necessary since names 
should anyway match
                                String[] subRangeColNames = (String[]) 
response.getData()[0];
@@ -261,7 +401,8 @@ public class ParameterizedBuiltinFEDInstruction extends 
ComputationFEDInstructio
                                EncoderOmit subRangeEncoder = (EncoderOmit) 
omitEncoder.subRangeEncoder(range.asIndexRange().add(1));
                                FederatedResponse response = data
                                        .executeFederatedOperation(new 
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1,
-                                               new 
InitRowsToRemoveOmit(data.getVarID(), subRangeEncoder))).get();
+                                               new 
InitRowsToRemoveOmit(data.getVarID(), subRangeEncoder)))
+                                       .get();
 
                                // no synchronization necessary since names 
should anyway match
                                Encoder builtEncoder = (Encoder) 
response.getData()[0];
@@ -283,7 +424,7 @@ public class ParameterizedBuiltinFEDInstruction extends 
ComputationFEDInstructio
        private CPOperand getTargetOperand() {
                return new CPOperand(params.get("target"), ValueType.FP64, 
DataType.MATRIX);
        }
-       
+
        public static class DecodeMatrix extends FederatedUDF {
                private static final long serialVersionUID = 
2376756757742169692L;
                private final long _outputID;
@@ -330,7 +471,7 @@ public class ParameterizedBuiltinFEDInstruction extends 
ComputationFEDInstructio
 
                @Override
                public FederatedResponse execute(ExecutionContext ec, Data... 
data) {
-                       FrameBlock fb = 
((FrameObject)data[0]).acquireReadAndRelease();
+                       FrameBlock fb = ((FrameObject) 
data[0]).acquireReadAndRelease();
                        // return column names
                        return new FederatedResponse(ResponseType.SUCCESS, new 
Object[] {fb.getColumnNames()});
                }
@@ -348,9 +489,98 @@ public class ParameterizedBuiltinFEDInstruction extends 
ComputationFEDInstructio
 
                @Override
                public FederatedResponse execute(ExecutionContext ec, Data... 
data) {
-                       FrameBlock fb = 
((FrameObject)data[0]).acquireReadAndRelease();
+                       FrameBlock fb = ((FrameObject) 
data[0]).acquireReadAndRelease();
                        _encoder.build(fb);
                        return new FederatedResponse(ResponseType.SUCCESS, new 
Object[] {_encoder});
                }
        }
+
+       private static class GetDataCharacteristics extends FederatedUDF {
+
+               private static final long serialVersionUID = 
578461386177730925L;
+
+               public GetDataCharacteristics(long varID) {
+                       super(new long[] {varID});
+               }
+
+               @Override
+               public FederatedResponse execute(ExecutionContext ec, Data... 
data) {
+                       MatrixBlock mb = ((MatrixObject) 
data[0]).acquireReadAndRelease();
+                       return new FederatedResponse(ResponseType.SUCCESS, new 
int[] {mb.getNumRows(), mb.getNumColumns()});
+               }
+       }
+
+       private static class RemoveEmpty extends FederatedUDF {
+
+               private static final long serialVersionUID = 12341521331L;
+               private final MatrixBlock _vector;
+               private final long _outputID;
+               private MatrixBlock _select;
+               private boolean _emptyReturn;
+               private final boolean _marginRow;
+
+               public RemoveEmpty(long varID, long outputID, MatrixBlock 
vector, MatrixBlock select, boolean emptyReturn,
+                       boolean marginRow) {
+                       super(new long[] {varID});
+                       _vector = vector;
+                       _outputID = outputID;
+                       _select = select;
+                       _emptyReturn = emptyReturn;
+                       _marginRow = marginRow;
+               }
+
+               @Override
+               public FederatedResponse execute(ExecutionContext ec, Data... 
data) {
+                       MatrixBlock mb = ((MatrixObject) 
data[0]).acquireReadAndRelease();
+
+                       BinaryOperator plus = 
InstructionUtils.parseBinaryOperator("+");
+                       BinaryOperator minus = 
InstructionUtils.parseBinaryOperator("-");
+
+                       mb = mb.binaryOperationsInPlace(plus, new 
MatrixBlock(mb.getNumRows(), mb.getNumColumns(), 1.0));
+                       for(int i = 0; i < mb.getNumRows(); i++)
+                               for(int j = 0; j < mb.getNumColumns(); j++)
+                                       if(_marginRow)
+                                               mb.setValue(i, j, 
_vector.getValue(i, 0) * mb.getValue(i, j));
+                                       else
+                                               mb.setValue(i, j, 
_vector.getValue(0, j) * mb.getValue(i, j));
+
+                       MatrixBlock res = mb.removeEmptyOperations(new 
MatrixBlock(), _marginRow, _emptyReturn, _select);
+                       res = res.binaryOperationsInPlace(minus, new 
MatrixBlock(res.getNumRows(), res.getNumColumns(), 1.0));
+
+                       MatrixObject mout = 
ExecutionContext.createMatrixObject(res);
+                       ec.setVariable(String.valueOf(_outputID), mout);
+
+                       return new 
FederatedResponse(FederatedResponse.ResponseType.SUCCESS,
+                               new int[] {res.getNumRows(), 
res.getNumColumns()});
+               }
+       }
+
+       private static class GetVector extends FederatedUDF {
+
+               private static final long serialVersionUID = 
-1003061862215703768L;
+               private final boolean _marginRow;
+
+               public GetVector(long varID, boolean marginRow) {
+                       super(new long[] {varID});
+                       _marginRow = marginRow;
+               }
+
+               @Override
+               public FederatedResponse execute(ExecutionContext ec, Data... 
data) {
+                       MatrixBlock mb = ((MatrixObject) 
data[0]).acquireReadAndRelease();
+
+                       BinaryOperator plus = 
InstructionUtils.parseBinaryOperator("+");
+                       BinaryOperator greater = 
InstructionUtils.parseBinaryOperator(">");
+                       int len = _marginRow ? mb.getNumColumns() : 
mb.getNumRows();
+                       MatrixBlock tmp1 = _marginRow ? mb.slice(0, 
mb.getNumRows() - 1, 0, 0, new MatrixBlock()) : mb
+                               .slice(0, 0, 0, mb.getNumColumns() - 1, new 
MatrixBlock());
+                       for(int i = 1; i < len; i++) {
+                               MatrixBlock tmp2 = _marginRow ? mb.slice(0, 
mb.getNumRows() - 1, i, i, new MatrixBlock()) : mb
+                                       .slice(i, i, 0, mb.getNumColumns() - 1, 
new MatrixBlock());
+                               tmp1 = tmp1.binaryOperationsInPlace(plus, tmp2);
+                       }
+                       tmp1 = tmp1.binaryOperationsInPlace(greater, new 
MatrixBlock(tmp1.getNumRows(), tmp1.getNumColumns(), 0.0));
+                       return new FederatedResponse(ResponseType.SUCCESS, 
tmp1);
+               }
+       }
 }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRemoveEmptyTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRemoveEmptyTest.java
new file mode 100644
index 0000000..de1e6d5
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRemoveEmptyTest.java
@@ -0,0 +1,161 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.primitives;
+
+import java.util.Arrays;
+import java.util.Collection;
+
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(value = Parameterized.class)
+@net.jcip.annotations.NotThreadSafe
+public class FederatedRemoveEmptyTest extends AutomatedTestBase {
+       // private static final Log LOG = 
LogFactory.getLog(FederatedRightIndexTest.class.getName());
+
+       private final static String TEST_NAME = "FederatedRemoveEmptyTest";
+
+       private final static String TEST_DIR = "functions/federated/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
FederatedRemoveEmptyTest.class.getSimpleName() + "/";
+
+       private final static int blocksize = 1024;
+       @Parameterized.Parameter()
+       public int rows;
+       @Parameterized.Parameter(1)
+       public int cols;
+
+       @Parameterized.Parameter(2)
+       public boolean rowPartitioned;
+
+       @Parameterized.Parameters
+       public static Collection<Object[]> data() {
+               return Arrays.asList(new Object[][] {{20, 10, true}, {20, 12, 
false}});
+       }
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"S"}));
+       }
+
+       @Test
+       public void testRemoveEmptyCP() {
+               runAggregateOperationTest(ExecMode.SINGLE_NODE);
+       }
+
+       private void runAggregateOperationTest(ExecMode execMode) {
+               boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+               ExecMode platformOld = rtplatform;
+
+               if(rtplatform == ExecMode.SPARK)
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+
+               getAndLoadTestConfiguration(TEST_NAME);
+               String HOME = SCRIPT_DIR + TEST_DIR;
+
+               // write input matrices
+               int r = rows;
+               int c = cols / 4;
+               if(rowPartitioned) {
+                       r = rows / 4;
+                       c = cols;
+               }
+
+               double[][] X1 = getRandomMatrix(r, c, 1, 5, 1, 3);
+               double[][] X2 = getRandomMatrix(r, c, 1, 5, 1, 7);
+               double[][] X3 = getRandomMatrix(r, c, 1, 5, 1, 8);
+               double[][] X4 = getRandomMatrix(r, c, 1, 5, 1, 9);
+
+               for(int k : new int[] {1, 2, 3}) {
+                       Arrays.fill(X3[k], 0);
+                       if(!rowPartitioned) {
+                               Arrays.fill(X1[k], 0);
+                               Arrays.fill(X2[k], 0);
+                               Arrays.fill(X4[k], 0);
+                       }
+               }
+
+               MatrixCharacteristics mc = new MatrixCharacteristics(r, c, 
blocksize, r * c);
+               writeInputMatrixWithMTD("X1", X1, false, mc);
+               writeInputMatrixWithMTD("X2", X2, false, mc);
+               writeInputMatrixWithMTD("X3", X3, false, mc);
+               writeInputMatrixWithMTD("X4", X4, false, mc);
+
+               // empty script name because we don't execute any script, just 
start the worker
+               fullDMLScriptName = "";
+               int port1 = getRandomAvailablePort();
+               int port2 = getRandomAvailablePort();
+               int port3 = getRandomAvailablePort();
+               int port4 = getRandomAvailablePort();
+               Thread t1 = startLocalFedWorkerThread(port1, 10);
+               Thread t2 = startLocalFedWorkerThread(port2, 10);
+               Thread t3 = startLocalFedWorkerThread(port3, 10);
+               Thread t4 = startLocalFedWorkerThread(port4);
+
+               rtplatform = execMode;
+               if(rtplatform == ExecMode.SPARK) {
+                       System.out.println(7);
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+               }
+               TestConfiguration config = 
availableTestConfigurations.get(TEST_NAME);
+               loadTestConfiguration(config);
+
+               // Run reference dml script with normal matrix
+               fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+               programArgs = new String[] {"-stats", "100", "-args", 
input("X1"), input("X2"), input("X3"), input("X4"),
+                       Boolean.toString(rowPartitioned).toUpperCase(), 
expected("S")};
+
+               runTest(null);
+
+               fullDMLScriptName = HOME + TEST_NAME + ".dml";
+               programArgs = new String[] {"-stats", "100", "-nvargs",
+                       "in_X1=" + TestUtils.federatedAddress(port1, 
input("X1")),
+                       "in_X2=" + TestUtils.federatedAddress(port2, 
input("X2")),
+                       "in_X3=" + TestUtils.federatedAddress(port3, 
input("X3")),
+                       "in_X4=" + TestUtils.federatedAddress(port4, 
input("X4")), "rows=" + rows, "cols=" + cols,
+                       "rP=" + Boolean.toString(rowPartitioned).toUpperCase(), 
"out_S=" + output("S")};
+
+               runTest(null);
+
+               // compare via files
+               compareResults(1e-9);
+
+               // check that federated input files are still existing
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X3")));
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4")));
+
+               TestUtils.shutdownThreads(t1, t2, t3, t4);
+
+               rtplatform = platformOld;
+               DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+
+       }
+}
diff --git a/src/test/scripts/functions/federated/FederatedRemoveEmptyTest.dml 
b/src/test/scripts/functions/federated/FederatedRemoveEmptyTest.dml
new file mode 100644
index 0000000..0c6b77b
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedRemoveEmptyTest.dml
@@ -0,0 +1,33 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+if ($rP) {
+    A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+        ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), 
list(2*$rows/4, $cols),
+               list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), 
list($rows, $cols)));
+} else {
+    A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+            ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4), 
list($rows, $cols/2),
+               list(0,$cols/2), list($rows, 3*($cols/4)), list(0, 
3*($cols/4)), list($rows, $cols)));
+}
+
+s = removeEmpty(target=A, margin="cols");
+write(s, $out_S);
diff --git 
a/src/test/scripts/functions/federated/FederatedRemoveEmptyTestReference.dml 
b/src/test/scripts/functions/federated/FederatedRemoveEmptyTestReference.dml
new file mode 100644
index 0000000..c4b2dc9
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedRemoveEmptyTestReference.dml
@@ -0,0 +1,26 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+if($5) { A = rbind(read($1), read($2), read($3), read($4)); }
+else { A = cbind(read($1), read($2), read($3), read($4)); }
+
+s = removeEmpty(target=A, margin="cols");
+write(s, $6);

Reply via email to