This is an automated email from the ASF dual-hosted git repository. baunsgaard pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push: new fd178aa [SYSTEMDS-2738] Federated rdiag, rev and uppertri instructions fd178aa is described below commit fd178aa36f02daea0372e4aab6b8f77711bc7a98 Author: Olga <ovcharenko.fo...@gmail.com> AuthorDate: Fri Nov 20 14:55:34 2020 +0100 [SYSTEMDS-2738] Federated rdiag, rev and uppertri instructions Federated support for diagonal, reverse and upper triangle. There are some TODO added in this commit as well: - Add a sorting method for federated map to sort the federatedInstances. - Reverse use slice and allocate many matrices, optimize this by leveraging underlying Dense and Sparse Blocks. Closes #1112 --- .../controlprogram/federated/FederationMap.java | 17 ++ .../instructions/fed/FEDInstructionUtils.java | 7 +- .../fed/ParameterizedBuiltinFEDInstruction.java | 108 +++++++- .../instructions/fed/ReorgFEDInstruction.java | 272 +++++++++++++++++++-- .../federated/primitives/FederatedRdiagTest.java | 145 +++++++++++ .../federated/primitives/FederatedRevTest.java | 160 ++++++++++++ .../federated/primitives/FederatedTriTest.java | 149 +++++++++++ .../federated/FederatedLmPipelineReference.dml | 2 +- .../functions/federated/FederatedRdiagTest.dml | 27 ++ .../federated/FederatedRdiagTestReference.dml | 25 ++ .../functions/federated/FederatedRevTest.dml | 32 +++ .../federated/FederatedRevTestReference.dml | 26 ++ .../functions/federated/FederatedTriTest.dml | 33 +++ .../federated/FederatedTriTestReference.dml | 26 ++ 14 files changed, 1006 insertions(+), 23 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java index 482ade7..e933979 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java @@ -391,6 +391,23 @@ public class FederationMap { Arrays.stream(frset).forEach(fr -> fr.setTID(tid)); } + public void reverseFedMap() { + // TODO: add a check if the map is sorted based on indexes before reversing. + // TODO: add a setup such that on construction the federated map is already sorted. + FederatedRange[] fedRanges = this.getFederatedRanges(); + + for(int i = 0; i < Math.floor(fedRanges.length / 2.0); i++) { + FederatedData data1 = _fedMap.get(fedRanges[i]); + FederatedData data2 = _fedMap.get(fedRanges[fedRanges.length-1-i]); + + _fedMap.remove(fedRanges[i]); + _fedMap.remove(fedRanges[fedRanges.length-1-i]); + + _fedMap.put(fedRanges[i], data2); + _fedMap.put(fedRanges[fedRanges.length-1-i], data1); + } + } + private static class MappingTask implements Callable<Void> { private final FederatedRange _range; private final FederatedData _data; diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java index af17637..e6a64cb 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java @@ -91,7 +91,8 @@ public class FEDInstructionUtils { } else if (inst instanceof UnaryCPInstruction && ! (inst instanceof IndexingCPInstruction)) { UnaryCPInstruction instruction = (UnaryCPInstruction) inst; - if(inst instanceof ReorgCPInstruction && inst.getOpcode().equals("r'")) { + if(inst instanceof ReorgCPInstruction && (inst.getOpcode().equals("r'") || inst.getOpcode().equals("rdiag") + || inst.getOpcode().equals("rev"))) { ReorgCPInstruction rinst = (ReorgCPInstruction) inst; CacheableData<?> mo = ec.getCacheableData(rinst.input1); @@ -129,7 +130,9 @@ public class FEDInstructionUtils { } else if( inst instanceof ParameterizedBuiltinCPInstruction ) { ParameterizedBuiltinCPInstruction pinst = (ParameterizedBuiltinCPInstruction) inst; - if((pinst.getOpcode().equals("replace") || pinst.getOpcode().equals("rmempty")) && pinst.getTarget(ec).isFederated()) { + if((pinst.getOpcode().equals("replace") || pinst.getOpcode().equals("rmempty") + || pinst.getOpcode().equals("lowertri") || pinst.getOpcode().equals("uppertri")) + && pinst.getTarget(ec).isFederated()) { fedinst = ParameterizedBuiltinFEDInstruction.parseInstruction(pinst.getInstructionString()); } else if((pinst.getOpcode().equals("transformdecode") || pinst.getOpcode().equals("transformapply")) && diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java index 6588909..a2b63e9 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java @@ -104,7 +104,8 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio LinkedHashMap<String, String> paramsMap = constructParameterMap(parts); // determine the appropriate value function - if(opcode.equalsIgnoreCase("replace") || opcode.equalsIgnoreCase("rmempty")) { + if(opcode.equalsIgnoreCase("replace") || opcode.equalsIgnoreCase("rmempty") + || opcode.equalsIgnoreCase("lowertri") || opcode.equalsIgnoreCase("uppertri")) { ValueFunction func = ParameterizedBuiltin.getParameterizedBuiltinFnObject(opcode); return new ParameterizedBuiltinFEDInstruction(new SimpleOperator(func), paramsMap, out, opcode, str); } @@ -137,6 +138,8 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio } else if(opcode.equals("rmempty")) rmempty(ec); + else if(opcode.equals("lowertri") || opcode.equals("uppertri")) + triangle(ec, opcode); else if(opcode.equalsIgnoreCase("transformdecode")) transformDecode(ec); else if(opcode.equalsIgnoreCase("transformapply")) @@ -145,22 +148,117 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio throw new DMLRuntimeException("Unknown opcode : " + opcode); } } + + private void triangle(ExecutionContext ec, String opcode) { + boolean lower = opcode.equals("lowertri"); + boolean diag = Boolean.parseBoolean(params.get("diag")); + boolean values = Boolean.parseBoolean(params.get("values")); + + MatrixObject mo = (MatrixObject) getTarget(ec); + + FederationMap fedMap = mo.getFedMapping(); + boolean rowFed = mo.isFederated(FederationMap.FType.ROW); + + long varID = FederationUtils.getNextFedDataID(); + FederationMap diagFedMap; + + diagFedMap = fedMap.mapParallel(varID, (range, data) -> { + try { + FederatedResponse response = data.executeFederatedOperation(new FederatedRequest( + FederatedRequest.RequestType.EXEC_UDF, -1, + new ParameterizedBuiltinFEDInstruction.Tri(data.getVarID(), varID, + rowFed ? (new int[] {range.getBeginDimsInt()[0], range.getEndDimsInt()[0]}) : + new int[] {range.getBeginDimsInt()[1], range.getEndDimsInt()[1]}, + rowFed, lower, diag, values))).get(); + if(!response.isSuccessful()) + response.throwExceptionFromResponse(); + return null; + } + catch(Exception e) { + throw new DMLRuntimeException(e); + } + }); + MatrixObject out = ec.getMatrixObject(output); + out.setFedMapping(diagFedMap); + } + + private static class Tri extends FederatedUDF { + private static final long serialVersionUID = 6254009025304038215L; + + private final long _outputID; + private final int[] _slice; + private final boolean _rowFed; + private final boolean _lower; + private final boolean _diag; + private final boolean _values; + + private Tri(long input, long outputID, int[] slice, boolean rowFed, boolean lower, boolean diag, boolean values) { + super(new long[] {input}); + _outputID = outputID; + _slice = slice; + _rowFed = rowFed; + _lower = lower; + _diag = diag; + _values = values; + } + + @Override + public FederatedResponse execute(ExecutionContext ec, Data... data) { + MatrixBlock mb = ((MatrixObject) data[0]).acquireReadAndRelease(); + MatrixBlock soresBlock, addBlock; + MatrixBlock ret; + + //slice + soresBlock = _rowFed ? + mb.slice(0, mb.getNumRows()-1, _slice[0], _slice[1]-1, new MatrixBlock()) : + mb.slice(_slice[0], _slice[1]-1); + + //triangle + MatrixBlock tri = soresBlock.extractTriangular(new MatrixBlock(), _lower, _diag, _values); + // todo: optimize to not allocate and slice all these matrix blocks, but leveraging underlying dense or sparse blocks. + if(_rowFed) { + ret = new MatrixBlock(mb.getNumRows(), mb.getNumColumns(), 0.0); + ret.copy(0, ret.getNumRows()-1, _slice[0], _slice[1]-1, tri, false); + if(_slice[1] <= mb.getNumColumns()-1 && !_lower) { + addBlock = mb.slice(0, mb.getNumRows()-1, _slice[1], mb.getNumColumns()-1, new MatrixBlock()); + ret.copy(0, ret.getNumRows()-1, _slice[1], ret.getNumColumns() - 1, addBlock, false); + } else if(_slice[0] > 0 && _lower) { + addBlock = mb.slice(0, mb.getNumRows()-1, 0, _slice[0]-1, new MatrixBlock()); + ret.copy(0, ret.getNumRows()-1, 0, _slice[0]-1, addBlock, false); + } + } else { + ret = new MatrixBlock(mb.getNumRows(), mb.getNumColumns(), 0.0); + ret.copy(_slice[0], _slice[1]-1, 0, mb.getNumColumns() - 1, tri, false); + if(_slice[0] > 0 && !_lower) { + addBlock = mb.slice(0, _slice[0]-1,0, mb.getNumColumns()-1, new MatrixBlock()); + ret.copy(0, ret.getNumRows() - 1, _slice[1], ret.getNumColumns() - 1, addBlock, false); + } else if(_slice[1] <= mb.getNumRows() &&_lower) { + addBlock = mb.slice(_slice[1], ret.getNumRows()-1,0, mb.getNumColumns()-1, new MatrixBlock()); + ret.copy(_slice[1], ret.getNumRows() - 1, 0, mb.getNumColumns()-1, addBlock, false); + } + } + MatrixObject mout = ExecutionContext.createMatrixObject(ret); + ec.setVariable(String.valueOf(_outputID), mout); + + return new FederatedResponse(ResponseType.SUCCESS_EMPTY); + } + } private void rmempty(ExecutionContext ec) { String margin = params.get("margin"); if( !(margin.equals("rows") || margin.equals("cols")) ) - throw new DMLRuntimeException("Unspupported margin identifier '"+margin+"'."); + throw new DMLRuntimeException("Unsupported margin identifier '"+margin+"'."); MatrixObject mo = (MatrixObject) getTarget(ec); MatrixObject select = params.containsKey("select") ? ec.getMatrixObject(params.get("select")) : null; MatrixObject out = ec.getMatrixObject(output); boolean marginRow = params.get("margin").equals("rows"); - boolean k = ((marginRow && mo.getFedMapping().getType().isColPartitioned()) || + boolean isNotAligned = ((marginRow && mo.getFedMapping().getType().isColPartitioned()) || (!marginRow && mo.getFedMapping().getType().isRowPartitioned())); MatrixBlock s = new MatrixBlock(); - if(select == null && k) { + if(select == null && isNotAligned) { List<MatrixBlock> colSums = new ArrayList<>(); mo.getFedMapping().forEachParallel((range, data) -> { try { @@ -209,7 +307,7 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio mo.getFedMapping().execute(getTID(), true, fr1); out.setFedMapping(mo.getFedMapping().copyWithNewID(fr1.getID())); } - else if (!k) { + else if (!isNotAligned) { //construct commands: broadcast , fed rmempty, clean broadcast FederatedRequest[] fr1 = mo.getFedMapping().broadcastSliced(select, !marginRow); FederatedRequest fr2 = FederationUtils.callInstruction(instString, diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java index d9e9d97..b2f3a53 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java @@ -19,28 +19,60 @@ package org.apache.sysds.runtime.instructions.fed; +import java.util.AbstractMap; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.apache.sysds.common.Types; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.controlprogram.federated.FederatedRange; import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest; +import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse; +import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF; +import org.apache.sysds.runtime.controlprogram.federated.FederationMap; import org.apache.sysds.runtime.controlprogram.federated.FederationUtils; +import org.apache.sysds.runtime.functionobjects.DiagIndex; +import org.apache.sysds.runtime.functionobjects.RevIndex; +import org.apache.sysds.runtime.functionobjects.SortIndex; +import org.apache.sysds.runtime.functionobjects.SwapIndex; import org.apache.sysds.runtime.instructions.InstructionUtils; import org.apache.sysds.runtime.instructions.cp.CPOperand; +import org.apache.sysds.runtime.instructions.cp.Data; +import org.apache.sysds.runtime.instructions.cp.ReorgCPInstruction; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.operators.Operator; +import org.apache.sysds.runtime.matrix.operators.ReorgOperator; +import org.apache.sysds.runtime.util.IndexRange; public class ReorgFEDInstruction extends UnaryFEDInstruction { - public ReorgFEDInstruction(CPOperand in1, CPOperand out, String opcode, String istr) { - super(FEDType.Reorg, null, in1, out, opcode, istr); + public ReorgFEDInstruction(Operator op, CPOperand in1, CPOperand out, String opcode, String istr) { + super(FEDType.Reorg, op, in1, out, opcode, istr); } public static ReorgFEDInstruction parseInstruction ( String str ) { + CPOperand in = new CPOperand("", Types.ValueType.UNKNOWN, Types.DataType.UNKNOWN); + CPOperand out = new CPOperand("", Types.ValueType.UNKNOWN, Types.DataType.UNKNOWN); + String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); String opcode = parts[0]; if ( opcode.equalsIgnoreCase("r'") ) { InstructionUtils.checkNumFields(str, 2, 3); - CPOperand in = new CPOperand(parts[1]); - CPOperand out = new CPOperand(parts[2]); - return new ReorgFEDInstruction(in, out, opcode, str); + in.split(parts[1]); + out.split(parts[2]); + int k = Integer.parseInt(parts[3]); + return new ReorgFEDInstruction(new ReorgOperator(SwapIndex.getSwapIndexFnObject(), k), in, out, opcode, str); + } + else if ( opcode.equalsIgnoreCase("rdiag") ) { + parseUnaryInstruction(str, in, out); //max 2 operands + return new ReorgFEDInstruction(new ReorgOperator(DiagIndex.getDiagIndexFnObject()), in, out, opcode, str); + } + else if ( opcode.equalsIgnoreCase("rev") ) { + parseUnaryInstruction(str, in, out); //max 2 operands + return new ReorgFEDInstruction(new ReorgOperator(RevIndex.getRevIndexFnObject()), in, out, opcode, str); } else { throw new DMLRuntimeException("ReorgFEDInstruction: unsupported opcode: "+opcode); @@ -50,20 +82,230 @@ public class ReorgFEDInstruction extends UnaryFEDInstruction { @Override public void processInstruction(ExecutionContext ec) { MatrixObject mo1 = ec.getMatrixObject(input1); - + ReorgOperator r_op = (ReorgOperator) _optr; + if( !mo1.isFederated() ) throw new DMLRuntimeException("Federated Reorg: " + "Federated input expected, but invoked w/ "+mo1.isFederated()); - //execute transpose at federated site - FederatedRequest fr1 = FederationUtils.callInstruction(instString, output, - new CPOperand[]{input1}, new long[]{mo1.getFedMapping().getID()}); - mo1.getFedMapping().execute(getTID(), true, fr1); + if(instOpcode.equals("r'")) { + //execute transpose at federated site + FederatedRequest fr1 = FederationUtils.callInstruction(instString, + output, + new CPOperand[] {input1}, + new long[] {mo1.getFedMapping().getID()}); + mo1.getFedMapping().execute(getTID(), true, fr1); + + //drive output federated mapping + MatrixObject out = ec.getMatrixObject(output); + out.getDataCharacteristics().set(mo1.getNumColumns(), mo1.getNumRows(), (int) mo1.getBlocksize(), mo1.getNnz()); + out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr1.getID()).transpose()); + } + else if(instOpcode.equalsIgnoreCase("rev")) { + //execute transpose at federated site + FederatedRequest fr1 = FederationUtils.callInstruction(instString, + output, + new CPOperand[] {input1}, + new long[] {mo1.getFedMapping().getID()}); + mo1.getFedMapping().execute(getTID(), true, fr1); + + if(mo1.isFederated(FederationMap.FType.ROW)) + mo1.getFedMapping().reverseFedMap(); + + //derive output federated mapping + MatrixObject out = ec.getMatrixObject(output); + out.getDataCharacteristics().set(mo1.getNumRows(), mo1.getNumColumns(), (int) mo1.getBlocksize(), mo1.getNnz()); + out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr1.getID())); + } + else if (instOpcode.equals("rdiag")) { + RdiagResult result; + // diag(diag(X)) + if (mo1.getNumColumns() == 1 && mo1.getNumRows() != 1) { + result = rdiagV2M(mo1, r_op); + } else { + result = rdiagM2V(mo1, r_op); + } + + FederationMap diagFedMap = result.getFedMap(); + Map<FederatedRange, int[]> dcs = result.getDcs(); + + //update fed ranges + for(int i = 0; i < diagFedMap.getFederatedRanges().length; i++) { + int[] newRange = dcs.get(diagFedMap.getFederatedRanges()[i]); + + diagFedMap.getFederatedRanges()[i].setBeginDim(0, + (diagFedMap.getFederatedRanges()[i].getBeginDims()[0] == 0 || + i == 0) ? 0 : diagFedMap.getFederatedRanges()[i - 1].getEndDims()[0]); + diagFedMap.getFederatedRanges()[i].setEndDim(0, + diagFedMap.getFederatedRanges()[i].getBeginDims()[0] + newRange[0]); + diagFedMap.getFederatedRanges()[i].setBeginDim(1, + (diagFedMap.getFederatedRanges()[i].getBeginDims()[1] == 0 || + i == 0) ? 0 : diagFedMap.getFederatedRanges()[i - 1].getEndDims()[1]); + diagFedMap.getFederatedRanges()[i].setEndDim(1, + diagFedMap.getFederatedRanges()[i].getBeginDims()[1] + newRange[1]); + } + + //update output mapping and data characteristics + MatrixObject rdiag = ec.getMatrixObject(output); + rdiag.getDataCharacteristics() + .set(diagFedMap.getMaxIndexInRange(0), diagFedMap.getMaxIndexInRange(1), + (int) mo1.getBlocksize()); + rdiag.setFedMapping(diagFedMap); + } + } + + private class RdiagResult { + FederationMap fedMap; + Map<FederatedRange, int[]> dcs; + + public RdiagResult(FederationMap fedMap, Map<FederatedRange, int[]> dcs) { + this.fedMap = fedMap; + this.dcs = dcs; + } + + public FederationMap getFedMap() { + return fedMap; + } + + public Map<FederatedRange, int[]> getDcs() { + return dcs; + } + } + + private RdiagResult rdiagV2M (MatrixObject mo1, ReorgOperator r_op) { + FederationMap fedMap = mo1.getFedMapping(); + boolean rowFed = mo1.isFederated(FederationMap.FType.ROW); - //drive output federated mapping - MatrixObject out = ec.getMatrixObject(output); - out.getDataCharacteristics().set(mo1.getNumColumns(), - mo1.getNumRows(), (int)mo1.getBlocksize(), mo1.getNnz()); - out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr1.getID()).transpose()); + long varID = FederationUtils.getNextFedDataID(); + Map<FederatedRange, int[]> dcs = new HashMap<>(); + FederationMap diagFedMap; + + diagFedMap = fedMap.mapParallel(varID, (range, data) -> { + try { + FederatedResponse response = data.executeFederatedOperation(new FederatedRequest( + FederatedRequest.RequestType.EXEC_UDF, -1, + new ReorgFEDInstruction.DiagMatrix(data.getVarID(), + varID, r_op, + rowFed ? (new int[] {range.getBeginDimsInt()[0], range.getEndDimsInt()[0]}) : + new int[] {range.getBeginDimsInt()[1], range.getEndDimsInt()[1]}, + rowFed, (int) mo1.getNumRows()))).get(); + if(!response.isSuccessful()) + response.throwExceptionFromResponse(); + int[] subRangeCharacteristics = (int[]) response.getData()[0]; + synchronized(dcs) { + dcs.put(range, subRangeCharacteristics); + } + return null; + } + catch(Exception e) { + throw new DMLRuntimeException(e); + } + }); + return new RdiagResult(diagFedMap, dcs); + } + + private RdiagResult rdiagM2V (MatrixObject mo1, ReorgOperator r_op) { + FederationMap fedMap = mo1.getFedMapping(); + boolean rowFed = mo1.isFederated(FederationMap.FType.ROW); + + long varID = FederationUtils.getNextFedDataID(); + Map<FederatedRange, int[]> dcs = new HashMap<>(); + FederationMap diagFedMap; + + diagFedMap = fedMap.mapParallel(varID, (range, data) -> { + try { + FederatedResponse response = data.executeFederatedOperation(new FederatedRequest( + FederatedRequest.RequestType.EXEC_UDF, -1, + new ReorgFEDInstruction.Rdiag(data.getVarID(), varID, r_op, + rowFed ? (new int[] {range.getBeginDimsInt()[0], range.getEndDimsInt()[0]}) : + new int[] {range.getBeginDimsInt()[1], range.getEndDimsInt()[1]}, + rowFed))).get(); + if(!response.isSuccessful()) + response.throwExceptionFromResponse(); + int[] subRangeCharacteristics = (int[]) response.getData()[0]; + synchronized(dcs) { + dcs.put(range, subRangeCharacteristics); + } + return null; + } + catch(Exception e) { + throw new DMLRuntimeException(e); + } + }); + return new RdiagResult(diagFedMap, dcs); + } + + private static class Rdiag extends FederatedUDF { + + private static final long serialVersionUID = -3466926635958851402L; + private final long _outputID; + private final ReorgOperator _r_op; + private final int[] _slice; + private final boolean _rowFed; + + private Rdiag(long input, long outputID, ReorgOperator r_op, int[] slice, boolean rowFed) { + super(new long[] {input}); + _outputID = outputID; + _r_op = r_op; + _slice = slice; + _rowFed = rowFed; + } + + @Override + public FederatedResponse execute(ExecutionContext ec, Data... data) { + MatrixBlock mb = ((MatrixObject) data[0]).acquireReadAndRelease(); + MatrixBlock soresBlock; + MatrixBlock res; + + soresBlock = _rowFed ? + mb.slice(0, mb.getNumRows() - 1, _slice[0], _slice[1] - 1, new MatrixBlock()) : + mb.slice(_slice[0], _slice[1] - 1); + res = soresBlock.reorgOperations(_r_op, new MatrixBlock(), 0, 0, 0); + + MatrixObject mout = ExecutionContext.createMatrixObject(res); + mout.setDiag(true); + ec.setVariable(String.valueOf(_outputID), mout); + + return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, new int[]{res.getNumRows(), res.getNumColumns()}); + } + } + + private static class DiagMatrix extends FederatedUDF { + + private static final long serialVersionUID = -3466926635958851402L; + private final long _outputID; + private final ReorgOperator _r_op; + private final int _len; + private final int[] _slice; + private final boolean _rowFed; + + private DiagMatrix(long input, long outputID, ReorgOperator r_op, int[] slice, boolean rowFed, int len) { + super(new long[] {input}); + _outputID = outputID; + _r_op = r_op; + _len = len; + _rowFed = rowFed; + _slice = slice; + } + + @Override + public FederatedResponse execute(ExecutionContext ec, Data... data) { + MatrixBlock mb = ((MatrixObject) data[0]).acquireReadAndRelease(); + MatrixBlock res; + + MatrixBlock tmp = mb.reorgOperations(_r_op, new MatrixBlock(), 0, 0, 0); + if(_rowFed) { + res = new MatrixBlock(mb.getNumRows(), _len, 0.0); + res.copy(0, res.getNumRows()-1, _slice[0], _slice[1]-1, tmp, false); + } else { + res = new MatrixBlock(_len, _slice[1], 0.0); + res.copy(_slice[0], _slice[1]-1, 0, mb.getNumColumns() - 1, tmp, false);; + } + MatrixObject mout = ExecutionContext.createMatrixObject(res); + mout.setDiag(true); + ec.setVariable(String.valueOf(_outputID), mout); + + return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, new int[]{res.getNumRows(), res.getNumColumns()}); + } } } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRdiagTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRdiagTest.java new file mode 100644 index 0000000..eedbbbb --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRdiagTest.java @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.federated.primitives; + +import java.util.Arrays; +import java.util.Collection; + +import org.apache.sysds.api.DMLScript; +import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.runtime.util.HDFSTool; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Ignore; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +@RunWith(value = Parameterized.class) +@net.jcip.annotations.NotThreadSafe +public class FederatedRdiagTest extends AutomatedTestBase { + + private final static String TEST_DIR = "functions/federated/"; + private final static String TEST_NAME = "FederatedRdiagTest"; + private final static String TEST_CLASS_DIR = TEST_DIR + FederatedRdiagTest.class.getSimpleName() + "/"; + + private final static int blocksize = 1024; + @Parameterized.Parameter() + public int rows; + + @Parameterized.Parameter(1) + public int cols; + + @Parameterized.Parameters + public static Collection<Object[]> data() { + return Arrays.asList(new Object[][] { + {1000, 1000}, + {1000,1} + }); + } + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"S"})); + } + + @Test + public void federatedRdiagCP() { federatedRdiag(Types.ExecMode.SINGLE_NODE); } + + @Test + @Ignore + public void federatedRdiagSP() { federatedRdiag(Types.ExecMode.SPARK); } + + public void federatedRdiag(Types.ExecMode execMode) { + boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; + Types.ExecMode platformOld = rtplatform; + + getAndLoadTestConfiguration(TEST_NAME); + String HOME = SCRIPT_DIR + TEST_DIR; + + int r = rows / 4; + + double[][] X1 = getRandomMatrix(r, cols, 1, 5, 1, 3); + double[][] X2 = getRandomMatrix(r, cols, 1, 5, 1, 7); + double[][] X3 = getRandomMatrix(r, cols, 1, 5, 1, 8); + double[][] X4 = getRandomMatrix(r, cols, 1, 5, 1, 9); + + MatrixCharacteristics mc = new MatrixCharacteristics(r, cols, blocksize, r*cols); + writeInputMatrixWithMTD("X1", X1, false, mc); + writeInputMatrixWithMTD("X2", X2, false, mc); + writeInputMatrixWithMTD("X3", X3, false, mc); + writeInputMatrixWithMTD("X4", X4, false, mc); + + // empty script name because we don't execute any script, just start the worker + fullDMLScriptName = ""; + int port1 = getRandomAvailablePort(); + int port2 = getRandomAvailablePort(); + int port3 = getRandomAvailablePort(); + int port4 = getRandomAvailablePort(); + Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); + Thread t2 = startLocalFedWorkerThread(port2, FED_WORKER_WAIT_S); + Thread t3 = startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S); + Thread t4 = startLocalFedWorkerThread(port4); + + // reference file should not be written to hdfs, so we set platform here + rtplatform = execMode; + if(rtplatform == Types.ExecMode.SPARK) { + DMLScript.USE_LOCAL_SPARK_CONFIG = true; + } + // Run reference dml script with normal matrix for Row/Col + fullDMLScriptName = HOME + TEST_NAME + "Reference.dml"; + programArgs = new String[] {"-stats", "100", "-args", + input("X1"), input("X2"), input("X3"), input("X4"), expected("S")}; + runTest(null); + + TestConfiguration config = availableTestConfigurations.get(TEST_NAME); + loadTestConfiguration(config); + + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] {"-stats", "100", "-nvargs", + "in_X1=" + TestUtils.federatedAddress(port1, input("X1")), + "in_X2=" + TestUtils.federatedAddress(port2, input("X2")), + "in_X3=" + TestUtils.federatedAddress(port3, input("X3")), + "in_X4=" + TestUtils.federatedAddress(port4, input("X4")), + "rows=" + rows, + "cols=" + cols, + "out_S=" + output("S")}; + runTest(null); + + // compare all sums via files + compareResults(0.01); + + Assert.assertTrue(heavyHittersContainsString("fed_rdiag")); + + // check that federated input files are still existing + Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1"))); + Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2"))); + Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X3"))); + Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4"))); + + TestUtils.shutdownThreads(t1, t2, t3, t4); + rtplatform = platformOld; + DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRevTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRevTest.java new file mode 100644 index 0000000..36996db --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRevTest.java @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.federated.primitives; + +import java.util.Arrays; +import java.util.Collection; + +import org.apache.sysds.api.DMLScript; +import org.apache.sysds.common.Types.ExecMode; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.runtime.util.HDFSTool; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +@RunWith(value = Parameterized.class) +@net.jcip.annotations.NotThreadSafe +public class FederatedRevTest extends AutomatedTestBase { + // private static final Log LOG = LogFactory.getLog(FederatedRightIndexTest.class.getName()); + + private final static String TEST_NAME = "FederatedRevTest"; + + private final static String TEST_DIR = "functions/federated/"; + private static final String TEST_CLASS_DIR = TEST_DIR + FederatedRevTest.class.getSimpleName() + "/"; + + private final static int blocksize = 1024; + @Parameterized.Parameter() + public int rows; + @Parameterized.Parameter(1) + public int cols; + + @Parameterized.Parameter(2) + public boolean rowPartitioned; + + @Parameterized.Parameters + public static Collection<Object[]> data() { + return Arrays.asList(new Object[][] { + {100, 12, true}, + {100, 12, false} + }); + } + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"S"})); + } + + @Test + public void testRevCP() { + runRevTest(ExecMode.SINGLE_NODE); + } + + private void runRevTest(ExecMode execMode) { + boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; + ExecMode platformOld = rtplatform; + + if(rtplatform == ExecMode.SPARK) + DMLScript.USE_LOCAL_SPARK_CONFIG = true; + + getAndLoadTestConfiguration(TEST_NAME); + String HOME = SCRIPT_DIR + TEST_DIR; + + // write input matrices + int r = rows; + int c = cols / 4; + if(rowPartitioned) { + r = rows / 4; + c = cols; + } + + double[][] X1 = getRandomMatrix(r, c, 1, 5, 1, 3); + double[][] X2 = getRandomMatrix(r, c, 1, 5, 1, 7); + double[][] X3 = getRandomMatrix(r, c, 1, 5, 1, 8); + double[][] X4 = getRandomMatrix(r, c, 1, 5, 1, 9); + + for(int k : new int[] {1, 2, 3}) { + Arrays.fill(X3[k], 0); + } + + MatrixCharacteristics mc = new MatrixCharacteristics(r, c, blocksize, r * c); + writeInputMatrixWithMTD("X1", X1, false, mc); + writeInputMatrixWithMTD("X2", X2, false, mc); + writeInputMatrixWithMTD("X3", X3, false, mc); + writeInputMatrixWithMTD("X4", X4, false, mc); + + // empty script name because we don't execute any script, just start the worker + fullDMLScriptName = ""; + int port1 = getRandomAvailablePort(); + int port2 = getRandomAvailablePort(); + int port3 = getRandomAvailablePort(); + int port4 = getRandomAvailablePort(); + Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); + Thread t2 = startLocalFedWorkerThread(port2, FED_WORKER_WAIT_S); + Thread t3 = startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S); + Thread t4 = startLocalFedWorkerThread(port4); + + rtplatform = execMode; + if(rtplatform == ExecMode.SPARK) { + DMLScript.USE_LOCAL_SPARK_CONFIG = true; + } + TestConfiguration config = availableTestConfigurations.get(TEST_NAME); + loadTestConfiguration(config); + + // Run reference dml script with normal matrix + fullDMLScriptName = HOME + TEST_NAME + "Reference.dml"; + programArgs = new String[] {"-stats", "100", "-args", input("X1"), input("X2"), input("X3"), input("X4"), + Boolean.toString(rowPartitioned).toUpperCase(), expected("S")}; + + runTest(null); + + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] {"-stats", "100", "-nvargs", + "in_X1=" + TestUtils.federatedAddress(port1, input("X1")), + "in_X2=" + TestUtils.federatedAddress(port2, input("X2")), + "in_X3=" + TestUtils.federatedAddress(port3, input("X3")), + "in_X4=" + TestUtils.federatedAddress(port4, input("X4")), "rows=" + rows, "cols=" + cols, + "rP=" + Boolean.toString(rowPartitioned).toUpperCase(), "out_S=" + output("S")}; + + runTest(null); + + // compare via files + compareResults(0.01); + + Assert.assertTrue(heavyHittersContainsString("fed_rev")); + + // check that federated input files are still existing + Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1"))); + Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2"))); + Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X3"))); + Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4"))); + + TestUtils.shutdownThreads(t1, t2, t3, t4); + + rtplatform = platformOld; + DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedTriTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedTriTest.java new file mode 100644 index 0000000..e7ae2d4 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedTriTest.java @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.federated.primitives; + +import java.util.Arrays; +import java.util.Collection; + +import org.apache.sysds.api.DMLScript; +import org.apache.sysds.common.Types.ExecMode; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.runtime.util.HDFSTool; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +@RunWith(value = Parameterized.class) +@net.jcip.annotations.NotThreadSafe +public class FederatedTriTest extends AutomatedTestBase { + + private final static String TEST_NAME = "FederatedTriTest"; + + private final static String TEST_DIR = "functions/federated/"; + private static final String TEST_CLASS_DIR = TEST_DIR + FederatedTriTest.class.getSimpleName() + "/"; + + private final static int blocksize = 1024; + @Parameterized.Parameter() + public int rows; + @Parameterized.Parameter(1) + public int cols; + + @Parameterized.Parameter(2) + public boolean rowPartitioned; + + @Parameterized.Parameters + public static Collection<Object[]> data() { + return Arrays.asList(new Object[][] { + {20, 20, true} + }); + } + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"S"})); + } + + @Test + public void testTriCP() { runTriTest(ExecMode.SINGLE_NODE); } + + private void runTriTest(ExecMode execMode) { + boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; + ExecMode platformOld = rtplatform; + + if(rtplatform == ExecMode.SPARK) + DMLScript.USE_LOCAL_SPARK_CONFIG = true; + + getAndLoadTestConfiguration(TEST_NAME); + String HOME = SCRIPT_DIR + TEST_DIR; + + // write input matrices + int r = rows; + int c = cols / 4; + if(rowPartitioned) { + r = rows / 4; + c = cols; + } + + double[][] X1 = getRandomMatrix(r, c, 1, 5, 1, 3); + double[][] X2 = getRandomMatrix(r, c, 1, 5, 1, 7); + double[][] X3 = getRandomMatrix(r, c, 1, 5, 1, 8); + double[][] X4 = getRandomMatrix(r, c, 1, 5, 1, 9); + + MatrixCharacteristics mc = new MatrixCharacteristics(r, c, blocksize, r * c); + writeInputMatrixWithMTD("X1", X1, false, mc); + writeInputMatrixWithMTD("X2", X2, false, mc); + writeInputMatrixWithMTD("X3", X3, false, mc); + writeInputMatrixWithMTD("X4", X4, false, mc); + + // empty script name because we don't execute any script, just start the worker + fullDMLScriptName = ""; + int port1 = getRandomAvailablePort(); + int port2 = getRandomAvailablePort(); + int port3 = getRandomAvailablePort(); + int port4 = getRandomAvailablePort(); + Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); + Thread t2 = startLocalFedWorkerThread(port2, FED_WORKER_WAIT_S); + Thread t3 = startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S); + Thread t4 = startLocalFedWorkerThread(port4); + + rtplatform = execMode; + if(rtplatform == ExecMode.SPARK) { + DMLScript.USE_LOCAL_SPARK_CONFIG = true; + } + TestConfiguration config = availableTestConfigurations.get(TEST_NAME); + loadTestConfiguration(config); + + // Run reference dml script with normal matrix + fullDMLScriptName = HOME + TEST_NAME + "Reference.dml"; + programArgs = new String[] {"-stats", "100", "-args", input("X1"), input("X2"), input("X3"), input("X4"), + Boolean.toString(rowPartitioned).toUpperCase(), expected("S")}; + + runTest(null); + + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] {"-stats", "100", "-nvargs", + "in_X1=" + TestUtils.federatedAddress(port1, input("X1")), + "in_X2=" + TestUtils.federatedAddress(port2, input("X2")), + "in_X3=" + TestUtils.federatedAddress(port3, input("X3")), + "in_X4=" + TestUtils.federatedAddress(port4, input("X4")), "rows=" + rows, "cols=" + cols, + "rP=" + Boolean.toString(rowPartitioned).toUpperCase(), "out_S=" + output("S")}; + + runTest(null); + + // compare via files + compareResults(1e-9); + + // check that federated input files are still existing + Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1"))); + Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2"))); + Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X3"))); + Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4"))); + + TestUtils.shutdownThreads(t1, t2, t3, t4); + + rtplatform = platformOld; + DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + } +} diff --git a/src/test/scripts/functions/federated/FederatedLmPipelineReference.dml b/src/test/scripts/functions/federated/FederatedLmPipelineReference.dml index 1934aae..ffdca07 100644 --- a/src/test/scripts/functions/federated/FederatedLmPipelineReference.dml +++ b/src/test/scripts/functions/federated/FederatedLmPipelineReference.dml @@ -19,8 +19,8 @@ # #------------------------------------------------------------- -Fin = rbind(read($1), read($2)) +Fin = rbind(read($1), read($2)) y = read($5) # one hot encoding categorical, other passthrough diff --git a/src/test/scripts/functions/federated/FederatedRdiagTest.dml b/src/test/scripts/functions/federated/FederatedRdiagTest.dml new file mode 100644 index 0000000..a03ce4e --- /dev/null +++ b/src/test/scripts/functions/federated/FederatedRdiagTest.dml @@ -0,0 +1,27 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4), + ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), list(2*$rows/4, $cols), + list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows, $cols))); + +s = diag(A); +write(s, $out_S); diff --git a/src/test/scripts/functions/federated/FederatedRdiagTestReference.dml b/src/test/scripts/functions/federated/FederatedRdiagTestReference.dml new file mode 100644 index 0000000..92bdd8f --- /dev/null +++ b/src/test/scripts/functions/federated/FederatedRdiagTestReference.dml @@ -0,0 +1,25 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +A = rbind(read($1), read($2), read($3), read($4)); + +s = diag(A); +write(s, $5); diff --git a/src/test/scripts/functions/federated/FederatedRevTest.dml b/src/test/scripts/functions/federated/FederatedRevTest.dml new file mode 100644 index 0000000..d43edd1 --- /dev/null +++ b/src/test/scripts/functions/federated/FederatedRevTest.dml @@ -0,0 +1,32 @@ +#------------------------------------------------------------- + # + # Licensed to the Apache Software Foundation (ASF) under one + # or more contributor license agreements. See the NOTICE file + # distributed with this work for additional information + # regarding copyright ownership. The ASF licenses this file + # to you under the Apache License, Version 2.0 (the + # "License"); you may not use this file except in compliance + # with the License. You may obtain a copy of the License at + # + # http://www.apache.org/licenses/LICENSE-2.0 + # + # Unless required by applicable law or agreed to in writing, + # software distributed under the License is distributed on an + # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + # KIND, either express or implied. See the License for the + # specific language governing permissions and limitations + # under the License. + # + #------------------------------------------------------------- +if ($rP) { + A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4), + ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), list(2*$rows/4, $cols), + list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows, $cols))); +} else { + A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4), + ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4), list($rows, $cols/2), + list(0,$cols/2), list($rows, 3*($cols/4)), list(0, 3*($cols/4)), list($rows, $cols))); +} + +s = rev(A); +write(s, $out_S); \ No newline at end of file diff --git a/src/test/scripts/functions/federated/FederatedRevTestReference.dml b/src/test/scripts/functions/federated/FederatedRevTestReference.dml new file mode 100644 index 0000000..e07591d --- /dev/null +++ b/src/test/scripts/functions/federated/FederatedRevTestReference.dml @@ -0,0 +1,26 @@ +#------------------------------------------------------------- + # + # Licensed to the Apache Software Foundation (ASF) under one + # or more contributor license agreements. See the NOTICE file + # distributed with this work for additional information + # regarding copyright ownership. The ASF licenses this file + # to you under the Apache License, Version 2.0 (the + # "License"); you may not use this file except in compliance + # with the License. You may obtain a copy of the License at + # + # http://www.apache.org/licenses/LICENSE-2.0 + # + # Unless required by applicable law or agreed to in writing, + # software distributed under the License is distributed on an + # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + # KIND, either express or implied. See the License for the + # specific language governing permissions and limitations + # under the License. + # + #------------------------------------------------------------- + + if($5) { A = rbind(read($1), read($2), read($3), read($4)); } + else { A = cbind(read($1), read($2), read($3), read($4)); } + + s = rev(A); + write(s, $6); \ No newline at end of file diff --git a/src/test/scripts/functions/federated/FederatedTriTest.dml b/src/test/scripts/functions/federated/FederatedTriTest.dml new file mode 100644 index 0000000..a661a4a --- /dev/null +++ b/src/test/scripts/functions/federated/FederatedTriTest.dml @@ -0,0 +1,33 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +if ($rP) { + A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4), + ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), list(2*$rows/4, $cols), + list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows, $cols))); + } else { + A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4), + ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4), list($rows, $cols/2), + list(0,$cols/2), list($rows, 3*($cols/4)), list(0, 3*($cols/4)), list($rows, $cols))); + } + +s = lower.tri(target=A, diag=FALSE, values=TRUE); +write(s, $out_S); diff --git a/src/test/scripts/functions/federated/FederatedTriTestReference.dml b/src/test/scripts/functions/federated/FederatedTriTestReference.dml new file mode 100644 index 0000000..bb5d98d --- /dev/null +++ b/src/test/scripts/functions/federated/FederatedTriTestReference.dml @@ -0,0 +1,26 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +if($5) { A = rbind(read($1), read($2), read($3), read($4)); } +else { A = cbind(read($1), read($2), read($3), read($4));} + +s = lower.tri(target=A, diag=FALSE, values=TRUE); +write(s, $6);