This is an automated email from the ASF dual-hosted git repository. baunsgaard pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push: new 4d8ec5d [SYSTEMDS-2732] Federated remove empty 4d8ec5d is described below commit 4d8ec5dd64922441f0acc452eee3e49bee0653cd Author: Olga <ovcharenko.fo...@gmail.com> AuthorDate: Sat Nov 14 17:42:33 2020 +0100 [SYSTEMDS-2732] Federated remove empty closes #1104 --- .../instructions/fed/FEDInstructionUtils.java | 2 +- .../instructions/fed/InitFEDInstruction.java | 2 +- .../fed/ParameterizedBuiltinFEDInstruction.java | 262 +++++++++++++++++++-- .../primitives/FederatedRemoveEmptyTest.java | 161 +++++++++++++ .../federated/FederatedRemoveEmptyTest.dml | 33 +++ .../FederatedRemoveEmptyTestReference.dml | 26 ++ 6 files changed, 468 insertions(+), 18 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java index f4b19bf..d8af245 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java @@ -129,7 +129,7 @@ public class FEDInstructionUtils { } else if( inst instanceof ParameterizedBuiltinCPInstruction ) { ParameterizedBuiltinCPInstruction pinst = (ParameterizedBuiltinCPInstruction) inst; - if(pinst.getOpcode().equals("replace") && pinst.getTarget(ec).isFederated()) { + if((pinst.getOpcode().equals("replace") || pinst.getOpcode().equals("rmempty")) && pinst.getTarget(ec).isFederated()) { fedinst = ParameterizedBuiltinFEDInstruction.parseInstruction(pinst.getInstructionString()); } else if((pinst.getOpcode().equals("transformdecode") || pinst.getOpcode().equals("transformapply")) && diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java index 8821a71..ce7f3b4 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java @@ -234,7 +234,7 @@ public class InitFEDInstruction extends FEDInstruction { } try { int timeout = ConfigurationManager.getDMLConfig().getIntValue(DMLConfig.DEFAULT_FEDERATED_INITIALIZATION_TIMEOUT); - LOG.error("Federated Initialization with timeout: " + timeout); + LOG.debug("Federated Initialization with timeout: " + timeout); for (Pair<FederatedData, Future<FederatedResponse>> idResponse : idResponses) idResponse.getRight().get(timeout,TimeUnit.SECONDS); //wait for initialization } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java index f549dca..c50671e 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java @@ -19,11 +19,14 @@ package org.apache.sysds.runtime.instructions.fed; +import java.util.AbstractMap; +import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.LinkedHashMap; - import java.util.List; +import java.util.Map; + import org.apache.sysds.common.Types; import org.apache.sysds.common.Types.DataType; import org.apache.sysds.common.Types.ValueType; @@ -34,6 +37,7 @@ import org.apache.sysds.runtime.controlprogram.caching.CacheableData; import org.apache.sysds.runtime.controlprogram.caching.FrameObject; import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.controlprogram.federated.FederatedRange; import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest; import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse; import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse.ResponseType; @@ -47,6 +51,7 @@ import org.apache.sysds.runtime.instructions.cp.CPOperand; import org.apache.sysds.runtime.instructions.cp.Data; import org.apache.sysds.runtime.matrix.data.FrameBlock; import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.operators.BinaryOperator; import org.apache.sysds.runtime.matrix.operators.Operator; import org.apache.sysds.runtime.matrix.operators.SimpleOperator; import org.apache.sysds.runtime.meta.MatrixCharacteristics; @@ -100,7 +105,7 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio LinkedHashMap<String, String> paramsMap = constructParameterMap(parts); // determine the appropriate value function - if( opcode.equalsIgnoreCase("replace") ) { + if(opcode.equalsIgnoreCase("replace") || opcode.equalsIgnoreCase("rmempty")) { ValueFunction func = ParameterizedBuiltin.getParameterizedBuiltinFnObject(opcode); return new ParameterizedBuiltinFEDInstruction(new SimpleOperator(func), paramsMap, out, opcode, str); } @@ -120,8 +125,10 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio // similar to unary federated instructions, get federated input // execute instruction, and derive federated output matrix MatrixObject mo = (MatrixObject) getTarget(ec); - FederatedRequest fr1 = FederationUtils.callInstruction(instString, output, - new CPOperand[] {getTargetOperand()}, new long[] {mo.getFedMapping().getID()}); + FederatedRequest fr1 = FederationUtils.callInstruction(instString, + output, + new CPOperand[] {getTargetOperand()}, + new long[] {mo.getFedMapping().getID()}); mo.getFedMapping().execute(getTID(), true, fr1); // derive new fed mapping for output @@ -129,6 +136,9 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio out.getDataCharacteristics().set(mo.getDataCharacteristics()); out.setFedMapping(mo.getFedMapping().copyWithNewID(fr1.getID())); } + else if(opcode.equals("rmempty")) { + rmempty(ec); + } else if(opcode.equalsIgnoreCase("transformdecode")) transformDecode(ec); else if(opcode.equalsIgnoreCase("transformapply")) @@ -138,6 +148,136 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio } } + private void rmempty(ExecutionContext ec) { + MatrixObject mo = (MatrixObject) getTarget(ec); + MatrixObject out = ec.getMatrixObject(output); + Map<FederatedRange, int[]> dcs; + if((instString.contains("margin=rows") && mo.isFederated(FederationMap.FType.ROW)) || + (instString.contains("margin=cols") && mo.isFederated(FederationMap.FType.COL))) { + FederatedRequest fr1 = FederationUtils.callInstruction(instString, + output, + new CPOperand[] {getTargetOperand()}, + new long[] {mo.getFedMapping().getID()}); + mo.getFedMapping().execute(getTID(), true, fr1); + out.setFedMapping(mo.getFedMapping().copyWithNewID(fr1.getID())); + + // new ranges + dcs = new HashMap<>(); + out.getFedMapping().forEachParallel((range, data) -> { + try { + FederatedResponse response = data + .executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1, + new GetDataCharacteristics(data.getVarID()))) + .get(); + + if(!response.isSuccessful()) + response.throwExceptionFromResponse(); + int[] subRangeCharacteristics = (int[]) response.getData()[0]; + synchronized(dcs) { + dcs.put(range, subRangeCharacteristics); + } + } + catch(Exception e) { + throw new DMLRuntimeException(e); + } + return null; + }); + } + else { + Map.Entry<FederationMap, Map<FederatedRange, int[]>> entry = rmemptyC(ec, mo); + out.setFedMapping(entry.getKey()); + dcs = entry.getValue(); + } + out.getDataCharacteristics().set(mo.getDataCharacteristics()); + for(int i = 0; i < mo.getFedMapping().getFederatedRanges().length; i++) { + int[] newRange = dcs.get(out.getFedMapping().getFederatedRanges()[i]); + + out.getFedMapping().getFederatedRanges()[i].setBeginDim(0, + (out.getFedMapping().getFederatedRanges()[i].getBeginDims()[0] == 0 || + i == 0) ? 0 : out.getFedMapping().getFederatedRanges()[i - 1].getEndDims()[0]); + + out.getFedMapping().getFederatedRanges()[i].setEndDim(0, + out.getFedMapping().getFederatedRanges()[i].getBeginDims()[0] + newRange[0]); + + out.getFedMapping().getFederatedRanges()[i].setBeginDim(1, + (out.getFedMapping().getFederatedRanges()[i].getBeginDims()[1] == 0 || + i == 0) ? 0 : out.getFedMapping().getFederatedRanges()[i - 1].getEndDims()[1]); + + out.getFedMapping().getFederatedRanges()[i].setEndDim(1, + out.getFedMapping().getFederatedRanges()[i].getBeginDims()[1] + newRange[1]); + } + + out.getDataCharacteristics().set(out.getFedMapping().getMaxIndexInRange(0), + out.getFedMapping().getMaxIndexInRange(1), + (int) mo.getBlocksize()); + } + + private Map.Entry<FederationMap, Map<FederatedRange, int[]>> rmemptyC(ExecutionContext ec, MatrixObject mo) { + boolean marginRow = instString.contains("margin=rows"); + + // find empty in ranges + List<MatrixBlock> colSums = new ArrayList<>(); + mo.getFedMapping().forEachParallel((range, data) -> { + try { + FederatedResponse response = data + .executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1, + new GetVector(data.getVarID(), marginRow))) + .get(); + + if(!response.isSuccessful()) + response.throwExceptionFromResponse(); + MatrixBlock vector = (MatrixBlock) response.getData()[0]; + synchronized(colSums) { + colSums.add(vector); + } + } + catch(Exception e) { + throw new DMLRuntimeException(e); + } + return null; + }); + + // find empty in matrix + BinaryOperator plus = InstructionUtils.parseBinaryOperator("+"); + BinaryOperator greater = InstructionUtils.parseBinaryOperator(">"); + MatrixBlock tmp1 = colSums.get(0); + for(int i = 1; i < colSums.size(); i++) + tmp1 = tmp1.binaryOperationsInPlace(plus, colSums.get(i)); + tmp1 = tmp1.binaryOperationsInPlace(greater, new MatrixBlock(tmp1.getNumRows(), tmp1.getNumColumns(), 0.0)); + + // remove empty from matrix + Map<FederatedRange, int[]> dcs = new HashMap<>(); + long varID = FederationUtils.getNextFedDataID(); + MatrixBlock finalTmp = new MatrixBlock(tmp1); + FederationMap resMapping; + if(tmp1.sum() == (marginRow ? tmp1.getNumColumns() : tmp1.getNumRows())) { + resMapping = mo.getFedMapping(); + } + else { + resMapping = mo.getFedMapping().mapParallel(varID, (range, data) -> { + try { + FederatedResponse response = data + .executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1, + new ParameterizedBuiltinFEDInstruction.RemoveEmpty(data.getVarID(), varID, finalTmp, + params.containsKey("select") ? ec.getMatrixInput(params.get("select")) : null, + Boolean.parseBoolean(params.get("empty.return").toLowerCase()), marginRow))) + .get(); + if(!response.isSuccessful()) + response.throwExceptionFromResponse(); + int[] subRangeCharacteristics = (int[]) response.getData()[0]; + synchronized(dcs) { + dcs.put(range, subRangeCharacteristics); + } + } + catch(Exception e) { + throw new DMLRuntimeException(e); + } + return null; + }); + } + return new AbstractMap.SimpleEntry<>(resMapping, dcs); + } + private void transformDecode(ExecutionContext ec) { // acquire locks MatrixObject mo = ec.getMatrixObject(params.get("target")); @@ -155,14 +295,14 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio long[] beginDims = range.getBeginDims(); long[] endDims = range.getEndDims(); int colStartBefore = (int) beginDims[1]; - + // update begin end dims (column part) considering columns added by dummycoding globalDecoder.updateIndexRanges(beginDims, endDims); - + // get the decoder segment that is relevant for this federated worker Decoder decoder = globalDecoder - .subRangeDecoder((int) beginDims[1] + 1, (int) endDims[1] + 1, colStartBefore); - + .subRangeDecoder((int) beginDims[1] + 1, (int) endDims[1] + 1, colStartBefore); + FrameBlock metaSlice = new FrameBlock(); synchronized(meta) { meta.slice(0, meta.getNumRows() - 1, (int) beginDims[1], (int) endDims[1] - 1, metaSlice); @@ -170,9 +310,8 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio FederatedResponse response; try { - response = data.executeFederatedOperation( - new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1, - new DecodeMatrix(data.getVarID(), varID, metaSlice, decoder))).get(); + response = data.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, + -1, new DecodeMatrix(data.getVarID(), varID, metaSlice, decoder))).get(); if(!response.isSuccessful()) response.throwExceptionFromResponse(); @@ -217,7 +356,8 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio try { FederatedResponse response = data .executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1, - new GetColumnNames(data.getVarID()))).get(); + new GetColumnNames(data.getVarID()))) + .get(); // no synchronization necessary since names should anyway match String[] subRangeColNames = (String[]) response.getData()[0]; @@ -261,7 +401,8 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio EncoderOmit subRangeEncoder = (EncoderOmit) omitEncoder.subRangeEncoder(range.asIndexRange().add(1)); FederatedResponse response = data .executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1, - new InitRowsToRemoveOmit(data.getVarID(), subRangeEncoder))).get(); + new InitRowsToRemoveOmit(data.getVarID(), subRangeEncoder))) + .get(); // no synchronization necessary since names should anyway match Encoder builtEncoder = (Encoder) response.getData()[0]; @@ -283,7 +424,7 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio private CPOperand getTargetOperand() { return new CPOperand(params.get("target"), ValueType.FP64, DataType.MATRIX); } - + public static class DecodeMatrix extends FederatedUDF { private static final long serialVersionUID = 2376756757742169692L; private final long _outputID; @@ -330,7 +471,7 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio @Override public FederatedResponse execute(ExecutionContext ec, Data... data) { - FrameBlock fb = ((FrameObject)data[0]).acquireReadAndRelease(); + FrameBlock fb = ((FrameObject) data[0]).acquireReadAndRelease(); // return column names return new FederatedResponse(ResponseType.SUCCESS, new Object[] {fb.getColumnNames()}); } @@ -348,9 +489,98 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio @Override public FederatedResponse execute(ExecutionContext ec, Data... data) { - FrameBlock fb = ((FrameObject)data[0]).acquireReadAndRelease(); + FrameBlock fb = ((FrameObject) data[0]).acquireReadAndRelease(); _encoder.build(fb); return new FederatedResponse(ResponseType.SUCCESS, new Object[] {_encoder}); } } + + private static class GetDataCharacteristics extends FederatedUDF { + + private static final long serialVersionUID = 578461386177730925L; + + public GetDataCharacteristics(long varID) { + super(new long[] {varID}); + } + + @Override + public FederatedResponse execute(ExecutionContext ec, Data... data) { + MatrixBlock mb = ((MatrixObject) data[0]).acquireReadAndRelease(); + return new FederatedResponse(ResponseType.SUCCESS, new int[] {mb.getNumRows(), mb.getNumColumns()}); + } + } + + private static class RemoveEmpty extends FederatedUDF { + + private static final long serialVersionUID = 12341521331L; + private final MatrixBlock _vector; + private final long _outputID; + private MatrixBlock _select; + private boolean _emptyReturn; + private final boolean _marginRow; + + public RemoveEmpty(long varID, long outputID, MatrixBlock vector, MatrixBlock select, boolean emptyReturn, + boolean marginRow) { + super(new long[] {varID}); + _vector = vector; + _outputID = outputID; + _select = select; + _emptyReturn = emptyReturn; + _marginRow = marginRow; + } + + @Override + public FederatedResponse execute(ExecutionContext ec, Data... data) { + MatrixBlock mb = ((MatrixObject) data[0]).acquireReadAndRelease(); + + BinaryOperator plus = InstructionUtils.parseBinaryOperator("+"); + BinaryOperator minus = InstructionUtils.parseBinaryOperator("-"); + + mb = mb.binaryOperationsInPlace(plus, new MatrixBlock(mb.getNumRows(), mb.getNumColumns(), 1.0)); + for(int i = 0; i < mb.getNumRows(); i++) + for(int j = 0; j < mb.getNumColumns(); j++) + if(_marginRow) + mb.setValue(i, j, _vector.getValue(i, 0) * mb.getValue(i, j)); + else + mb.setValue(i, j, _vector.getValue(0, j) * mb.getValue(i, j)); + + MatrixBlock res = mb.removeEmptyOperations(new MatrixBlock(), _marginRow, _emptyReturn, _select); + res = res.binaryOperationsInPlace(minus, new MatrixBlock(res.getNumRows(), res.getNumColumns(), 1.0)); + + MatrixObject mout = ExecutionContext.createMatrixObject(res); + ec.setVariable(String.valueOf(_outputID), mout); + + return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, + new int[] {res.getNumRows(), res.getNumColumns()}); + } + } + + private static class GetVector extends FederatedUDF { + + private static final long serialVersionUID = -1003061862215703768L; + private final boolean _marginRow; + + public GetVector(long varID, boolean marginRow) { + super(new long[] {varID}); + _marginRow = marginRow; + } + + @Override + public FederatedResponse execute(ExecutionContext ec, Data... data) { + MatrixBlock mb = ((MatrixObject) data[0]).acquireReadAndRelease(); + + BinaryOperator plus = InstructionUtils.parseBinaryOperator("+"); + BinaryOperator greater = InstructionUtils.parseBinaryOperator(">"); + int len = _marginRow ? mb.getNumColumns() : mb.getNumRows(); + MatrixBlock tmp1 = _marginRow ? mb.slice(0, mb.getNumRows() - 1, 0, 0, new MatrixBlock()) : mb + .slice(0, 0, 0, mb.getNumColumns() - 1, new MatrixBlock()); + for(int i = 1; i < len; i++) { + MatrixBlock tmp2 = _marginRow ? mb.slice(0, mb.getNumRows() - 1, i, i, new MatrixBlock()) : mb + .slice(i, i, 0, mb.getNumColumns() - 1, new MatrixBlock()); + tmp1 = tmp1.binaryOperationsInPlace(plus, tmp2); + } + tmp1 = tmp1.binaryOperationsInPlace(greater, new MatrixBlock(tmp1.getNumRows(), tmp1.getNumColumns(), 0.0)); + return new FederatedResponse(ResponseType.SUCCESS, tmp1); + } + } } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRemoveEmptyTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRemoveEmptyTest.java new file mode 100644 index 0000000..de1e6d5 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRemoveEmptyTest.java @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.federated.primitives; + +import java.util.Arrays; +import java.util.Collection; + +import org.apache.sysds.api.DMLScript; +import org.apache.sysds.common.Types.ExecMode; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.runtime.util.HDFSTool; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +@RunWith(value = Parameterized.class) +@net.jcip.annotations.NotThreadSafe +public class FederatedRemoveEmptyTest extends AutomatedTestBase { + // private static final Log LOG = LogFactory.getLog(FederatedRightIndexTest.class.getName()); + + private final static String TEST_NAME = "FederatedRemoveEmptyTest"; + + private final static String TEST_DIR = "functions/federated/"; + private static final String TEST_CLASS_DIR = TEST_DIR + FederatedRemoveEmptyTest.class.getSimpleName() + "/"; + + private final static int blocksize = 1024; + @Parameterized.Parameter() + public int rows; + @Parameterized.Parameter(1) + public int cols; + + @Parameterized.Parameter(2) + public boolean rowPartitioned; + + @Parameterized.Parameters + public static Collection<Object[]> data() { + return Arrays.asList(new Object[][] {{20, 10, true}, {20, 12, false}}); + } + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"S"})); + } + + @Test + public void testRemoveEmptyCP() { + runAggregateOperationTest(ExecMode.SINGLE_NODE); + } + + private void runAggregateOperationTest(ExecMode execMode) { + boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; + ExecMode platformOld = rtplatform; + + if(rtplatform == ExecMode.SPARK) + DMLScript.USE_LOCAL_SPARK_CONFIG = true; + + getAndLoadTestConfiguration(TEST_NAME); + String HOME = SCRIPT_DIR + TEST_DIR; + + // write input matrices + int r = rows; + int c = cols / 4; + if(rowPartitioned) { + r = rows / 4; + c = cols; + } + + double[][] X1 = getRandomMatrix(r, c, 1, 5, 1, 3); + double[][] X2 = getRandomMatrix(r, c, 1, 5, 1, 7); + double[][] X3 = getRandomMatrix(r, c, 1, 5, 1, 8); + double[][] X4 = getRandomMatrix(r, c, 1, 5, 1, 9); + + for(int k : new int[] {1, 2, 3}) { + Arrays.fill(X3[k], 0); + if(!rowPartitioned) { + Arrays.fill(X1[k], 0); + Arrays.fill(X2[k], 0); + Arrays.fill(X4[k], 0); + } + } + + MatrixCharacteristics mc = new MatrixCharacteristics(r, c, blocksize, r * c); + writeInputMatrixWithMTD("X1", X1, false, mc); + writeInputMatrixWithMTD("X2", X2, false, mc); + writeInputMatrixWithMTD("X3", X3, false, mc); + writeInputMatrixWithMTD("X4", X4, false, mc); + + // empty script name because we don't execute any script, just start the worker + fullDMLScriptName = ""; + int port1 = getRandomAvailablePort(); + int port2 = getRandomAvailablePort(); + int port3 = getRandomAvailablePort(); + int port4 = getRandomAvailablePort(); + Thread t1 = startLocalFedWorkerThread(port1, 10); + Thread t2 = startLocalFedWorkerThread(port2, 10); + Thread t3 = startLocalFedWorkerThread(port3, 10); + Thread t4 = startLocalFedWorkerThread(port4); + + rtplatform = execMode; + if(rtplatform == ExecMode.SPARK) { + System.out.println(7); + DMLScript.USE_LOCAL_SPARK_CONFIG = true; + } + TestConfiguration config = availableTestConfigurations.get(TEST_NAME); + loadTestConfiguration(config); + + // Run reference dml script with normal matrix + fullDMLScriptName = HOME + TEST_NAME + "Reference.dml"; + programArgs = new String[] {"-stats", "100", "-args", input("X1"), input("X2"), input("X3"), input("X4"), + Boolean.toString(rowPartitioned).toUpperCase(), expected("S")}; + + runTest(null); + + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] {"-stats", "100", "-nvargs", + "in_X1=" + TestUtils.federatedAddress(port1, input("X1")), + "in_X2=" + TestUtils.federatedAddress(port2, input("X2")), + "in_X3=" + TestUtils.federatedAddress(port3, input("X3")), + "in_X4=" + TestUtils.federatedAddress(port4, input("X4")), "rows=" + rows, "cols=" + cols, + "rP=" + Boolean.toString(rowPartitioned).toUpperCase(), "out_S=" + output("S")}; + + runTest(null); + + // compare via files + compareResults(1e-9); + + // check that federated input files are still existing + Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1"))); + Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2"))); + Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X3"))); + Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4"))); + + TestUtils.shutdownThreads(t1, t2, t3, t4); + + rtplatform = platformOld; + DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + + } +} diff --git a/src/test/scripts/functions/federated/FederatedRemoveEmptyTest.dml b/src/test/scripts/functions/federated/FederatedRemoveEmptyTest.dml new file mode 100644 index 0000000..0c6b77b --- /dev/null +++ b/src/test/scripts/functions/federated/FederatedRemoveEmptyTest.dml @@ -0,0 +1,33 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +if ($rP) { + A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4), + ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), list(2*$rows/4, $cols), + list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows, $cols))); +} else { + A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4), + ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4), list($rows, $cols/2), + list(0,$cols/2), list($rows, 3*($cols/4)), list(0, 3*($cols/4)), list($rows, $cols))); +} + +s = removeEmpty(target=A, margin="cols"); +write(s, $out_S); diff --git a/src/test/scripts/functions/federated/FederatedRemoveEmptyTestReference.dml b/src/test/scripts/functions/federated/FederatedRemoveEmptyTestReference.dml new file mode 100644 index 0000000..c4b2dc9 --- /dev/null +++ b/src/test/scripts/functions/federated/FederatedRemoveEmptyTestReference.dml @@ -0,0 +1,26 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +if($5) { A = rbind(read($1), read($2), read($3), read($4)); } +else { A = cbind(read($1), read($2), read($3), read($4)); } + +s = removeEmpty(target=A, margin="cols"); +write(s, $6);