This is an automated email from the ASF dual-hosted git repository.
sebwrede pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 1434954 [SYSTEMDS-2904] Federated ternary instruction
1434954 is described below
commit 143495400759c21e030c1809e800e5e6fc68b2da
Author: sebwrede <[email protected]>
AuthorDate: Fri Apr 2 11:04:06 2021 +0200
[SYSTEMDS-2904] Federated ternary instruction
This commit adds TernaryFEDInstruction to the system and extends processing
of other federated instructions.
This commit closes PR #1193.
---
scripts/builtin/lmPredict.dml | 2 +-
.../controlprogram/federated/FederationMap.java | 20 ++
.../fed/AggregateBinaryFEDInstruction.java | 2 +-
.../fed/AggregateTernaryFEDInstruction.java | 27 ++-
.../fed/BinaryMatrixMatrixFEDInstruction.java | 26 ++-
.../runtime/instructions/fed/FEDInstruction.java | 1 +
.../instructions/fed/FEDInstructionUtils.java | 9 +
.../instructions/fed/TernaryFEDInstruction.java | 226 +++++++++++++++++++++
.../fed/UnaryMatrixFEDInstruction.java | 16 +-
.../runtime/matrix/operators/BinaryOperator.java | 3 +-
.../federated/primitives/FederatedIfelseTest.java | 224 ++++++++++++++++++++
.../federated/FederatedIfelseAlignedTest.dml | 42 ++++
.../FederatedIfelseAlignedTestReference.dml | 31 ++-
.../functions/federated/FederatedIfelseTest.dml | 34 ++--
.../federated/FederatedIfelseTestReference.dml | 29 +--
15 files changed, 624 insertions(+), 68 deletions(-)
diff --git a/scripts/builtin/lmPredict.dml b/scripts/builtin/lmPredict.dml
index 1bdabf3..52ea656 100644
--- a/scripts/builtin/lmPredict.dml
+++ b/scripts/builtin/lmPredict.dml
@@ -27,7 +27,7 @@ m_lmPredict = function(Matrix[Double] X, Matrix[Double] B,
yhat = X %*% B[1:ncol(X)] + matrix(1,nrow(X),1) %*% intercept;
if( verbose ) {
- y_residual = ytest - yhat;
+ y_residual = ytest - yhat;
avg_res = sum(y_residual) / nrow(ytest);
ss_res = sum(y_residual^2);
ss_avg_res = ss_res - nrow(ytest) * avg_res^2;
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
index 7278123..b066d1e 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
@@ -212,6 +212,26 @@ public class FederationMap {
return ret.toArray(new Future[0]);
}
+ @SuppressWarnings("unchecked")
+ public Future<FederatedResponse>[] execute(long tid, boolean wait,
FederatedRequest[] frSlices1, FederatedRequest[] frSlices2, FederatedRequest...
fr) {
+ // executes step1[] - step 2 - ... step4 (only first step
federated-data-specific)
+ setThreadID(tid, frSlices1, fr);
+ setThreadID(tid, frSlices2, fr);
+ List<Future<FederatedResponse>> ret = new ArrayList<>();
+ int pos = 0;
+ for(Entry<FederatedRange, FederatedData> e :
_fedMap.entrySet()) {
+ FederatedRequest[] newFr = (frSlices1!=null) ?
+ ((frSlices2!=null)? (addAll(frSlices2[pos],
addAll(frSlices1[pos++], fr))) : addAll(frSlices1[pos++], fr)) : fr;
+ ret.add(e.getValue().executeFederatedOperation(newFr));
+ }
+
+ // prepare results (future federated responses), with optional
wait to ensure the
+ // order of requests without data dependencies (e.g., cleanup
RPCs)
+ if( wait )
+ FederationUtils.waitFor(ret);
+ return ret.toArray(new Future[0]);
+ }
+
public List<Pair<FederatedRange, Future<FederatedResponse>>>
requestFederatedData() {
if(!isInitialized())
throw new DMLRuntimeException("Federated matrix read
only supported on initialized FederatedData");
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
index 12616ed..9822bef 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
@@ -80,7 +80,7 @@ public class AggregateBinaryFEDInstruction extends
BinaryFEDInstruction {
FederatedRequest fr1 =
mo1.getFedMapping().broadcast(mo2);
FederatedRequest fr2 =
FederationUtils.callInstruction(instString, output,
new CPOperand[]{input1, input2}, new
long[]{mo1.getFedMapping().getID(), fr1.getID()});
- if( mo2.getNumColumns() == 1 ) { //MV
+ if( mo2.getNumColumns() == 1 && mo2.getNumRows() !=
mo1.getNumColumns()) { //MV
FederatedRequest fr3 = new
FederatedRequest(RequestType.GET_VAR, fr2.getID());
FederatedRequest fr4 =
mo1.getFedMapping().cleanup(getTID(), fr1.getID(), fr2.getID());
//execute federated operations and aggregate
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateTernaryFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateTernaryFEDInstruction.java
index 2bc187e..42a6e0e 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateTernaryFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateTernaryFEDInstruction.java
@@ -52,7 +52,6 @@ public class AggregateTernaryFEDInstruction extends
FEDInstruction {
MatrixObject mo1 = ec.getMatrixObject(_ins.input1);
MatrixObject mo2 = ec.getMatrixObject(_ins.input2);
MatrixObject mo3 = _ins.input3.isLiteral() ? null :
ec.getMatrixObject(_ins.input3);
-
if(mo1.isFederated() && mo2.isFederated() &&
mo1.getFedMapping().isAligned(mo2.getFedMapping(), false) &&
mo3 == null) {
FederatedRequest fr1 =
mo1.getFedMapping().broadcast(ec.getScalarInput(_ins.input3));
@@ -79,6 +78,32 @@ public class AggregateTernaryFEDInstruction extends
FEDInstruction {
else {
throw new DMLRuntimeException("Not Implemented
Federated Ternary Variation");
}
+ } else if(mo1.isFederated() && _ins.input3.isMatrix() && mo3 !=
null) {
+ FederatedRequest[] fr1 =
mo1.getFedMapping().broadcastSliced(mo3, false);
+ FederatedRequest[] fr2 =
mo1.getFedMapping().broadcastSliced(mo2, false);
+ FederatedRequest fr3 =
FederationUtils.callInstruction(_ins.getInstructionString(),
+ _ins.getOutput(),
+ new CPOperand[] {_ins.input1, _ins.input2,
_ins.input3},
+ new long[] {mo1.getFedMapping().getID(),
fr2[0].getID(), fr1[0].getID()});
+ FederatedRequest fr4 = new
FederatedRequest(RequestType.GET_VAR, fr3.getID());
+ FederatedRequest fr5 =
mo2.getFedMapping().cleanup(getTID(), fr1[0].getID(), fr2[0].getID());
+ Future<FederatedResponse>[] tmp =
mo1.getFedMapping().execute(getTID(), fr1, fr2[0], fr3, fr4, fr5);
+
+ if(_ins.output.getDataType().isScalar()) {
+ double sum = 0;
+ for(Future<FederatedResponse> fr : tmp)
+ try {
+ sum += ((ScalarObject)
fr.get().getData()[0]).getDoubleValue();
+ }
+ catch(Exception e) {
+ throw new
DMLRuntimeException("Federated Get data failed with exception on
TernaryFedInstruction", e);
+ }
+
+ ec.setScalarOutput(_ins.output.getName(), new
DoubleObject(sum));
+ }
+ else {
+ throw new DMLRuntimeException("Not Implemented
Federated Ternary Variation");
+ }
}
else {
if(mo3 == null)
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
index 6f7dcc9..b6d0227 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
@@ -58,6 +58,12 @@ public class BinaryMatrixMatrixFEDInstruction extends
BinaryFEDInstruction
new long[]{mo1.getFedMapping().getID(),
mo2.getFedMapping().getID()});
mo1.getFedMapping().execute(getTID(), true,
fr2);
}
+ else if ( !mo1.isFederated() ){
+ FederatedRequest[] fr1 =
mo2.getFedMapping().broadcastSliced(mo1, false);
+ fr2 =
FederationUtils.callInstruction(instString, output, new CPOperand[]{input1,
input2},
+ new long[]{fr1[0].getID(),
mo2.getFedMapping().getID()});
+ mo2.getFedMapping().execute(getTID(), true,
fr1, fr2);
+ }
else {
throw new DMLRuntimeException("Matrix-matrix
binary operations with a "
+ "federated right input are only
supported for special cases yet.");
@@ -103,10 +109,22 @@ public class BinaryMatrixMatrixFEDInstruction extends
BinaryFEDInstruction
}
}
- // derive new fed mapping for output
- MatrixObject out = ec.getMatrixObject(output);
+ if ( mo1.isFederated() )
+ setOutputFedMapping(mo1, fr2.getID(), ec);
+ else if ( mo2.isFederated() )
+ setOutputFedMapping(mo2, fr2.getID(), ec);
+ else throw new DMLRuntimeException("Input is not federated, so
the output FedMapping cannot be set!");
+ }
- out.getDataCharacteristics().set(mo1.getDataCharacteristics());
-
out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr2.getID()));
+ /**
+ * Set data characteristics and fed mapping for output.
+ * @param moFederated federated matrix object from which data
characteristics and fed mapping are derived
+ * @param outputFedmappingID ID for the fed mapping of output
+ * @param ec execution context
+ */
+ private void setOutputFedMapping(MatrixObject moFederated, long
outputFedmappingID, ExecutionContext ec){
+ MatrixObject out = ec.getMatrixObject(output);
+
out.getDataCharacteristics().set(moFederated.getDataCharacteristics());
+
out.setFedMapping(moFederated.getFedMapping().copyWithNewID(outputFedmappingID));
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
index dafd723..8ed9aba 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
@@ -40,6 +40,7 @@ public abstract class FEDInstruction extends Instruction {
Reorg,
Reshape,
MatrixIndexing,
+ Ternary,
Quaternary,
QSort,
QPick,
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index 1417c90..a4c750c 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -38,6 +38,7 @@ import
org.apache.sysds.runtime.instructions.cp.MultiReturnParameterizedBuiltinC
import
org.apache.sysds.runtime.instructions.cp.ParameterizedBuiltinCPInstruction;
import org.apache.sysds.runtime.instructions.cp.QuaternaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.ReorgCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.TernaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.UnaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.UnaryMatrixCPInstruction;
import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
@@ -174,6 +175,14 @@ public class FEDInstructionUtils {
fedinst =
IndexingFEDInstruction.parseInstruction(minst.getInstructionString());
}
}
+ else if(inst instanceof TernaryCPInstruction) {
+ TernaryCPInstruction tinst = (TernaryCPInstruction)
inst;
+ if((tinst.input1.isMatrix() &&
ec.getCacheableData(tinst.input1).isFederated())
+ || (tinst.input2.isMatrix() &&
ec.getCacheableData(tinst.input2).isFederated())
+ || (tinst.input3.isMatrix() &&
ec.getCacheableData(tinst.input3).isFederated())) {
+ fedinst =
TernaryFEDInstruction.parseInstruction(tinst.getInstructionString());
+ }
+ }
else if(inst instanceof VariableCPInstruction ){
VariableCPInstruction ins = (VariableCPInstruction)
inst;
if(ins.getVariableOpcode() ==
VariableOperationCode.Write
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
new file mode 100644
index 0000000..957ee6e
--- /dev/null
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
@@ -0,0 +1,226 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.fed;
+
+import java.util.Objects;
+
+import com.sun.tools.javac.util.List;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.matrix.operators.TernaryOperator;
+
+public class TernaryFEDInstruction extends ComputationFEDInstruction {
+
+ private TernaryFEDInstruction(TernaryOperator op, CPOperand in1,
CPOperand in2, CPOperand in3, CPOperand out, String opcode, String str) {
+ super(FEDInstruction.FEDType.Ternary, op, in1, in2, in3, out,
opcode, str);
+ }
+
+ public static TernaryFEDInstruction parseInstruction(String str) {
+ String[] parts =
InstructionUtils.getInstructionPartsWithValueType(str);
+ String opcode = parts[0];
+ CPOperand operand1 = new CPOperand(parts[1]);
+ CPOperand operand2 = new CPOperand(parts[2]);
+ CPOperand operand3 = new CPOperand(parts[3]);
+ CPOperand outOperand = new CPOperand(parts[4]);
+ TernaryOperator op =
InstructionUtils.parseTernaryOperator(opcode);
+ return new TernaryFEDInstruction(op, operand1, operand2,
operand3, outOperand, opcode, str);
+ }
+
+ @Override
+ public void processInstruction(ExecutionContext ec) {
+ MatrixObject mo1 = input1.isMatrix() ?
ec.getMatrixObject(input1.getName()) : null;
+ MatrixObject mo2 = input2.isMatrix() ?
ec.getMatrixObject(input2.getName()) : null;
+ MatrixObject mo3 = input3 != null && input3.isMatrix() ?
ec.getMatrixObject(input3.getName()) : null;
+
+ long matrixInputsCount = List.of(mo1, mo2,
mo3).stream().filter(Objects::nonNull).count();
+
+ if(matrixInputsCount == 3)
+ processMatrixInput(ec, mo1, mo2, mo3);
+ else if(matrixInputsCount == 1) {
+ CPOperand in = mo1 == null ? mo2 == null ? input3 :
input2 : input1;
+ mo1 = mo1 == null ? mo2 == null ? mo3 : mo2 : mo1;
+ processMatrixScalarInput(ec, mo1, in);
+ }
+ else {
+ if(mo1 != null && mo2 != null) {
+ if(input3 != null && !input3.isLiteral())
+ instString =
InstructionUtils.replaceOperand(instString, 4,
+
InstructionUtils.createLiteralOperand(ec.getScalarInput(input3).getStringValue(),
Types.ValueType.FP64));
+ process2MatrixScalarInput(ec, mo1, mo2, input1,
input2);
+ }
+ else if(mo2 != null && mo3 != null) {
+ if(!input1.isLiteral())
+ instString =
InstructionUtils.replaceOperand(instString, 2,
+
InstructionUtils.createLiteralOperand(ec.getScalarInput(input1).getStringValue(),
Types.ValueType.FP64));
+ process2MatrixScalarInput(ec, mo2, mo3, input2,
input3);
+ }
+ else if(mo1 != null && mo3 != null) {
+ if(!input2.isLiteral())
+ instString =
InstructionUtils.replaceOperand(instString, 3,
+
InstructionUtils.createLiteralOperand(ec.getScalarInput(input2).getStringValue(),
Types.ValueType.FP64));
+ process2MatrixScalarInput(ec, mo1, mo3, input1,
input3);
+ }
+ }
+ }
+
+ private void processMatrixScalarInput(ExecutionContext ec, MatrixObject
mo1, CPOperand in) {
+ FederatedRequest fr1 =
FederationUtils.callInstruction(instString, output, new CPOperand[] {in}, new
long[] {mo1.getFedMapping().getID()});
+ mo1.getFedMapping().execute(getTID(), true, fr1);
+
+ setOutputFedMapping(ec, mo1, fr1.getID());
+ }
+
+ private void process2MatrixScalarInput(ExecutionContext ec,
MatrixObject mo1, MatrixObject mo2, CPOperand in1, CPOperand in2) {
+ FederatedRequest[] fr1 = null;
+ CPOperand[] varOldIn;
+ boolean cleanupIn = true;
+ long[] varNewIn;
+ varOldIn = new CPOperand[] {in1, in2};
+ if(mo1.isFederated()) {
+ if(mo2.isFederated() &&
mo1.getFedMapping().isAligned(mo2.getFedMapping(), false))
+ varNewIn = new
long[]{mo1.getFedMapping().getID(), mo2.getFedMapping().getID()};
+ else {
+ fr1 = mo1.getFedMapping().broadcastSliced(mo2,
false);
+ varNewIn = new
long[]{mo1.getFedMapping().getID(), fr1[0].getID()};
+ }
+ } else {
+ cleanupIn = false;
+ mo1 = ec.getMatrixObject(in2);
+ fr1 =
mo1.getFedMapping().broadcastSliced(ec.getMatrixObject(in1), false);
+ varNewIn = new long[]{fr1[0].getID(),
mo1.getFedMapping().getID()};
+ }
+ FederatedRequest fr2 =
FederationUtils.callInstruction(instString, output, varOldIn, varNewIn);
+ FederatedRequest fr3;
+
+ // 2 aligned inputs
+ if(fr1 == null) {
+ mo1.getFedMapping().execute(getTID(), true, fr2);
+ } else {
+ if(cleanupIn) {
+ fr3 = mo1.getFedMapping().cleanup(getTID(),
fr1[0].getID());
+ mo1.getFedMapping().execute(getTID(), true,
fr1, fr2, fr3);
+ } else
+ mo1.getFedMapping().execute(getTID(), true,
fr1, fr2);
+ }
+ setOutputFedMapping(ec, mo1, fr2.getID());
+ }
+
+ private void processMatrixInput(ExecutionContext ec, MatrixObject mo1,
MatrixObject mo2, MatrixObject mo3) {
+
+ // check aligned matrices
+ RetAlignedValues retAlignedValues = getAlignedInputs(ec, mo1,
mo2, mo3);
+
+ FederatedRequest[] fr2;
+ FederatedRequest fr3, fr4;
+
+ // all 3 inputs fed aligned on the one worker
+ if(retAlignedValues._allAligned) {
+ fr3 = FederationUtils.callInstruction(instString,
output, new CPOperand[] {input1, input2, input3},
+ new long[] {mo1.getFedMapping().getID(),
mo2.getFedMapping().getID(), mo3.getFedMapping().getID()});
+ mo1.getFedMapping().execute(getTID(), fr3);
+ }
+ // 2 fed aligned inputs
+ else if(retAlignedValues._twoAligned) {
+ fr3 = FederationUtils.callInstruction(instString,
output, new CPOperand[] {input1, input2, input3}, retAlignedValues._vars);
+ fr4 = mo1.getFedMapping().cleanup(getTID(),
retAlignedValues._fr[0].getID());
+ mo1.getFedMapping().execute(getTID(), true,
retAlignedValues._fr, fr3, fr4);
+ }
+ // 1 fed input or not aligned
+ else {
+ if(!mo1.isFederated())
+ if(mo2.isFederated()) {
+ mo1 = mo2;
+ mo2 = ec.getMatrixObject(input1);
+ }
+ else {
+ mo1 = mo3;
+ mo3 = ec.getMatrixObject(input1);
+ }
+
+ FederatedRequest[] fr1 =
mo1.getFedMapping().broadcastSliced(mo2, false);
+ fr2 = mo1.getFedMapping().broadcastSliced(mo3, false);
+
+ long[] vars = new long[] {mo1.getFedMapping().getID(),
fr1[0].getID(), fr2[0].getID()};
+ if(!ec.getMatrixObject(input1).isFederated())
+ vars = ec.getMatrixObject(input2).isFederated()
? new long[] {fr1[0].getID(), mo1.getFedMapping().getID(), fr2[0].getID()} :
new long[] {fr1[0].getID(), fr2[0].getID(),
+ mo1.getFedMapping().getID()};
+
+ fr3 = FederationUtils.callInstruction(instString,
output, new CPOperand[] {input1, input2, input3}, vars);
+ fr4 = mo1.getFedMapping().cleanup(getTID(),
fr1[0].getID(), fr2[0].getID());
+ mo1.getFedMapping().execute(getTID(), true, fr1, fr2,
fr3, fr4);
+ }
+
+ //derive new fed mapping for output
+ setOutputFedMapping(ec, mo1, fr3.getID());
+ }
+
+ // check aligned matrices and return vars
+ private RetAlignedValues getAlignedInputs(ExecutionContext ec,
MatrixObject mo1, MatrixObject mo2, MatrixObject mo3) {
+ long[] vars = new long[0];
+ FederatedRequest[] fr = new FederatedRequest[0];
+ boolean twoAligned = false, allAligned = false;
+ if(mo1.isFederated() && mo2.isFederated() &&
mo1.getFedMapping().isAligned(mo2.getFedMapping(), false)) {
+ twoAligned = true;
+ fr = mo1.getFedMapping().broadcastSliced(mo3, false);
+ vars = new long[] {mo1.getFedMapping().getID(),
mo2.getFedMapping().getID(), fr[0].getID()};
+ }
+ if(mo1.isFederated() && mo3.isFederated() &&
mo1.getFedMapping().isAligned(mo3.getFedMapping(), false)) {
+ allAligned = twoAligned;
+ twoAligned = true;
+ fr = mo1.getFedMapping().broadcastSliced(mo2, false);
+ vars = new long[] {mo1.getFedMapping().getID(),
fr[0].getID(), mo3.getFedMapping().getID()};
+ }
+ if(mo2.isFederated() && mo3.isFederated() &&
mo2.getFedMapping().isAligned(mo3.getFedMapping(), false) && !allAligned) {
+ twoAligned = true;
+ mo1 = mo2;
+ mo2 = mo3;
+ mo3 = ec.getMatrixObject(input1);
+ fr = mo1.getFedMapping().broadcastSliced(mo3, false);
+ vars = new long[] {fr[0].getID(),
mo1.getFedMapping().getID(), mo2.getFedMapping().getID()};
+ }
+
+ return new RetAlignedValues(twoAligned, allAligned, vars, fr);
+ }
+
+ private static final class RetAlignedValues {
+ public boolean _twoAligned;
+ public boolean _allAligned;
+ public long[] _vars;
+ public FederatedRequest[] _fr;
+
+ public RetAlignedValues(boolean twoAligned, boolean allAligned,
long[] vars, FederatedRequest[] fr) {
+ _twoAligned = twoAligned;
+ _allAligned = allAligned;
+ _vars = vars;
+ _fr = fr;
+ }
+ }
+
+ private void setOutputFedMapping(ExecutionContext ec, MatrixObject
fedMapObj, long fedOutputID) {
+ MatrixObject out = ec.getMatrixObject(output);
+
out.getDataCharacteristics().set(fedMapObj.getDataCharacteristics());
+
out.setFedMapping(fedMapObj.getFedMapping().copyWithNewID(fedOutputID));
+ }
+}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryMatrixFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryMatrixFEDInstruction.java
index 8e5104c..c6e0b2c 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryMatrixFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryMatrixFEDInstruction.java
@@ -25,10 +25,13 @@ import
org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.functionobjects.Builtin;
+import org.apache.sysds.runtime.functionobjects.ValueFunction;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.matrix.data.LibCommonsMath;
import org.apache.sysds.runtime.matrix.operators.Operator;
+import org.apache.sysds.runtime.matrix.operators.UnaryOperator;
public class UnaryMatrixFEDInstruction extends UnaryFEDInstruction {
protected UnaryMatrixFEDInstruction(Operator op, CPOperand in,
CPOperand out, String opcode, String instr) {
@@ -43,7 +46,18 @@ public class UnaryMatrixFEDInstruction extends
UnaryFEDInstruction {
public static UnaryMatrixFEDInstruction parseInstruction(String str) {
CPOperand in = new CPOperand("", ValueType.UNKNOWN,
DataType.UNKNOWN);
CPOperand out = new CPOperand("", ValueType.UNKNOWN,
DataType.UNKNOWN);
- String opcode = parseUnaryInstruction(str, in, out);
+
+ String[] parts =
InstructionUtils.getInstructionPartsWithValueType(str);
+ String opcode;
+ opcode = parts[0];
+ if( opcode.equalsIgnoreCase("exp") && parts.length == 5) {
+ in.split(parts[1]);
+ out.split(parts[2]);
+ ValueFunction func = Builtin.getBuiltinFnObject(opcode);
+ return new UnaryMatrixFEDInstruction(new
UnaryOperator(func,
+
Integer.parseInt(parts[3]),Boolean.parseBoolean(parts[4])), in, out, opcode,
str);
+ }
+ opcode = parseUnaryInstruction(str, in, out);
return new
UnaryMatrixFEDInstruction(InstructionUtils.parseUnaryOperator(opcode), in, out,
opcode, str);
}
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/operators/BinaryOperator.java
b/src/main/java/org/apache/sysds/runtime/matrix/operators/BinaryOperator.java
index 7579046..7cf201a 100644
---
a/src/main/java/org/apache/sysds/runtime/matrix/operators/BinaryOperator.java
+++
b/src/main/java/org/apache/sysds/runtime/matrix/operators/BinaryOperator.java
@@ -38,6 +38,7 @@ import org.apache.sysds.runtime.functionobjects.IntegerDivide;
import org.apache.sysds.runtime.functionobjects.LessThan;
import org.apache.sysds.runtime.functionobjects.LessThanEquals;
import org.apache.sysds.runtime.functionobjects.Minus;
+import org.apache.sysds.runtime.functionobjects.Minus1Multiply;
import org.apache.sysds.runtime.functionobjects.MinusMultiply;
import org.apache.sysds.runtime.functionobjects.MinusNz;
import org.apache.sysds.runtime.functionobjects.Modulus;
@@ -68,7 +69,7 @@ public class BinaryOperator extends Operator implements
Serializable
|| p instanceof BitwShiftL || p instanceof BitwShiftR);
fn = p;
commutative = p instanceof Plus || p instanceof Multiply
- || p instanceof And || p instanceof Or || p instanceof
Xor;
+ || p instanceof And || p instanceof Or || p instanceof
Xor || p instanceof Minus1Multiply;
}
public void setNumThreads(int k) {
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedIfelseTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedIfelseTest.java
new file mode 100644
index 0000000..c8733ff
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedIfelseTest.java
@@ -0,0 +1,224 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.primitives;
+
+import java.util.Arrays;
+import java.util.Collection;
+
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(value = Parameterized.class)
[email protected]
+public class FederatedIfelseTest extends AutomatedTestBase {
+ private final static String TEST_NAME1 = "FederatedIfelseTest";
+ private final static String TEST_NAME2 = "FederatedIfelseAlignedTest";
+
+ private final static String TEST_DIR = "functions/federated/";
+ private static final String TEST_CLASS_DIR = TEST_DIR +
FederatedIfelseTest.class.getSimpleName() + "/";
+
+ private final static int blocksize = 1024;
+
+ @Parameterized.Parameter()
+ public int rows;
+ @Parameterized.Parameter(1)
+ public int cols;
+ @Parameterized.Parameter(2)
+ public boolean rowPartitioned;
+
+ @Parameterized.Parameters
+ public static Collection<Object[]> data() {
+ return Arrays.asList(
+ new Object[][] {
+ {64, 16, true},
+ {64, 16, false},
+ });
+ }
+
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME1, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"S"}));
+ addTestConfiguration(TEST_NAME2, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {"S"}));
+ }
+
+ @Test
+ public void testIfelseDiffWorkersCP() {
+ runTernaryTest(ExecMode.SINGLE_NODE, false);
+ }
+
+ @Test
+ public void testIfelseAlignedCP() {
+ runTernaryTest(ExecMode.SINGLE_NODE, true);
+ }
+
+ private void runTernaryTest(ExecMode execMode, boolean alligned) {
+ boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+ ExecMode platformOld = rtplatform;
+
+ if(rtplatform == ExecMode.SPARK)
+ DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+
+ String TEST_NAME = alligned ? TEST_NAME2 : TEST_NAME1;
+
+ getAndLoadTestConfiguration(TEST_NAME);
+ String HOME = SCRIPT_DIR + TEST_DIR;
+
+ // write input matrices
+ int r = rows;
+ int c = cols / 4;
+ if(rowPartitioned) {
+ r = rows / 4;
+ c = cols;
+ }
+
+ double[][] X1 = getRandomMatrix(r, c, 1, 5, 1, 3);
+ double[][] X2 = getRandomMatrix(r, c, 1, 5, 1, 7);
+ double[][] X3 = getRandomMatrix(r, c, 1, 5, 1, 8);
+ double[][] X4 = getRandomMatrix(r, c, 1, 5, 1, 9);
+ X1[0][0] = 0;
+ MatrixCharacteristics mc = new MatrixCharacteristics(r, c,
blocksize, r * c);
+ writeInputMatrixWithMTD("X1", X1, false, mc);
+ writeInputMatrixWithMTD("X2", X2, false, mc);
+ writeInputMatrixWithMTD("X3", X3, false, mc);
+ writeInputMatrixWithMTD("X4", X4, false, mc);
+
+ double[][] Y1 = getRandomMatrix(r, c, 10, 15, 1, 3);
+ double[][] Y2 = getRandomMatrix(r, c, 10, 15,1, 7);
+ double[][] Y3 = getRandomMatrix(r, c, 10, 15,1, 8);
+ double[][] Y4 = getRandomMatrix(r, c, 10, 15, 1, 9);
+ MatrixCharacteristics mc2 = new MatrixCharacteristics(r, c,
blocksize, r * c);
+ writeInputMatrixWithMTD("Y1", Y1, false, mc2);
+ writeInputMatrixWithMTD("Y2", Y2, false, mc2);
+ writeInputMatrixWithMTD("Y3", Y3, false, mc2);
+ writeInputMatrixWithMTD("Y4", Y4, false, mc2);
+
+ // empty script name because we don't execute any script, just
start the worker
+ fullDMLScriptName = "";
+ int port1 = getRandomAvailablePort();
+ int port2 = getRandomAvailablePort();
+ int port3 = getRandomAvailablePort();
+ int port4 = getRandomAvailablePort();
+ Thread t1 = startLocalFedWorkerThread(port1);
+ Thread t2 = startLocalFedWorkerThread(port2);
+ Thread t3 = startLocalFedWorkerThread(port3);
+ Thread t4 = startLocalFedWorkerThread(port4);
+
+ rtplatform = execMode;
+ if(rtplatform == ExecMode.SPARK)
+ DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+
+ TestConfiguration config =
availableTestConfigurations.get(TEST_NAME);
+ loadTestConfiguration(config);
+
+ if(alligned)
+ runAlignedTernary(HOME, TEST_NAME, r, c, port1, port2,
port3, port4);
+ else
+ runTernary(HOME, TEST_NAME, port1, port2, port3, port4);
+
+ // compare via files
+ compareResults(1e-9);
+ Assert.assertTrue(heavyHittersContainsString("fed_ifelse"));
+
+ // check that federated input files are still existing
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X3")));
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4")));
+
+ if(alligned) {
+
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("Y1")));
+
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("Y2")));
+
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("Y3")));
+
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("Y4")));
+ }
+
+ TestUtils.shutdownThreads(t1, t2, t3, t4);
+
+ rtplatform = platformOld;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ }
+
+ private void runTernary(String HOME, String TEST_NAME, int port1, int
port2, int port3, int port4) {
+ // Run reference dml script with normal matrix
+ fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+ programArgs = new String[] {"-stats", "100", "-args",
input("X1"), input("X2"), input("X3"), input("X4"),
+ expected("S"),
Boolean.toString(rowPartitioned).toUpperCase()};
+ runTest(true, false, null, -1);
+
+ // Run actual dml script with federated matrix
+
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {"-stats", "100", "-nvargs",
+ "in_X1=" + TestUtils.federatedAddress(port1,
input("X1")),
+ "in_X2=" + TestUtils.federatedAddress(port2,
input("X2")),
+ "in_X3=" + TestUtils.federatedAddress(port3,
input("X3")),
+ "in_X4=" + TestUtils.federatedAddress(port4,
input("X4")), "rows=" + rows, "cols=" + cols,
+ "rP=" + Boolean.toString(rowPartitioned).toUpperCase(),
+ "out_S=" + output("S")};
+ runTest(true, false, null, -1);
+ }
+
+ private void runAlignedTernary(String HOME, String TEST_NAME, int r,
int c, int port1, int port2, int port3, int port4) {
+ double[][] Y1 = getRandomMatrix(r, c, 10, 15, 1, 3);
+ double[][] Y2 = getRandomMatrix(r, c, 10, 15,1, 7);
+ double[][] Y3 = getRandomMatrix(r, c, 10, 15,1, 8);
+ double[][] Y4 = getRandomMatrix(r, c, 10, 15, 1, 9);
+ MatrixCharacteristics mc2 = new MatrixCharacteristics(r, c,
blocksize, r * c);
+ writeInputMatrixWithMTD("Y1", Y1, false, mc2);
+ writeInputMatrixWithMTD("Y2", Y2, false, mc2);
+ writeInputMatrixWithMTD("Y3", Y3, false, mc2);
+ writeInputMatrixWithMTD("Y4", Y4, false, mc2);
+
+ // Run reference dml script with normal matrix
+ fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+ programArgs = new String[] {"-stats", "100", "-args",
+ input("X1"), input("X2"), input("X3"), input("X4"),
+ input("Y1"), input("Y2"), input("Y3"), input("Y4"),
+ expected("S"),
Boolean.toString(rowPartitioned).toUpperCase()};
+ runTest(true, false, null, -1);
+
+ // Run actual dml script with federated matrix
+
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {"-stats", "100", "-nvargs",
+ "in_X1=" + TestUtils.federatedAddress(port1,
input("X1")),
+ "in_X2=" + TestUtils.federatedAddress(port2,
input("X2")),
+ "in_X3=" + TestUtils.federatedAddress(port3,
input("X3")),
+ "in_X4=" + TestUtils.federatedAddress(port4,
input("X4")),
+ "in_Y1=" + TestUtils.federatedAddress(port1,
input("Y1")),
+ "in_Y2=" + TestUtils.federatedAddress(port2,
input("Y2")),
+ "in_Y3=" + TestUtils.federatedAddress(port3,
input("Y3")),
+ "in_Y4=" + TestUtils.federatedAddress(port4,
input("Y4")), "rows=" + rows, "cols=" + cols,
+ "rP=" + Boolean.toString(rowPartitioned).toUpperCase(),
+ "out_S=" + output("S")};
+ runTest(true, false, null, -1);
+ }
+}
diff --git
a/src/test/scripts/functions/federated/FederatedIfelseAlignedTest.dml
b/src/test/scripts/functions/federated/FederatedIfelseAlignedTest.dml
new file mode 100644
index 0000000..fc73a80
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedIfelseAlignedTest.dml
@@ -0,0 +1,42 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+if ($rP) {
+ A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0),
list(2*$rows/4, $cols),
+ list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0),
list($rows, $cols)));
+
+ B = federated(addresses=list($in_Y1, $in_Y2, $in_Y3, $in_Y4),
+ ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0),
list(2*$rows/4, $cols),
+ list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0),
list($rows, $cols)));
+
+} else {
+ A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4),
list($rows, $cols/2),
+ list(0,$cols/2), list($rows, 3*($cols/4)), list(0,
3*($cols/4)), list($rows, $cols)));
+ B = federated(addresses=list($in_Y1, $in_Y2, $in_Y3, $in_Y4),
+ ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4),
list($rows, $cols/2),
+ list(0,$cols/2), list($rows, 3*($cols/4)), list(0,
3*($cols/4)), list($rows, $cols)));
+}
+c1 = ifelse(A-3>0, B, A*2);
+c2 = ifelse(A-3>0, B, matrix(1, $rows, $cols));
+s = c1 + c2;
+write(s, $out_S);
diff --git a/scripts/builtin/lmPredict.dml
b/src/test/scripts/functions/federated/FederatedIfelseAlignedTestReference.dml
similarity index 53%
copy from scripts/builtin/lmPredict.dml
copy to
src/test/scripts/functions/federated/FederatedIfelseAlignedTestReference.dml
index 1bdabf3..b899a9b 100644
--- a/scripts/builtin/lmPredict.dml
+++
b/src/test/scripts/functions/federated/FederatedIfelseAlignedTestReference.dml
@@ -19,24 +19,17 @@
#
#-------------------------------------------------------------
-m_lmPredict = function(Matrix[Double] X, Matrix[Double] B,
- Matrix[Double] ytest = matrix(0,1,1), Integer icpt = 0, Boolean verbose =
FALSE)
- return (Matrix[Double] yhat)
-{
- intercept = ifelse(icpt==0, matrix(0,1,ncol(B)), B[nrow(B),]);
- yhat = X %*% B[1:ncol(X)] + matrix(1,nrow(X),1) %*% intercept;
- if( verbose ) {
- y_residual = ytest - yhat;
- avg_res = sum(y_residual) / nrow(ytest);
- ss_res = sum(y_residual^2);
- ss_avg_res = ss_res - nrow(ytest) * avg_res^2;
- R2 = 1 - ss_res / (sum(ytest^2) - nrow(ytest) *
(sum(ytest)/nrow(ytest))^2);
- print("\nAccuracy:" +
- "\n--sum(ytest) = " + sum(ytest) +
- "\n--sum(yhat) = " + sum(yhat) +
- "\n--AVG_RES_Y: " + avg_res +
- "\n--SS_AVG_RES_Y: " + ss_avg_res +
- "\n--R2: " + R2 );
- }
+if($10) {
+ A = rbind(read($1), read($2), read($3), read($4));
+ B = rbind(read($5), read($6), read($7), read($8));
}
+else {
+ A = cbind(read($1), read($2), read($3), read($4));
+ B = cbind(read($5), read($6), read($7), read($8));
+}
+
+c1 = ifelse(A-3>0, B, A*2);
+c2 = ifelse(A-3>0, B, matrix(1, nrow(A), ncol(A)));
+s = c1 + c2;
+write(s, $9);
diff --git a/scripts/builtin/lmPredict.dml
b/src/test/scripts/functions/federated/FederatedIfelseTest.dml
similarity index 53%
copy from scripts/builtin/lmPredict.dml
copy to src/test/scripts/functions/federated/FederatedIfelseTest.dml
index 1bdabf3..da650a3 100644
--- a/scripts/builtin/lmPredict.dml
+++ b/src/test/scripts/functions/federated/FederatedIfelseTest.dml
@@ -19,24 +19,18 @@
#
#-------------------------------------------------------------
-m_lmPredict = function(Matrix[Double] X, Matrix[Double] B,
- Matrix[Double] ytest = matrix(0,1,1), Integer icpt = 0, Boolean verbose =
FALSE)
- return (Matrix[Double] yhat)
-{
- intercept = ifelse(icpt==0, matrix(0,1,ncol(B)), B[nrow(B),]);
- yhat = X %*% B[1:ncol(X)] + matrix(1,nrow(X),1) %*% intercept;
-
- if( verbose ) {
- y_residual = ytest - yhat;
- avg_res = sum(y_residual) / nrow(ytest);
- ss_res = sum(y_residual^2);
- ss_avg_res = ss_res - nrow(ytest) * avg_res^2;
- R2 = 1 - ss_res / (sum(ytest^2) - nrow(ytest) *
(sum(ytest)/nrow(ytest))^2);
- print("\nAccuracy:" +
- "\n--sum(ytest) = " + sum(ytest) +
- "\n--sum(yhat) = " + sum(yhat) +
- "\n--AVG_RES_Y: " + avg_res +
- "\n--SS_AVG_RES_Y: " + ss_avg_res +
- "\n--R2: " + R2 );
- }
+if ($rP) {
+ A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0),
list(2*$rows/4, $cols),
+ list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0),
list($rows, $cols)));
+} else {
+ A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4),
list($rows, $cols/2),
+ list(0,$cols/2), list($rows, 3*($cols/4)), list(0,
3*($cols/4)), list($rows, $cols)));
}
+c1 = ifelse(A>0, A + matrix(1, $rows, $cols), A*2);
+c2 = ifelse(A-3>0, A + matrix(1, $rows, $cols), 3);
+c3 = ifelse(1, matrix(1, $rows, $cols), 3);
+s = c2 + c3;
+s = s + 10*c1;
+write(s, $out_S);
diff --git a/scripts/builtin/lmPredict.dml
b/src/test/scripts/functions/federated/FederatedIfelseTestReference.dml
similarity index 53%
copy from scripts/builtin/lmPredict.dml
copy to src/test/scripts/functions/federated/FederatedIfelseTestReference.dml
index 1bdabf3..e232cdf 100644
--- a/scripts/builtin/lmPredict.dml
+++ b/src/test/scripts/functions/federated/FederatedIfelseTestReference.dml
@@ -19,24 +19,13 @@
#
#-------------------------------------------------------------
-m_lmPredict = function(Matrix[Double] X, Matrix[Double] B,
- Matrix[Double] ytest = matrix(0,1,1), Integer icpt = 0, Boolean verbose =
FALSE)
- return (Matrix[Double] yhat)
-{
- intercept = ifelse(icpt==0, matrix(0,1,ncol(B)), B[nrow(B),]);
- yhat = X %*% B[1:ncol(X)] + matrix(1,nrow(X),1) %*% intercept;
- if( verbose ) {
- y_residual = ytest - yhat;
- avg_res = sum(y_residual) / nrow(ytest);
- ss_res = sum(y_residual^2);
- ss_avg_res = ss_res - nrow(ytest) * avg_res^2;
- R2 = 1 - ss_res / (sum(ytest^2) - nrow(ytest) *
(sum(ytest)/nrow(ytest))^2);
- print("\nAccuracy:" +
- "\n--sum(ytest) = " + sum(ytest) +
- "\n--sum(yhat) = " + sum(yhat) +
- "\n--AVG_RES_Y: " + avg_res +
- "\n--SS_AVG_RES_Y: " + ss_avg_res +
- "\n--R2: " + R2 );
- }
-}
+if($6) { A = rbind(read($1), read($2), read($3), read($4)); }
+else { A = cbind(read($1), read($2), read($3), read($4)); }
+
+c1 = ifelse(A>0, A + matrix(1, nrow(A), ncol(A)), A*2);
+c2 = ifelse(A-3>0, A + matrix(1, nrow(A), ncol(A)), 3);
+c3 = ifelse(1, matrix(1, nrow(A), ncol(A)), 3);
+s = c2 + c3;
+s = s + 10*c1;
+write(s, $5);