This is an automated email from the ASF dual-hosted git repository.

sebwrede pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 1434954  [SYSTEMDS-2904] Federated ternary instruction
1434954 is described below

commit 143495400759c21e030c1809e800e5e6fc68b2da
Author: sebwrede <[email protected]>
AuthorDate: Fri Apr 2 11:04:06 2021 +0200

    [SYSTEMDS-2904] Federated ternary instruction
    
    This commit adds TernaryFEDInstruction to the system and extends processing 
of other federated instructions.
        This commit closes PR #1193.
---
 scripts/builtin/lmPredict.dml                      |   2 +-
 .../controlprogram/federated/FederationMap.java    |  20 ++
 .../fed/AggregateBinaryFEDInstruction.java         |   2 +-
 .../fed/AggregateTernaryFEDInstruction.java        |  27 ++-
 .../fed/BinaryMatrixMatrixFEDInstruction.java      |  26 ++-
 .../runtime/instructions/fed/FEDInstruction.java   |   1 +
 .../instructions/fed/FEDInstructionUtils.java      |   9 +
 .../instructions/fed/TernaryFEDInstruction.java    | 226 +++++++++++++++++++++
 .../fed/UnaryMatrixFEDInstruction.java             |  16 +-
 .../runtime/matrix/operators/BinaryOperator.java   |   3 +-
 .../federated/primitives/FederatedIfelseTest.java  | 224 ++++++++++++++++++++
 .../federated/FederatedIfelseAlignedTest.dml       |  42 ++++
 .../FederatedIfelseAlignedTestReference.dml        |  31 ++-
 .../functions/federated/FederatedIfelseTest.dml    |  34 ++--
 .../federated/FederatedIfelseTestReference.dml     |  29 +--
 15 files changed, 624 insertions(+), 68 deletions(-)

diff --git a/scripts/builtin/lmPredict.dml b/scripts/builtin/lmPredict.dml
index 1bdabf3..52ea656 100644
--- a/scripts/builtin/lmPredict.dml
+++ b/scripts/builtin/lmPredict.dml
@@ -27,7 +27,7 @@ m_lmPredict = function(Matrix[Double] X, Matrix[Double] B,
   yhat = X %*% B[1:ncol(X)] + matrix(1,nrow(X),1) %*% intercept;
 
   if( verbose ) {
-    y_residual = ytest - yhat;
+    y_residual =  ytest - yhat;
     avg_res = sum(y_residual) / nrow(ytest);
     ss_res = sum(y_residual^2);
     ss_avg_res = ss_res - nrow(ytest) * avg_res^2;
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
index 7278123..b066d1e 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
@@ -212,6 +212,26 @@ public class FederationMap {
                return ret.toArray(new Future[0]);
        }
 
+       @SuppressWarnings("unchecked")
+       public Future<FederatedResponse>[] execute(long tid, boolean wait, 
FederatedRequest[] frSlices1, FederatedRequest[] frSlices2, FederatedRequest... 
fr) {
+               // executes step1[] - step 2 - ... step4 (only first step 
federated-data-specific)
+               setThreadID(tid, frSlices1, fr);
+               setThreadID(tid, frSlices2, fr);
+               List<Future<FederatedResponse>> ret = new ArrayList<>();
+               int pos = 0;
+               for(Entry<FederatedRange, FederatedData> e : 
_fedMap.entrySet()) {
+                       FederatedRequest[] newFr = (frSlices1!=null) ?
+                               ((frSlices2!=null)? (addAll(frSlices2[pos], 
addAll(frSlices1[pos++], fr))) : addAll(frSlices1[pos++], fr)) : fr;
+                       ret.add(e.getValue().executeFederatedOperation(newFr));
+               }
+
+               // prepare results (future federated responses), with optional 
wait to ensure the
+               // order of requests without data dependencies (e.g., cleanup 
RPCs)
+               if( wait )
+                       FederationUtils.waitFor(ret);
+               return ret.toArray(new Future[0]);
+       }
+       
        public List<Pair<FederatedRange, Future<FederatedResponse>>> 
requestFederatedData() {
                if(!isInitialized())
                        throw new DMLRuntimeException("Federated matrix read 
only supported on initialized FederatedData");
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
index 12616ed..9822bef 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
@@ -80,7 +80,7 @@ public class AggregateBinaryFEDInstruction extends 
BinaryFEDInstruction {
                        FederatedRequest fr1 = 
mo1.getFedMapping().broadcast(mo2);
                        FederatedRequest fr2 = 
FederationUtils.callInstruction(instString, output,
                                new CPOperand[]{input1, input2}, new 
long[]{mo1.getFedMapping().getID(), fr1.getID()});
-                       if( mo2.getNumColumns() == 1 ) { //MV
+                       if( mo2.getNumColumns() == 1 && mo2.getNumRows() != 
mo1.getNumColumns()) { //MV
                                FederatedRequest fr3 = new 
FederatedRequest(RequestType.GET_VAR, fr2.getID());
                                FederatedRequest fr4 = 
mo1.getFedMapping().cleanup(getTID(), fr1.getID(), fr2.getID());
                                //execute federated operations and aggregate
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateTernaryFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateTernaryFEDInstruction.java
index 2bc187e..42a6e0e 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateTernaryFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateTernaryFEDInstruction.java
@@ -52,7 +52,6 @@ public class AggregateTernaryFEDInstruction extends 
FEDInstruction {
                MatrixObject mo1 = ec.getMatrixObject(_ins.input1);
                MatrixObject mo2 = ec.getMatrixObject(_ins.input2);
                MatrixObject mo3 = _ins.input3.isLiteral() ? null : 
ec.getMatrixObject(_ins.input3);
-
                if(mo1.isFederated() && mo2.isFederated() && 
mo1.getFedMapping().isAligned(mo2.getFedMapping(), false) &&
                        mo3 == null) {
                        FederatedRequest fr1 = 
mo1.getFedMapping().broadcast(ec.getScalarInput(_ins.input3));
@@ -79,6 +78,32 @@ public class AggregateTernaryFEDInstruction extends 
FEDInstruction {
                        else {
                                throw new DMLRuntimeException("Not Implemented 
Federated Ternary Variation");
                        }
+               } else if(mo1.isFederated() && _ins.input3.isMatrix() && mo3 != 
null) {
+                       FederatedRequest[] fr1 = 
mo1.getFedMapping().broadcastSliced(mo3, false);
+                       FederatedRequest[] fr2 = 
mo1.getFedMapping().broadcastSliced(mo2, false);
+                       FederatedRequest fr3 = 
FederationUtils.callInstruction(_ins.getInstructionString(),
+                               _ins.getOutput(),
+                               new CPOperand[] {_ins.input1, _ins.input2, 
_ins.input3},
+                               new long[] {mo1.getFedMapping().getID(), 
fr2[0].getID(), fr1[0].getID()});
+                       FederatedRequest fr4 = new 
FederatedRequest(RequestType.GET_VAR, fr3.getID());
+                       FederatedRequest fr5 = 
mo2.getFedMapping().cleanup(getTID(), fr1[0].getID(), fr2[0].getID());
+                       Future<FederatedResponse>[] tmp = 
mo1.getFedMapping().execute(getTID(), fr1, fr2[0], fr3, fr4, fr5);
+
+                       if(_ins.output.getDataType().isScalar()) {
+                               double sum = 0;
+                               for(Future<FederatedResponse> fr : tmp)
+                                       try {
+                                               sum += ((ScalarObject) 
fr.get().getData()[0]).getDoubleValue();
+                                       }
+                                       catch(Exception e) {
+                                               throw new 
DMLRuntimeException("Federated Get data failed with exception on 
TernaryFedInstruction", e);
+                                       }
+
+                               ec.setScalarOutput(_ins.output.getName(), new 
DoubleObject(sum));
+                       }
+                       else {
+                               throw new DMLRuntimeException("Not Implemented 
Federated Ternary Variation");
+                       }
                }
                else {
                        if(mo3 == null)
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
index 6f7dcc9..b6d0227 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
@@ -58,6 +58,12 @@ public class BinaryMatrixMatrixFEDInstruction extends 
BinaryFEDInstruction
                                        new long[]{mo1.getFedMapping().getID(), 
mo2.getFedMapping().getID()});
                                mo1.getFedMapping().execute(getTID(), true, 
fr2);
                        }
+                       else if ( !mo1.isFederated() ){
+                               FederatedRequest[] fr1 = 
mo2.getFedMapping().broadcastSliced(mo1, false);
+                               fr2 = 
FederationUtils.callInstruction(instString, output, new CPOperand[]{input1, 
input2},
+                                       new long[]{fr1[0].getID(), 
mo2.getFedMapping().getID()});
+                               mo2.getFedMapping().execute(getTID(), true, 
fr1, fr2);
+                       }
                        else {
                                throw new DMLRuntimeException("Matrix-matrix 
binary operations with a "
                                        + "federated right input are only 
supported for special cases yet.");
@@ -103,10 +109,22 @@ public class BinaryMatrixMatrixFEDInstruction extends 
BinaryFEDInstruction
                        }
                }
 
-               // derive new fed mapping for output
-               MatrixObject out = ec.getMatrixObject(output);
+               if ( mo1.isFederated() )
+                       setOutputFedMapping(mo1, fr2.getID(), ec);
+               else if ( mo2.isFederated() )
+                       setOutputFedMapping(mo2, fr2.getID(), ec);
+               else throw new DMLRuntimeException("Input is not federated, so 
the output FedMapping cannot be set!");
+       }
 
-               out.getDataCharacteristics().set(mo1.getDataCharacteristics());
-               
out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr2.getID()));
+       /**
+        * Set data characteristics and fed mapping for output.
+        * @param moFederated federated matrix object from which data 
characteristics and fed mapping are derived
+        * @param outputFedmappingID ID for the fed mapping of output
+        * @param ec execution context
+        */
+       private void setOutputFedMapping(MatrixObject moFederated, long 
outputFedmappingID, ExecutionContext ec){
+               MatrixObject out = ec.getMatrixObject(output);
+               
out.getDataCharacteristics().set(moFederated.getDataCharacteristics());
+               
out.setFedMapping(moFederated.getFedMapping().copyWithNewID(outputFedmappingID));
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
index dafd723..8ed9aba 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
@@ -40,6 +40,7 @@ public abstract class FEDInstruction extends Instruction {
                Reorg,
                Reshape,
                MatrixIndexing,
+       Ternary,
                Quaternary,
                QSort,
                QPick,
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index 1417c90..a4c750c 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -38,6 +38,7 @@ import 
org.apache.sysds.runtime.instructions.cp.MultiReturnParameterizedBuiltinC
 import 
org.apache.sysds.runtime.instructions.cp.ParameterizedBuiltinCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.QuaternaryCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.ReorgCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.TernaryCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.UnaryCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.UnaryMatrixCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
@@ -174,6 +175,14 @@ public class FEDInstructionUtils {
                                fedinst = 
IndexingFEDInstruction.parseInstruction(minst.getInstructionString());
                        }
                }
+               else if(inst instanceof TernaryCPInstruction) {
+                       TernaryCPInstruction tinst = (TernaryCPInstruction) 
inst;
+                       if((tinst.input1.isMatrix() && 
ec.getCacheableData(tinst.input1).isFederated())
+                               || (tinst.input2.isMatrix() && 
ec.getCacheableData(tinst.input2).isFederated())
+                               || (tinst.input3.isMatrix() && 
ec.getCacheableData(tinst.input3).isFederated())) {
+                               fedinst = 
TernaryFEDInstruction.parseInstruction(tinst.getInstructionString());
+                       }
+               }
                else if(inst instanceof VariableCPInstruction ){
                        VariableCPInstruction ins = (VariableCPInstruction) 
inst;
                        if(ins.getVariableOpcode() == 
VariableOperationCode.Write
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
new file mode 100644
index 0000000..957ee6e
--- /dev/null
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
@@ -0,0 +1,226 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.fed;
+
+import java.util.Objects;
+
+import com.sun.tools.javac.util.List;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.matrix.operators.TernaryOperator;
+
+public class TernaryFEDInstruction extends ComputationFEDInstruction {
+
+       private TernaryFEDInstruction(TernaryOperator op, CPOperand in1, 
CPOperand in2, CPOperand in3, CPOperand out, String opcode, String str) {
+               super(FEDInstruction.FEDType.Ternary, op, in1, in2, in3, out, 
opcode, str);
+       }
+
+       public static TernaryFEDInstruction parseInstruction(String str) {
+               String[] parts = 
InstructionUtils.getInstructionPartsWithValueType(str);
+               String opcode = parts[0];
+               CPOperand operand1 = new CPOperand(parts[1]);
+               CPOperand operand2 = new CPOperand(parts[2]);
+               CPOperand operand3 = new CPOperand(parts[3]);
+               CPOperand outOperand = new CPOperand(parts[4]);
+               TernaryOperator op = 
InstructionUtils.parseTernaryOperator(opcode);
+               return new TernaryFEDInstruction(op, operand1, operand2, 
operand3, outOperand, opcode, str);
+       }
+
+       @Override
+       public void processInstruction(ExecutionContext ec) {
+               MatrixObject mo1 = input1.isMatrix() ? 
ec.getMatrixObject(input1.getName()) : null;
+               MatrixObject mo2 = input2.isMatrix() ? 
ec.getMatrixObject(input2.getName()) : null;
+               MatrixObject mo3 = input3 != null && input3.isMatrix() ? 
ec.getMatrixObject(input3.getName()) : null;
+
+               long matrixInputsCount = List.of(mo1, mo2, 
mo3).stream().filter(Objects::nonNull).count();
+
+               if(matrixInputsCount == 3)
+                       processMatrixInput(ec, mo1, mo2, mo3);
+               else if(matrixInputsCount == 1) {
+                       CPOperand in = mo1 == null ? mo2 == null ? input3 : 
input2 : input1;
+                       mo1 = mo1 == null ? mo2 == null ? mo3 : mo2 : mo1;
+                       processMatrixScalarInput(ec, mo1, in);
+               }
+               else {
+                       if(mo1 != null && mo2 != null) {
+                               if(input3 != null && !input3.isLiteral())
+                                       instString = 
InstructionUtils.replaceOperand(instString, 4,
+                                               
InstructionUtils.createLiteralOperand(ec.getScalarInput(input3).getStringValue(),
 Types.ValueType.FP64));
+                               process2MatrixScalarInput(ec, mo1, mo2, input1, 
input2);
+                       }
+                       else if(mo2 != null && mo3 != null) {
+                               if(!input1.isLiteral())
+                                       instString = 
InstructionUtils.replaceOperand(instString, 2,
+                                               
InstructionUtils.createLiteralOperand(ec.getScalarInput(input1).getStringValue(),
 Types.ValueType.FP64));
+                               process2MatrixScalarInput(ec, mo2, mo3, input2, 
input3);
+                       }
+                       else if(mo1 != null && mo3 != null) {
+                               if(!input2.isLiteral())
+                                       instString = 
InstructionUtils.replaceOperand(instString, 3,
+                                               
InstructionUtils.createLiteralOperand(ec.getScalarInput(input2).getStringValue(),
 Types.ValueType.FP64));
+                               process2MatrixScalarInput(ec, mo1, mo3, input1, 
input3);
+                       }
+               }
+       }
+
+       private void processMatrixScalarInput(ExecutionContext ec, MatrixObject 
mo1, CPOperand in) {
+               FederatedRequest fr1 = 
FederationUtils.callInstruction(instString, output, new CPOperand[] {in}, new 
long[] {mo1.getFedMapping().getID()});
+               mo1.getFedMapping().execute(getTID(), true, fr1);
+
+               setOutputFedMapping(ec, mo1, fr1.getID());
+       }
+
+       private void process2MatrixScalarInput(ExecutionContext ec, 
MatrixObject mo1, MatrixObject mo2, CPOperand in1, CPOperand in2) {
+               FederatedRequest[] fr1 = null;
+               CPOperand[] varOldIn;
+               boolean cleanupIn = true;
+               long[] varNewIn;
+               varOldIn = new CPOperand[] {in1, in2};
+               if(mo1.isFederated()) {
+                       if(mo2.isFederated() && 
mo1.getFedMapping().isAligned(mo2.getFedMapping(), false))
+                               varNewIn = new 
long[]{mo1.getFedMapping().getID(), mo2.getFedMapping().getID()};
+                       else {
+                               fr1 = mo1.getFedMapping().broadcastSliced(mo2, 
false);
+                               varNewIn = new 
long[]{mo1.getFedMapping().getID(), fr1[0].getID()};
+                       }
+               } else {
+                       cleanupIn = false;
+                       mo1 = ec.getMatrixObject(in2);
+                       fr1 = 
mo1.getFedMapping().broadcastSliced(ec.getMatrixObject(in1), false);
+                       varNewIn = new long[]{fr1[0].getID(), 
mo1.getFedMapping().getID()};
+               }
+               FederatedRequest fr2 = 
FederationUtils.callInstruction(instString, output, varOldIn, varNewIn);
+               FederatedRequest fr3;
+
+               // 2 aligned inputs
+               if(fr1 == null) {
+                       mo1.getFedMapping().execute(getTID(), true, fr2);
+               } else {
+                       if(cleanupIn) {
+                               fr3 = mo1.getFedMapping().cleanup(getTID(), 
fr1[0].getID());
+                               mo1.getFedMapping().execute(getTID(), true, 
fr1, fr2, fr3);
+                       } else
+                               mo1.getFedMapping().execute(getTID(), true, 
fr1, fr2);
+               }
+               setOutputFedMapping(ec, mo1, fr2.getID());
+       }
+
+       private void processMatrixInput(ExecutionContext ec, MatrixObject mo1, 
MatrixObject mo2, MatrixObject mo3) {
+
+               // check aligned matrices
+               RetAlignedValues retAlignedValues = getAlignedInputs(ec, mo1, 
mo2, mo3);
+
+               FederatedRequest[] fr2;
+               FederatedRequest fr3, fr4;
+
+               // all 3 inputs fed aligned on the one worker
+               if(retAlignedValues._allAligned) {
+                       fr3 = FederationUtils.callInstruction(instString, 
output, new CPOperand[] {input1, input2, input3},
+                               new long[] {mo1.getFedMapping().getID(), 
mo2.getFedMapping().getID(), mo3.getFedMapping().getID()});
+                       mo1.getFedMapping().execute(getTID(), fr3);
+               }
+               // 2 fed aligned inputs
+               else if(retAlignedValues._twoAligned) {
+                       fr3 = FederationUtils.callInstruction(instString, 
output, new CPOperand[] {input1, input2, input3}, retAlignedValues._vars);
+                       fr4 = mo1.getFedMapping().cleanup(getTID(), 
retAlignedValues._fr[0].getID());
+                       mo1.getFedMapping().execute(getTID(), true, 
retAlignedValues._fr, fr3, fr4);
+               }
+               // 1 fed input or not aligned
+               else {
+                       if(!mo1.isFederated())
+                               if(mo2.isFederated()) {
+                                       mo1 = mo2;
+                                       mo2 = ec.getMatrixObject(input1);
+                               }
+                               else {
+                                       mo1 = mo3;
+                                       mo3 = ec.getMatrixObject(input1);
+                               }
+
+                       FederatedRequest[] fr1 = 
mo1.getFedMapping().broadcastSliced(mo2, false);
+                       fr2 = mo1.getFedMapping().broadcastSliced(mo3, false);
+
+                       long[] vars = new long[] {mo1.getFedMapping().getID(), 
fr1[0].getID(), fr2[0].getID()};
+                       if(!ec.getMatrixObject(input1).isFederated())
+                               vars = ec.getMatrixObject(input2).isFederated() 
? new long[] {fr1[0].getID(), mo1.getFedMapping().getID(), fr2[0].getID()} : 
new long[] {fr1[0].getID(), fr2[0].getID(),
+                                       mo1.getFedMapping().getID()};
+
+                       fr3 = FederationUtils.callInstruction(instString, 
output, new CPOperand[] {input1, input2, input3}, vars);
+                       fr4 = mo1.getFedMapping().cleanup(getTID(), 
fr1[0].getID(), fr2[0].getID());
+                       mo1.getFedMapping().execute(getTID(), true, fr1, fr2, 
fr3, fr4);
+               }
+
+               //derive new fed mapping for output
+               setOutputFedMapping(ec, mo1, fr3.getID());
+       }
+
+       // check aligned matrices and return vars
+       private RetAlignedValues getAlignedInputs(ExecutionContext ec, 
MatrixObject mo1, MatrixObject mo2, MatrixObject mo3) {
+               long[] vars = new long[0];
+               FederatedRequest[] fr = new FederatedRequest[0];
+               boolean twoAligned = false, allAligned = false;
+               if(mo1.isFederated() && mo2.isFederated() && 
mo1.getFedMapping().isAligned(mo2.getFedMapping(), false)) {
+                       twoAligned = true;
+                       fr = mo1.getFedMapping().broadcastSliced(mo3, false);
+                       vars = new long[] {mo1.getFedMapping().getID(), 
mo2.getFedMapping().getID(), fr[0].getID()};
+               }
+               if(mo1.isFederated() && mo3.isFederated() && 
mo1.getFedMapping().isAligned(mo3.getFedMapping(), false)) {
+                       allAligned = twoAligned;
+                       twoAligned = true;
+                       fr = mo1.getFedMapping().broadcastSliced(mo2, false);
+                       vars = new long[] {mo1.getFedMapping().getID(), 
fr[0].getID(), mo3.getFedMapping().getID()};
+               }
+               if(mo2.isFederated() && mo3.isFederated() && 
mo2.getFedMapping().isAligned(mo3.getFedMapping(), false) && !allAligned) {
+                       twoAligned = true;
+                       mo1 = mo2;
+                       mo2 = mo3;
+                       mo3 = ec.getMatrixObject(input1);
+                       fr = mo1.getFedMapping().broadcastSliced(mo3, false);
+                       vars = new long[] {fr[0].getID(), 
mo1.getFedMapping().getID(), mo2.getFedMapping().getID()};
+               }
+
+               return new RetAlignedValues(twoAligned, allAligned, vars, fr);
+       }
+
+       private static final class RetAlignedValues {
+               public boolean _twoAligned;
+               public boolean _allAligned;
+               public long[] _vars;
+               public FederatedRequest[] _fr;
+
+               public RetAlignedValues(boolean twoAligned, boolean allAligned, 
long[] vars, FederatedRequest[] fr) {
+                       _twoAligned = twoAligned;
+                       _allAligned = allAligned;
+                       _vars = vars;
+                       _fr = fr;
+               }
+       }
+
+       private void setOutputFedMapping(ExecutionContext ec, MatrixObject 
fedMapObj, long fedOutputID) {
+               MatrixObject out = ec.getMatrixObject(output);
+               
out.getDataCharacteristics().set(fedMapObj.getDataCharacteristics());
+               
out.setFedMapping(fedMapObj.getFedMapping().copyWithNewID(fedOutputID));
+       }
+}
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryMatrixFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryMatrixFEDInstruction.java
index 8e5104c..c6e0b2c 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryMatrixFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryMatrixFEDInstruction.java
@@ -25,10 +25,13 @@ import 
org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
 import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.functionobjects.Builtin;
+import org.apache.sysds.runtime.functionobjects.ValueFunction;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
 import org.apache.sysds.runtime.matrix.data.LibCommonsMath;
 import org.apache.sysds.runtime.matrix.operators.Operator;
+import org.apache.sysds.runtime.matrix.operators.UnaryOperator;
 
 public class UnaryMatrixFEDInstruction extends UnaryFEDInstruction {
        protected UnaryMatrixFEDInstruction(Operator op, CPOperand in, 
CPOperand out, String opcode, String instr) {
@@ -43,7 +46,18 @@ public class UnaryMatrixFEDInstruction extends 
UnaryFEDInstruction {
        public static UnaryMatrixFEDInstruction parseInstruction(String str) {
                CPOperand in = new CPOperand("", ValueType.UNKNOWN, 
DataType.UNKNOWN);
                CPOperand out = new CPOperand("", ValueType.UNKNOWN, 
DataType.UNKNOWN);
-               String opcode = parseUnaryInstruction(str, in, out);
+
+               String[] parts = 
InstructionUtils.getInstructionPartsWithValueType(str);
+               String opcode;
+               opcode = parts[0];
+               if( opcode.equalsIgnoreCase("exp") && parts.length == 5) {
+                       in.split(parts[1]);
+                       out.split(parts[2]);
+                       ValueFunction func = Builtin.getBuiltinFnObject(opcode);
+                       return new UnaryMatrixFEDInstruction(new 
UnaryOperator(func,
+                               
Integer.parseInt(parts[3]),Boolean.parseBoolean(parts[4])), in, out, opcode, 
str);
+               }
+               opcode = parseUnaryInstruction(str, in, out);
                return new 
UnaryMatrixFEDInstruction(InstructionUtils.parseUnaryOperator(opcode), in, out, 
opcode, str);
        }
        
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/operators/BinaryOperator.java 
b/src/main/java/org/apache/sysds/runtime/matrix/operators/BinaryOperator.java
index 7579046..7cf201a 100644
--- 
a/src/main/java/org/apache/sysds/runtime/matrix/operators/BinaryOperator.java
+++ 
b/src/main/java/org/apache/sysds/runtime/matrix/operators/BinaryOperator.java
@@ -38,6 +38,7 @@ import org.apache.sysds.runtime.functionobjects.IntegerDivide;
 import org.apache.sysds.runtime.functionobjects.LessThan;
 import org.apache.sysds.runtime.functionobjects.LessThanEquals;
 import org.apache.sysds.runtime.functionobjects.Minus;
+import org.apache.sysds.runtime.functionobjects.Minus1Multiply;
 import org.apache.sysds.runtime.functionobjects.MinusMultiply;
 import org.apache.sysds.runtime.functionobjects.MinusNz;
 import org.apache.sysds.runtime.functionobjects.Modulus;
@@ -68,7 +69,7 @@ public class BinaryOperator  extends Operator implements 
Serializable
                        || p instanceof BitwShiftL || p instanceof BitwShiftR);
                fn = p;
                commutative = p instanceof Plus || p instanceof Multiply 
-                       || p instanceof And || p instanceof Or || p instanceof 
Xor;
+                       || p instanceof And || p instanceof Or || p instanceof 
Xor || p instanceof Minus1Multiply;
        }
        
        public void setNumThreads(int k) {
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedIfelseTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedIfelseTest.java
new file mode 100644
index 0000000..c8733ff
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedIfelseTest.java
@@ -0,0 +1,224 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.primitives;
+
+import java.util.Arrays;
+import java.util.Collection;
+
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(value = Parameterized.class)
[email protected]
+public class FederatedIfelseTest extends AutomatedTestBase {
+       private final static String TEST_NAME1 = "FederatedIfelseTest";
+       private final static String TEST_NAME2 = "FederatedIfelseAlignedTest";
+
+       private final static String TEST_DIR = "functions/federated/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
FederatedIfelseTest.class.getSimpleName() + "/";
+
+       private final static int blocksize = 1024;
+
+       @Parameterized.Parameter()
+       public int rows;
+       @Parameterized.Parameter(1)
+       public int cols;
+       @Parameterized.Parameter(2)
+       public boolean rowPartitioned;
+
+       @Parameterized.Parameters
+       public static Collection<Object[]> data() {
+               return Arrays.asList(
+                       new Object[][] {
+                               {64, 16, true},
+                               {64, 16, false},
+                       });
+       }
+
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"S"}));
+               addTestConfiguration(TEST_NAME2, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {"S"}));
+       }
+
+       @Test
+       public void testIfelseDiffWorkersCP() {
+               runTernaryTest(ExecMode.SINGLE_NODE, false);
+       }
+
+       @Test
+       public void testIfelseAlignedCP() {
+               runTernaryTest(ExecMode.SINGLE_NODE, true);
+       }
+
+       private void runTernaryTest(ExecMode execMode, boolean alligned) {
+               boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+               ExecMode platformOld = rtplatform;
+
+               if(rtplatform == ExecMode.SPARK)
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+
+               String TEST_NAME = alligned ? TEST_NAME2 : TEST_NAME1;
+
+               getAndLoadTestConfiguration(TEST_NAME);
+               String HOME = SCRIPT_DIR + TEST_DIR;
+
+               // write input matrices
+               int r = rows;
+               int c = cols / 4;
+               if(rowPartitioned) {
+                       r = rows / 4;
+                       c = cols;
+               }
+
+               double[][] X1 = getRandomMatrix(r, c, 1, 5, 1, 3);
+               double[][] X2 = getRandomMatrix(r, c, 1, 5, 1, 7);
+               double[][] X3 = getRandomMatrix(r, c, 1, 5, 1, 8);
+               double[][] X4 = getRandomMatrix(r, c, 1, 5, 1, 9);
+               X1[0][0] = 0;
+               MatrixCharacteristics mc = new MatrixCharacteristics(r, c, 
blocksize, r * c);
+               writeInputMatrixWithMTD("X1", X1, false, mc);
+               writeInputMatrixWithMTD("X2", X2, false, mc);
+               writeInputMatrixWithMTD("X3", X3, false, mc);
+               writeInputMatrixWithMTD("X4", X4, false, mc);
+
+               double[][] Y1 = getRandomMatrix(r, c, 10, 15, 1, 3);
+               double[][] Y2 = getRandomMatrix(r, c, 10, 15,1, 7);
+               double[][] Y3 = getRandomMatrix(r, c, 10, 15,1, 8);
+               double[][] Y4 = getRandomMatrix(r, c, 10, 15, 1, 9);
+               MatrixCharacteristics mc2 = new MatrixCharacteristics(r, c, 
blocksize, r * c);
+               writeInputMatrixWithMTD("Y1", Y1, false, mc2);
+               writeInputMatrixWithMTD("Y2", Y2, false, mc2);
+               writeInputMatrixWithMTD("Y3", Y3, false, mc2);
+               writeInputMatrixWithMTD("Y4", Y4, false, mc2);
+
+               // empty script name because we don't execute any script, just 
start the worker
+               fullDMLScriptName = "";
+               int port1 = getRandomAvailablePort();
+               int port2 = getRandomAvailablePort();
+               int port3 = getRandomAvailablePort();
+               int port4 = getRandomAvailablePort();
+               Thread t1 = startLocalFedWorkerThread(port1);
+               Thread t2 = startLocalFedWorkerThread(port2);
+               Thread t3 = startLocalFedWorkerThread(port3);
+               Thread t4 = startLocalFedWorkerThread(port4);
+
+               rtplatform = execMode;
+               if(rtplatform == ExecMode.SPARK)
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+
+               TestConfiguration config = 
availableTestConfigurations.get(TEST_NAME);
+               loadTestConfiguration(config);
+
+               if(alligned)
+                       runAlignedTernary(HOME, TEST_NAME, r, c, port1, port2, 
port3, port4);
+               else
+                       runTernary(HOME, TEST_NAME, port1, port2, port3, port4);
+
+               // compare via files
+               compareResults(1e-9);
+               Assert.assertTrue(heavyHittersContainsString("fed_ifelse"));
+
+               // check that federated input files are still existing
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X3")));
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4")));
+
+               if(alligned) {
+                       
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("Y1")));
+                       
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("Y2")));
+                       
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("Y3")));
+                       
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("Y4")));
+               }
+
+               TestUtils.shutdownThreads(t1, t2, t3, t4);
+
+               rtplatform = platformOld;
+               DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+       }
+
+       private void runTernary(String HOME, String TEST_NAME, int port1, int 
port2, int port3, int port4) {
+               // Run reference dml script with normal matrix
+               fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+               programArgs = new String[] {"-stats", "100", "-args", 
input("X1"), input("X2"), input("X3"), input("X4"),
+                       expected("S"), 
Boolean.toString(rowPartitioned).toUpperCase()};
+               runTest(true, false, null, -1);
+
+               // Run actual dml script with federated matrix
+
+               fullDMLScriptName = HOME + TEST_NAME + ".dml";
+               programArgs = new String[] {"-stats", "100", "-nvargs",
+                       "in_X1=" + TestUtils.federatedAddress(port1, 
input("X1")),
+                       "in_X2=" + TestUtils.federatedAddress(port2, 
input("X2")),
+                       "in_X3=" + TestUtils.federatedAddress(port3, 
input("X3")),
+                       "in_X4=" + TestUtils.federatedAddress(port4, 
input("X4")), "rows=" + rows, "cols=" + cols,
+                       "rP=" + Boolean.toString(rowPartitioned).toUpperCase(),
+                       "out_S=" + output("S")};
+               runTest(true, false, null, -1);
+       }
+
+       private void runAlignedTernary(String HOME, String TEST_NAME, int r, 
int c, int port1, int port2, int port3, int port4) {
+               double[][] Y1 = getRandomMatrix(r, c, 10, 15, 1, 3);
+               double[][] Y2 = getRandomMatrix(r, c, 10, 15,1, 7);
+               double[][] Y3 = getRandomMatrix(r, c, 10, 15,1, 8);
+               double[][] Y4 = getRandomMatrix(r, c, 10, 15, 1, 9);
+               MatrixCharacteristics mc2 = new MatrixCharacteristics(r, c, 
blocksize, r * c);
+               writeInputMatrixWithMTD("Y1", Y1, false, mc2);
+               writeInputMatrixWithMTD("Y2", Y2, false, mc2);
+               writeInputMatrixWithMTD("Y3", Y3, false, mc2);
+               writeInputMatrixWithMTD("Y4", Y4, false, mc2);
+
+               // Run reference dml script with normal matrix
+               fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+               programArgs = new String[] {"-stats", "100", "-args",
+                       input("X1"), input("X2"), input("X3"), input("X4"),
+                       input("Y1"), input("Y2"), input("Y3"), input("Y4"),
+                       expected("S"), 
Boolean.toString(rowPartitioned).toUpperCase()};
+               runTest(true, false, null, -1);
+
+               // Run actual dml script with federated matrix
+
+               fullDMLScriptName = HOME + TEST_NAME + ".dml";
+               programArgs = new String[] {"-stats", "100", "-nvargs",
+                       "in_X1=" + TestUtils.federatedAddress(port1, 
input("X1")),
+                       "in_X2=" + TestUtils.federatedAddress(port2, 
input("X2")),
+                       "in_X3=" + TestUtils.federatedAddress(port3, 
input("X3")),
+                       "in_X4=" + TestUtils.federatedAddress(port4, 
input("X4")),
+                       "in_Y1=" + TestUtils.federatedAddress(port1, 
input("Y1")),
+                       "in_Y2=" + TestUtils.federatedAddress(port2, 
input("Y2")),
+                       "in_Y3=" + TestUtils.federatedAddress(port3, 
input("Y3")),
+                       "in_Y4=" + TestUtils.federatedAddress(port4, 
input("Y4")), "rows=" + rows, "cols=" + cols,
+                       "rP=" + Boolean.toString(rowPartitioned).toUpperCase(),
+                       "out_S=" + output("S")};
+               runTest(true, false, null, -1);
+       }
+}
diff --git 
a/src/test/scripts/functions/federated/FederatedIfelseAlignedTest.dml 
b/src/test/scripts/functions/federated/FederatedIfelseAlignedTest.dml
new file mode 100644
index 0000000..fc73a80
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedIfelseAlignedTest.dml
@@ -0,0 +1,42 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+if ($rP) {
+ A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+        ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), 
list(2*$rows/4, $cols),
+               list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), 
list($rows, $cols)));
+
+ B = federated(addresses=list($in_Y1, $in_Y2, $in_Y3, $in_Y4),
+        ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), 
list(2*$rows/4, $cols),
+               list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), 
list($rows, $cols)));
+
+} else {
+  A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+             ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4), 
list($rows, $cols/2),
+               list(0,$cols/2), list($rows, 3*($cols/4)), list(0, 
3*($cols/4)), list($rows, $cols)));
+  B = federated(addresses=list($in_Y1, $in_Y2, $in_Y3, $in_Y4),
+             ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4), 
list($rows, $cols/2),
+               list(0,$cols/2), list($rows, 3*($cols/4)), list(0, 
3*($cols/4)), list($rows, $cols)));
+}
+c1 = ifelse(A-3>0, B, A*2);
+c2 = ifelse(A-3>0, B, matrix(1, $rows, $cols));
+s = c1 + c2;
+write(s, $out_S);
diff --git a/scripts/builtin/lmPredict.dml 
b/src/test/scripts/functions/federated/FederatedIfelseAlignedTestReference.dml
similarity index 53%
copy from scripts/builtin/lmPredict.dml
copy to 
src/test/scripts/functions/federated/FederatedIfelseAlignedTestReference.dml
index 1bdabf3..b899a9b 100644
--- a/scripts/builtin/lmPredict.dml
+++ 
b/src/test/scripts/functions/federated/FederatedIfelseAlignedTestReference.dml
@@ -19,24 +19,17 @@
 #
 #-------------------------------------------------------------
 
-m_lmPredict = function(Matrix[Double] X, Matrix[Double] B, 
-  Matrix[Double] ytest = matrix(0,1,1), Integer icpt = 0, Boolean verbose = 
FALSE) 
-  return (Matrix[Double] yhat)
-{
-  intercept = ifelse(icpt==0, matrix(0,1,ncol(B)), B[nrow(B),]);
-  yhat = X %*% B[1:ncol(X)] + matrix(1,nrow(X),1) %*% intercept;
 
-  if( verbose ) {
-    y_residual = ytest - yhat;
-    avg_res = sum(y_residual) / nrow(ytest);
-    ss_res = sum(y_residual^2);
-    ss_avg_res = ss_res - nrow(ytest) * avg_res^2;
-    R2 = 1 - ss_res / (sum(ytest^2) - nrow(ytest) * 
(sum(ytest)/nrow(ytest))^2);
-    print("\nAccuracy:" +
-          "\n--sum(ytest) = " + sum(ytest) +
-          "\n--sum(yhat) = " + sum(yhat) +
-          "\n--AVG_RES_Y: " + avg_res +
-          "\n--SS_AVG_RES_Y: " + ss_avg_res +
-          "\n--R2: " + R2 );
-  }
+if($10) {
+  A = rbind(read($1), read($2), read($3), read($4));
+  B = rbind(read($5), read($6), read($7), read($8));
 }
+else {
+  A = cbind(read($1), read($2), read($3), read($4));
+  B = cbind(read($5), read($6), read($7), read($8));
+}
+
+c1 = ifelse(A-3>0, B, A*2);
+c2 = ifelse(A-3>0, B, matrix(1, nrow(A), ncol(A)));
+s = c1 + c2;
+write(s, $9);
diff --git a/scripts/builtin/lmPredict.dml 
b/src/test/scripts/functions/federated/FederatedIfelseTest.dml
similarity index 53%
copy from scripts/builtin/lmPredict.dml
copy to src/test/scripts/functions/federated/FederatedIfelseTest.dml
index 1bdabf3..da650a3 100644
--- a/scripts/builtin/lmPredict.dml
+++ b/src/test/scripts/functions/federated/FederatedIfelseTest.dml
@@ -19,24 +19,18 @@
 #
 #-------------------------------------------------------------
 
-m_lmPredict = function(Matrix[Double] X, Matrix[Double] B, 
-  Matrix[Double] ytest = matrix(0,1,1), Integer icpt = 0, Boolean verbose = 
FALSE) 
-  return (Matrix[Double] yhat)
-{
-  intercept = ifelse(icpt==0, matrix(0,1,ncol(B)), B[nrow(B),]);
-  yhat = X %*% B[1:ncol(X)] + matrix(1,nrow(X),1) %*% intercept;
-
-  if( verbose ) {
-    y_residual = ytest - yhat;
-    avg_res = sum(y_residual) / nrow(ytest);
-    ss_res = sum(y_residual^2);
-    ss_avg_res = ss_res - nrow(ytest) * avg_res^2;
-    R2 = 1 - ss_res / (sum(ytest^2) - nrow(ytest) * 
(sum(ytest)/nrow(ytest))^2);
-    print("\nAccuracy:" +
-          "\n--sum(ytest) = " + sum(ytest) +
-          "\n--sum(yhat) = " + sum(yhat) +
-          "\n--AVG_RES_Y: " + avg_res +
-          "\n--SS_AVG_RES_Y: " + ss_avg_res +
-          "\n--R2: " + R2 );
-  }
+if ($rP) {
+    A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+        ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), 
list(2*$rows/4, $cols),
+               list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), 
list($rows, $cols)));
+} else {
+    A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+            ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4), 
list($rows, $cols/2),
+               list(0,$cols/2), list($rows, 3*($cols/4)), list(0, 
3*($cols/4)), list($rows, $cols)));
 }
+c1 = ifelse(A>0, A + matrix(1, $rows, $cols), A*2);
+c2 = ifelse(A-3>0, A + matrix(1, $rows, $cols), 3);
+c3 = ifelse(1, matrix(1, $rows, $cols), 3);
+s = c2 + c3;
+s = s + 10*c1;
+write(s, $out_S);
diff --git a/scripts/builtin/lmPredict.dml 
b/src/test/scripts/functions/federated/FederatedIfelseTestReference.dml
similarity index 53%
copy from scripts/builtin/lmPredict.dml
copy to src/test/scripts/functions/federated/FederatedIfelseTestReference.dml
index 1bdabf3..e232cdf 100644
--- a/scripts/builtin/lmPredict.dml
+++ b/src/test/scripts/functions/federated/FederatedIfelseTestReference.dml
@@ -19,24 +19,13 @@
 #
 #-------------------------------------------------------------
 
-m_lmPredict = function(Matrix[Double] X, Matrix[Double] B, 
-  Matrix[Double] ytest = matrix(0,1,1), Integer icpt = 0, Boolean verbose = 
FALSE) 
-  return (Matrix[Double] yhat)
-{
-  intercept = ifelse(icpt==0, matrix(0,1,ncol(B)), B[nrow(B),]);
-  yhat = X %*% B[1:ncol(X)] + matrix(1,nrow(X),1) %*% intercept;
 
-  if( verbose ) {
-    y_residual = ytest - yhat;
-    avg_res = sum(y_residual) / nrow(ytest);
-    ss_res = sum(y_residual^2);
-    ss_avg_res = ss_res - nrow(ytest) * avg_res^2;
-    R2 = 1 - ss_res / (sum(ytest^2) - nrow(ytest) * 
(sum(ytest)/nrow(ytest))^2);
-    print("\nAccuracy:" +
-          "\n--sum(ytest) = " + sum(ytest) +
-          "\n--sum(yhat) = " + sum(yhat) +
-          "\n--AVG_RES_Y: " + avg_res +
-          "\n--SS_AVG_RES_Y: " + ss_avg_res +
-          "\n--R2: " + R2 );
-  }
-}
+if($6) { A = rbind(read($1), read($2), read($3), read($4)); }
+else { A = cbind(read($1), read($2), read($3), read($4)); }
+
+c1 = ifelse(A>0, A + matrix(1, nrow(A), ncol(A)), A*2);
+c2 = ifelse(A-3>0, A + matrix(1, nrow(A), ncol(A)), 3);
+c3 = ifelse(1, matrix(1, nrow(A), ncol(A)), 3);
+s = c2 + c3;
+s = s + 10*c1;
+write(s, $5);

Reply via email to