This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new bd6c642  [SYSTEMDS-2982] Federated codegen w/ aligned federated inputs
bd6c642 is described below

commit bd6c64265ff25e923141166a8f9cd4ecea16caf6
Author: ywcb00 <[email protected]>
AuthorDate: Sun May 30 01:29:00 2021 +0200

    [SYSTEMDS-2982] Federated codegen w/ aligned federated inputs
    
    Closes #1287.
---
 .../controlprogram/caching/CacheableData.java      |   2 +-
 .../controlprogram/federated/FederatedRange.java   |   4 +
 .../controlprogram/federated/FederationMap.java    |  37 ++-
 .../instructions/cp/SpoofCPInstruction.java        |  41 ++-
 .../instructions/fed/SpoofFEDInstruction.java      |  75 +++--
 .../instructions/spark/SpoofSPInstruction.java     |  43 ++-
 .../org/apache/sysds/test/AutomatedTestBase.java   |  28 ++
 .../codegen/FederatedCodegenMultipleFedMOTest.java | 269 +++++++++++++++++
 .../codegen/FederatedCellwiseTmplTest.dml          |   4 +-
 .../codegen/FederatedCellwiseTmplTestReference.dml |   4 +-
 .../codegen/FederatedCodegenMultipleFedMOTest.dml  | 333 +++++++++++++++++++++
 .../FederatedCodegenMultipleFedMOTestReference.dml | 329 ++++++++++++++++++++
 .../FederatedOuterProductTmplTestReference.dml     |   2 +-
 13 files changed, 1108 insertions(+), 63 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
index f7970ff..c4d04d7 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
@@ -373,7 +373,7 @@ public abstract class CacheableData<T extends CacheBlock> 
extends Data
        }
        
        public boolean isFederated(FType type) {
-               return isFederated() && _fedMapping.getType().isType(type);
+               return isFederated() && (type == null || 
_fedMapping.getType().isType(type));
        }
        
        /**
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java
index 4948d27..73636d8 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java
@@ -94,6 +94,10 @@ public class FederatedRange implements 
Comparable<FederatedRange> {
                                return -1;
                        if ( _beginDims[i] > o._beginDims[i])
                                return 1;
+                       if ( _endDims[i] < o._endDims[i])
+                               return -1;
+                       if ( _endDims[i] > o._endDims[i])
+                               return 1;
                }
                return 0;
        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
index 7a52b11..c77ff79 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
@@ -247,7 +247,42 @@ public class FederationMap {
                }
                return ret;
        }
-       
+
+       /**
+        * determines if the two federated data are aligned row/column 
partitions (depending on parameters equalRows/equalCols)
+        * at the same federated site (which often allows for purely federated 
operations)
+        * @param that FederationMap to check alignment with
+        * @param transposed true if that FederationMap should be transposed 
before checking alignment
+        * @param equalRows true to indicate that the row dimension should be 
checked for alignment
+        * @param equalCols true to indicate that the col dimension should be 
checked for alignment
+        * @return true if this and that FederationMap are aligned
+        */
+       public boolean isAligned(FederationMap that, boolean transposed, 
boolean equalRows, boolean equalCols) {
+               boolean ret = true;
+               final int ROW_IX = transposed ? 1 : 0; // swapping row and col 
dimension index of "that" if transposed
+               final int COL_IX = transposed ? 0 : 1;
+
+               for(Pair<FederatedRange, FederatedData> e : _fedMap) {
+                       boolean rangeFound = false; // to indicate if at least 
one matching range has been found
+                       for(FederatedRange r : that.getFederatedRanges()) {
+                               long[] rbd = r.getBeginDims();
+                               long[] red = r.getEndDims();
+                               long[] ebd = e.getKey().getBeginDims();
+                               long[] eed = e.getKey().getEndDims();
+                               // searching for the matching federated range 
of "that"
+                               if((!equalRows || (rbd[ROW_IX] == ebd[0] && 
red[ROW_IX] == eed[0]))
+                                       && (!equalCols || (rbd[COL_IX] == 
ebd[1] && red[COL_IX] == eed[1]))) {
+                                       rangeFound = true;
+                                       FederatedData dat2 = 
that.getFederatedData(r);
+                                       ret &= e.getValue().equalAddress(dat2); 
// both paritions must be located on the same fed worker
+                               }
+                       }
+                       if(!(ret &= rangeFound)) // setting ret to false if no 
matching range has been found
+                               break; // directly returning if not ret to skip 
further checks
+               }
+               return ret;
+       }
+
        public Future<FederatedResponse>[] execute(long tid, 
FederatedRequest... fr) {
                return execute(tid, false, fr);
        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/SpoofCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/SpoofCPInstruction.java
index e9bacd3..38fd8d7 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/SpoofCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/SpoofCPInstruction.java
@@ -27,8 +27,11 @@ import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.common.Types.DataType;
 import org.apache.sysds.runtime.codegen.CodegenUtils;
 import org.apache.sysds.runtime.codegen.SpoofOperator;
+import org.apache.sysds.runtime.codegen.SpoofOuterProduct;
 import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
 import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.lineage.LineageCodegenItem;
@@ -131,16 +134,38 @@ public class SpoofCPInstruction extends 
ComputationCPInstruction {
        }
 
        public boolean isFederated(ExecutionContext ec) {
-               for(CPOperand input : _in)
-                       if( ec.isFederated(input) )
-                               return true;
-               return false;
+               return isFederated(ec, null);
        }
        
        public boolean isFederated(ExecutionContext ec, FType type) {
-               for(CPOperand input : _in)
-                       if( ec.isFederated(input, type) )
-                               return true;
-               return false;
+               FederationMap fedMap = null;
+               boolean retVal = false;
+
+               // flags for alignment check
+               boolean equalRows = false;
+               boolean equalCols = false;
+               boolean transposed = false; // flag indicates to check for 
transposed alignment
+
+               for(CPOperand input : _in) {
+                       Data data = ec.getVariable(input);
+                       if(data instanceof MatrixObject && ((MatrixObject) 
data).isFederated(type)) {
+                               MatrixObject mo = ((MatrixObject) data);
+                               if(fedMap == null) { // first federated matrix
+                                       fedMap = mo.getFedMapping();
+                                       retVal = true;
+
+                                       // setting the constraints for 
alignment check on further federated matrices
+                                       equalRows = mo.isFederated(FType.ROW);
+                                       equalCols = mo.isFederated(FType.COL);
+                                       transposed = 
(getOperatorClass().getSuperclass() == SpoofOuterProduct.class);
+                               }
+                               else if(!fedMap.isAligned(mo.getFedMapping(), 
false, equalRows, equalCols)
+                                       && (!transposed || 
!(fedMap.isAligned(mo.getFedMapping(), true, equalRows, equalCols)
+                                               || 
mo.getFedMapping().isAligned(fedMap, true, equalRows, equalCols)))) {
+                                       retVal = false; // multiple federated 
matrices must be aligned
+                               }
+                       }
+               }
+               return retVal;
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction.java
index 13b1785..918ec00 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction.java
@@ -94,60 +94,51 @@ public class SpoofFEDInstruction extends FEDInstruction
                        throw new DMLRuntimeException("Federated code 
generation only supported" +
                                " for cellwise, rowwise, multiaggregate, and 
outerproduct templates.");
 
-               ArrayList<CPOperand> inCpoMat = new ArrayList<>();
-               ArrayList<CPOperand> inCpoScal = new ArrayList<>();
-               ArrayList<MatrixObject> inMo = new ArrayList<>();
-               ArrayList<ScalarObject> inSo = new ArrayList<>();
+
                FederationMap fedMap = null;
-               for(CPOperand cpo : _inputs) {
+               for(CPOperand cpo : _inputs) { // searching for the first 
federated matrix to obtain the federation map
                        Data tmpData = ec.getVariable(cpo);
-                       if(tmpData instanceof MatrixObject) {
-                               MatrixObject tmp = (MatrixObject) tmpData;
-                               if(fedMap == null & tmp.isFederated()) { //take 
first
-                                       inCpoMat.add(0, cpo); // insert 
federated CPO at the beginning
-                                       fedMap = tmp.getFedMapping();
-                               }
-                               else {
-                                       inCpoMat.add(cpo);
-                                       inMo.add(tmp);
-                               }
-                       }
-                       else if(tmpData instanceof ScalarObject) {
-                               ScalarObject tmp = (ScalarObject) tmpData;
-                               inCpoScal.add(cpo);
-                               inSo.add(tmp);
+                       if(tmpData instanceof MatrixObject && 
((MatrixObject)tmpData).isFederated()) {
+                               fedMap = 
((MatrixObject)tmpData).getFedMapping();
+                               break;
                        }
                }
 
                ArrayList<FederatedRequest> frBroadcast = new ArrayList<>();
                ArrayList<FederatedRequest[]> frBroadcastSliced = new 
ArrayList<>();
-               long[] frIds = new long[1 + inMo.size() + inSo.size()];
+               long[] frIds = new long[_inputs.length];
                int index = 0;
-               frIds[index++] = fedMap.getID(); // insert federation map id at 
the beginning
-               for(MatrixObject mo : inMo) {
-                       if(spoofType.needsBroadcastSliced(fedMap, 
mo.getNumRows(), mo.getNumColumns(), index)) {
-                               FederatedRequest[] tmpFr = 
spoofType.broadcastSliced(mo, fedMap);
-                               frIds[index++] = tmpFr[0].getID();
-                               frBroadcastSliced.add(tmpFr);
+               
+               for(CPOperand cpo : _inputs) {
+                       Data tmpData = ec.getVariable(cpo);
+                       if(tmpData instanceof MatrixObject) {
+                               MatrixObject mo = (MatrixObject) tmpData;
+                               if(mo.isFederated()) {
+                                       frIds[index++] = 
mo.getFedMapping().getID();
+                               }
+                               else if(spoofType.needsBroadcastSliced(fedMap, 
mo.getNumRows(), mo.getNumColumns(), index)) {
+                                       FederatedRequest[] tmpFr = 
spoofType.broadcastSliced(mo, fedMap);
+                                       frIds[index++] = tmpFr[0].getID();
+                                       frBroadcastSliced.add(tmpFr);
+                               }
+                               else {
+                                       FederatedRequest tmpFr = 
fedMap.broadcast(mo);
+                                       frIds[index++] = tmpFr.getID();
+                                       frBroadcast.add(tmpFr);
+                               }
                        }
-                       else {
-                               FederatedRequest tmpFr = fedMap.broadcast(mo);
+                       else if(tmpData instanceof ScalarObject) {
+                               ScalarObject so = (ScalarObject) tmpData;
+                               FederatedRequest tmpFr = fedMap.broadcast(so);
                                frIds[index++] = tmpFr.getID();
                                frBroadcast.add(tmpFr);
                        }
                }
-               for(ScalarObject so : inSo) {
-                       FederatedRequest tmpFr = fedMap.broadcast(so);
-                       frIds[index++] = tmpFr.getID();
-                       frBroadcast.add(tmpFr);
-               }
 
                // change the is_literal flag from true to false because when 
broadcasted it is not a literal anymore
                instString = instString.replace("true", "false");
 
-               CPOperand[] inCpo = ArrayUtils.addAll(inCpoMat.toArray(new 
CPOperand[0]),
-                       inCpoScal.toArray(new CPOperand[0]));
-               FederatedRequest frCompute = 
FederationUtils.callInstruction(instString, _output, inCpo, frIds);
+               FederatedRequest frCompute = 
FederationUtils.callInstruction(instString, _output, _inputs, frIds);
 
                // get partial results from federated workers
                FederatedRequest frGet = new 
FederatedRequest(RequestType.GET_VAR, frCompute.getID());
@@ -184,14 +175,18 @@ public class SpoofFEDInstruction extends FEDInstruction
 
                protected boolean needsBroadcastSliced(FederationMap fedMap, 
long rowNum, long colNum, int inputIndex) {
                        FType fedType = fedMap.getType();
+
                        boolean retVal = (rowNum == 
fedMap.getMaxIndexInRange(0) && colNum == fedMap.getMaxIndexInRange(1));
                        if(fedType == FType.ROW)
-                               retVal |= (rowNum == 
fedMap.getMaxIndexInRange(0) && (colNum == 1 || colNum == fedMap.getSize()));
+                               retVal |= (rowNum == 
fedMap.getMaxIndexInRange(0) 
+                                       && (colNum == 1 || colNum == 
fedMap.getSize() || fedMap.getMaxIndexInRange(1) == 1));
                        else if(fedType == FType.COL)
-                               retVal |= ((rowNum == 1 || rowNum == 
fedMap.getSize()) && colNum == fedMap.getMaxIndexInRange(1));
-                       else
+                               retVal |= (colNum == 
fedMap.getMaxIndexInRange(1)
+                                       && (rowNum == 1 || rowNum == 
fedMap.getSize() || fedMap.getMaxIndexInRange(0) == 1));
+                       else {
                                throw new DMLRuntimeException("Only row 
partitioned or column" +
                                        " partitioned federated input supported 
yet.");
+                       }
                        return retVal;
                }
 
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/SpoofSPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/SpoofSPInstruction.java
index f76d74f..cc6d7bc 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/SpoofSPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/SpoofSPInstruction.java
@@ -39,14 +39,17 @@ import 
org.apache.sysds.runtime.codegen.SpoofOuterProduct.OutProdType;
 import org.apache.sysds.runtime.codegen.SpoofRowwise;
 import org.apache.sysds.runtime.codegen.SpoofRowwise.RowType;
 import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
 import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
 import org.apache.sysds.runtime.functionobjects.Builtin;
 import org.apache.sysds.runtime.functionobjects.Builtin.BuiltinCode;
 import org.apache.sysds.runtime.functionobjects.KahanPlus;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.cp.Data;
 import org.apache.sysds.runtime.instructions.cp.DoubleObject;
 import org.apache.sysds.runtime.instructions.cp.ScalarObject;
 import org.apache.sysds.runtime.instructions.spark.data.PartitionedBroadcast;
@@ -678,16 +681,40 @@ public class SpoofSPInstruction extends SPInstruction {
        }
 
        public boolean isFederated(ExecutionContext ec) {
-               for(CPOperand input : _in)
-                       if( ec.isFederated(input) )
-                               return true;
-               return false;
+               return isFederated(ec, null);
        }
        
        public boolean isFederated(ExecutionContext ec, FType type) {
-               for(CPOperand input : _in)
-                       if( ec.isFederated(input, type) )
-                               return true;
-               return false;
+               //FIXME remove redundancy with SpoofCPInstruction
+               
+               FederationMap fedMap = null;
+               boolean retVal = false;
+
+               // flags for alignment check
+               boolean equalRows = false;
+               boolean equalCols = false;
+               boolean transposed = false; // flag indicates to check for 
transposed alignment
+
+               for(CPOperand input : _in) {
+                       Data data = ec.getVariable(input);
+                       if(data instanceof MatrixObject && ((MatrixObject) 
data).isFederated(type)) {
+                               MatrixObject mo = ((MatrixObject) data);
+                               if(fedMap == null) { // first federated matrix
+                                       fedMap = mo.getFedMapping();
+                                       retVal = true;
+
+                                       // setting the constraints for 
alignment check on further federated matrices
+                                       equalRows = mo.isFederated(FType.ROW);
+                                       equalCols = mo.isFederated(FType.COL);
+                                       transposed = 
(getOperatorClass().getSuperclass() == SpoofOuterProduct.class);
+                               }
+                               else if(!fedMap.isAligned(mo.getFedMapping(), 
false, equalRows, equalCols)
+                                       && (!transposed || 
!(fedMap.isAligned(mo.getFedMapping(), true, equalRows, equalCols)
+                                               || 
mo.getFedMapping().isAligned(fedMap, true, equalRows, equalCols)))) {
+                                       retVal = false; // multiple federated 
matrices must be aligned
+                               }
+                       }
+               }
+               return retVal;
        }
 }
diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java 
b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
index c90892b..40cf34d 100644
--- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
+++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
@@ -2080,6 +2080,20 @@ public abstract class AutomatedTestBase {
                return(count >= minCount);
        }
 
+       protected boolean heavyHittersContainsString(String str, int minCount, 
long minCallCount) {
+               int count = 0;
+               long callCount = Long.MAX_VALUE;
+               for(String opcode : Statistics.getCPHeavyHitterOpCodes()) {
+                       if(opcode.equals(str)) {
+                               count++;
+                               long tmpCallCount = 
Statistics.getCPHeavyHitterCount(opcode);
+                               if(tmpCallCount < callCount)
+                                       callCount = tmpCallCount;
+                       }
+               }
+               return (count >= minCount && callCount >= minCallCount);
+       }
+
        protected boolean heavyHittersContainsSubString(String... str) {
                for(String opcode : Statistics.getCPHeavyHitterOpCodes())
                        for(String s : str)
@@ -2095,6 +2109,20 @@ public abstract class AutomatedTestBase {
                return(count >= minCount);
        }
 
+       protected boolean heavyHittersContainsSubString(String str, int 
minCount, long minCallCount) {
+               int count = 0;
+               long callCount = Long.MAX_VALUE;
+               for(String opcode : Statistics.getCPHeavyHitterOpCodes()) {
+                       if(opcode.contains(str)) {
+                               count++;
+                               long tmpCallCount = 
Statistics.getCPHeavyHitterCount(opcode);
+                               if(tmpCallCount < callCount)
+                                       callCount = tmpCallCount;
+                       }
+               }
+               return (count >= minCount && callCount >= minCallCount);
+       }
+
        protected boolean checkedPrivacyConstraintsContains(PrivacyLevel... 
levels) {
                for(PrivacyLevel level : levels)
                        
if(!(CheckedConstraintsLog.getCheckedConstraints().containsKey(level)))
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCodegenMultipleFedMOTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCodegenMultipleFedMOTest.java
new file mode 100644
index 0000000..65f1728
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCodegenMultipleFedMOTest.java
@@ -0,0 +1,269 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.codegen;
+
+import java.io.File;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.BeforeClass;
+import org.junit.Ignore;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.Test;
+
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.HashMap;
+
+@RunWith(value = Parameterized.class)
[email protected]
+public class FederatedCodegenMultipleFedMOTest extends AutomatedTestBase
+{
+       private final static String TEST_NAME = 
"FederatedCodegenMultipleFedMOTest";
+
+       private final static String TEST_DIR = "functions/federated/codegen/";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
FederatedCodegenMultipleFedMOTest.class.getSimpleName() + "/";
+
+       private final static String TEST_CONF = "SystemDS-config-codegen.xml";
+
+       private final static String OUTPUT_NAME = "Z";
+       private final static double TOLERANCE = 1e-7;
+       private final static int BLOCKSIZE = 1024;
+
+       @Parameterized.Parameter()
+       public int test_num;
+       @Parameterized.Parameter(1)
+       public int rows_x;
+       @Parameterized.Parameter(2)
+       public int cols_x;
+       @Parameterized.Parameter(3)
+       public int rows_y;
+       @Parameterized.Parameter(4)
+       public int cols_y;
+       @Parameterized.Parameter(5)
+       public boolean row_partitioned;
+
+       @Override
+       public void setUp() {
+               addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[]{OUTPUT_NAME}));
+       }
+
+       @Parameterized.Parameters
+       public static Collection<Object[]> data() {
+               // rows must be even for row partitioned X and Y
+               // cols must be even for col partitioned X and Y
+               return Arrays.asList(new Object[][] {
+                       // {test_num, rows_x, cols_x, rows_y, cols_y 
row_partitioned}
+
+                       // cellwise
+                       // row partitioned
+                       {1, 4, 4, 4, 4, true},
+                       // {2, 4, 4, 4, 1, true},
+                       {3, 4, 1, 4, 1, true},
+                       {4, 1000, 1, 1000, 1, true},
+                       // {5, 500, 2, 500, 2, true},
+                       {6, 2, 500, 2, 500, true},
+                       {7, 2, 4, 2, 4, true},
+                       // column partitioned
+                       // {1, 4, 4, 4, 4, false},
+                       {2, 4, 4, 1, 4, false},
+                       {5, 500, 2, 500, 2, false},
+                       // {6, 2, 500, 2, 500, false},
+                       {7, 2, 4, 2, 4, false},
+
+                       // rowwise
+                       // {101, 6, 2, 6, 2, true},
+                       {102, 6, 1, 6, 4, true},
+                       // {103, 6, 4, 6, 2, true},
+                       {104, 150, 10, 150, 10, true},
+
+                       // multi aggregate
+                       // row partitioned
+                       // {201, 6, 4, 6, 4, true},
+                       {202, 6, 4, 6, 4, true},
+                       // {203, 20, 1, 20, 1, true},
+                       // col partitioned
+                       {201, 6, 4, 6, 4, false},
+                       {202, 6, 4, 6, 4, false},
+
+                       // outer product
+                       // row partitioned
+                       // {301, 1500, 1500, 1500, 10, true},
+                       {303, 4000, 2000, 4000, 10, true},
+                       // {305, 4000, 2000, 4000, 10, true},
+                       // {307, 1000, 2000, 1000, 10, true},
+                       // {309, 1000, 2000, 1000, 10, true},
+                       // col partitioned
+                       // {302, 2000, 2000, 10, 2000, false},
+                       // {304, 4000, 2000, 10, 2000, false},
+                       // {306, 4000, 2000, 10, 2000, false},
+                       {308, 1000, 2000, 10, 2000, false},
+                       // {310, 1000, 2000, 10, 2000, false},
+                       // row and col partitioned
+                       // {311, 1000, 2000, 1000, 10, true}, // not working 
yet - ArrayIndexOutOfBoundsException in dotProduct
+                       {312, 1000, 2000, 10, 2000, false},
+                       // {313, 4000, 2000, 4000, 10, true}, // not working 
yet - ArrayIndexOutOfBoundsException in dotProduct
+                       {314, 4000, 2000, 10, 2000, false},
+
+                       // combined tests
+                       {401, 20, 10, 20, 6, true}, // cellwise, rowwise, 
multiaggregate
+                       {402, 2000, 2000, 2000, 10, true}, // outerproduct
+
+               });
+       }
+
+       @BeforeClass
+       public static void init() {
+               TestUtils.clearDirectory(TEST_DATA_DIR + TEST_CLASS_DIR);
+       }
+
+       @Test
+       @Ignore
+       public void federatedCodegenMultipleFedMOSingleNode() {
+               testFederatedCodegenMultipleFedMO(ExecMode.SINGLE_NODE);
+       }
+
+       @Test
+       @Ignore
+       public void federatedCodegenMultipleFedMOSpark() {
+               testFederatedCodegenMultipleFedMO(ExecMode.SPARK);
+       }
+       
+       @Test
+       public void federatedCodegenMultipleFedMOHybrid() {
+               testFederatedCodegenMultipleFedMO(ExecMode.HYBRID);
+       }
+
+       private void testFederatedCodegenMultipleFedMO(ExecMode exec_mode) {
+               // store the previous platform config to restore it after the 
test
+               ExecMode platform_old = setExecMode(exec_mode);
+
+               getAndLoadTestConfiguration(TEST_NAME);
+               String HOME = SCRIPT_DIR + TEST_DIR;
+
+               int fed_rows_x = rows_x;
+               int fed_cols_x = cols_x;
+               int fed_rows_y = rows_y;
+               int fed_cols_y = cols_y;
+               if(row_partitioned) {
+                       fed_rows_x /= 2;
+                       fed_rows_y /= 2;
+               }
+               else {
+                       fed_cols_x /= 2;
+                       fed_cols_y /= 2;
+               }
+
+               // generate dataset
+               // matrix handled by two federated workers
+               double[][] X1 = getRandomMatrix(fed_rows_x, fed_cols_x, 0, 1, 
0.1, 3);
+               double[][] X2 = getRandomMatrix(fed_rows_x, fed_cols_x, 0, 1, 
0.1, 23);
+               // matrix handled by two federated workers
+               double[][] Y1 = getRandomMatrix(fed_rows_y, fed_cols_y, 0, 1, 
0.1, 64);
+               double[][] Y2 = getRandomMatrix(fed_rows_y, fed_cols_y, 0, 1, 
0.1, 135);
+
+               writeInputMatrixWithMTD("X1", X1, false, new 
MatrixCharacteristics(fed_rows_x, fed_cols_x, BLOCKSIZE, fed_rows_x * 
fed_cols_x));
+               writeInputMatrixWithMTD("X2", X2, false, new 
MatrixCharacteristics(fed_rows_x, fed_cols_x, BLOCKSIZE, fed_rows_x * 
fed_cols_x));
+               writeInputMatrixWithMTD("Y1", Y1, false, new 
MatrixCharacteristics(fed_rows_y, fed_cols_y, BLOCKSIZE, fed_rows_y * 
fed_cols_y));
+               writeInputMatrixWithMTD("Y2", Y2, false, new 
MatrixCharacteristics(fed_rows_y, fed_cols_y, BLOCKSIZE, fed_rows_y * 
fed_cols_y));
+
+               // empty script name because we don't execute any script, just 
start the worker
+               fullDMLScriptName = "";
+               int port1 = getRandomAvailablePort();
+               int port2 = getRandomAvailablePort();
+               Thread thread1 = startLocalFedWorkerThread(port1, 
FED_WORKER_WAIT_S);
+               Thread thread2 = startLocalFedWorkerThread(port2);
+
+               getAndLoadTestConfiguration(TEST_NAME);
+
+               // Run reference dml script with normal matrix
+               fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+               programArgs = new String[] {"-stats", "-nvargs",
+                       "in_X1=" + input("X1"), "in_X2=" + input("X2"),
+                       "in_Y1=" + input("Y1"), "in_Y2=" + input("Y2"),
+                       "in_rp=" + 
Boolean.toString(row_partitioned).toUpperCase(),
+                       "in_test_num=" + Integer.toString(test_num),
+                       "out_Z=" + expected(OUTPUT_NAME)};
+               runTest(true, false, null, -1);
+
+               // Run actual dml script with federated matrix
+               fullDMLScriptName = HOME + TEST_NAME + ".dml";
+               programArgs = new String[] {"-stats", "-nvargs",
+                       "in_X1=" + TestUtils.federatedAddress(port1, 
input("X1")),
+                       "in_X2=" + TestUtils.federatedAddress(port2, 
input("X2")),
+                       "in_Y1=" + TestUtils.federatedAddress(port1, 
input("Y1")),
+                       "in_Y2=" + TestUtils.federatedAddress(port2, 
input("Y2")),
+                       "in_rp=" + 
Boolean.toString(row_partitioned).toUpperCase(),
+                       "in_test_num=" + Integer.toString(test_num),
+                       "rows_x=" + rows_x, "cols_x=" + cols_x,
+                       "rows_y=" + rows_y, "cols_y=" + cols_y,
+                       "out_Z=" + output(OUTPUT_NAME)};
+               runTest(true, false, null, -1);
+
+               // compare the results via files
+               HashMap<CellIndex, Double> refResults = 
readDMLMatrixFromExpectedDir(OUTPUT_NAME);
+               HashMap<CellIndex, Double> fedResults = 
readDMLMatrixFromOutputDir(OUTPUT_NAME);
+               TestUtils.compareMatrices(fedResults, refResults, TOLERANCE, 
"Fed", "Ref");
+
+               TestUtils.shutdownThreads(thread1, thread2);
+
+               // check for federated operations
+               if(test_num >= 0 && test_num < 100)
+                       
Assert.assertTrue(heavyHittersContainsSubString("fed_spoofCell"));
+               else if(test_num < 200)
+                       
Assert.assertTrue(heavyHittersContainsSubString("fed_spoofRA"));
+               else if(test_num < 300)
+                       
Assert.assertTrue(heavyHittersContainsSubString("fed_spoofMA"));
+               else if(test_num < 400)
+                       
Assert.assertTrue(heavyHittersContainsSubString("fed_spoofOP"));
+               else if(test_num == 401) {
+                       
Assert.assertTrue(heavyHittersContainsSubString("fed_spoofRA"));
+                       
Assert.assertTrue(heavyHittersContainsSubString("fed_spoofCell"));
+                       
Assert.assertTrue(heavyHittersContainsSubString("fed_spoofMA", exec_mode == 
ExecMode.SPARK ? 0 : 1));
+               }
+               else if(test_num == 402)
+                       
Assert.assertTrue(heavyHittersContainsSubString("fed_spoofOP", 3, exec_mode == 
ExecMode.SPARK? 1 :2));
+
+               // check that federated input files are still existing
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("Y1")));
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("Y2")));
+               
+               resetExecMode(platform_old);
+       }
+
+       /**
+        * Override default configuration with custom test configuration to 
ensure
+        * scratch space and local temporary directory locations are also 
updated.
+        */
+       @Override
+       protected File getConfigTemplateFile() {
+               // Instrumentation in this test's output log to show custom 
configuration file used for template.
+               File TEST_CONF_FILE = new File(SCRIPT_DIR + TEST_DIR, 
TEST_CONF);
+               return TEST_CONF_FILE;
+       }
+}
diff --git 
a/src/test/scripts/functions/federated/codegen/FederatedCellwiseTmplTest.dml 
b/src/test/scripts/functions/federated/codegen/FederatedCellwiseTmplTest.dml
index 3f91385..45f790b 100644
--- a/src/test/scripts/functions/federated/codegen/FederatedCellwiseTmplTest.dml
+++ b/src/test/scripts/functions/federated/codegen/FederatedCellwiseTmplTest.dml
@@ -87,7 +87,7 @@ else if(test_num == 9) {
   Y = matrix(seq(6, 1005), 500, 2);
 
   U = X + 7 * Y;
-  Z = as.matrix(sum(log(U)))
+  Z = as.matrix(sum(log(U)));
 }
 else if(test_num == 10) {
   # X ... 500x2 matrix
@@ -106,7 +106,7 @@ else if(test_num == 12) {
   Y = matrix(seq(6, 1005), 2, 500);
 
   U = X + 7 * Y;
-  Z = as.matrix(sum(sqrt(U)))
+  Z = as.matrix(sum(sqrt(U)));
 }
 else if(test_num == 13) {
   # X ... 2x4 matrix
diff --git 
a/src/test/scripts/functions/federated/codegen/FederatedCellwiseTmplTestReference.dml
 
b/src/test/scripts/functions/federated/codegen/FederatedCellwiseTmplTestReference.dml
index 2c13e6a..e2e3b4b 100644
--- 
a/src/test/scripts/functions/federated/codegen/FederatedCellwiseTmplTestReference.dml
+++ 
b/src/test/scripts/functions/federated/codegen/FederatedCellwiseTmplTestReference.dml
@@ -85,7 +85,7 @@ else if(test_num == 9) {
   Y = matrix(seq(6, 1005), 500, 2);
 
   U = X + 7 * Y;
-  Z = as.matrix(sum(log(U)))
+  Z = as.matrix(sum(log(U)));
 }
 else if(test_num == 10) {
   while(FALSE){} #TODO
@@ -104,7 +104,7 @@ else if(test_num == 12) {
   Y = matrix(seq(6, 1005), 2, 500);
 
   U = X + 7 * Y;
-  Z = as.matrix(sum(sqrt(U)))
+  Z = as.matrix(sum(sqrt(U)));
 }
 else if(test_num == 13) {
   # X ... 2x4 matrix
diff --git 
a/src/test/scripts/functions/federated/codegen/FederatedCodegenMultipleFedMOTest.dml
 
b/src/test/scripts/functions/federated/codegen/FederatedCodegenMultipleFedMOTest.dml
new file mode 100644
index 0000000..5e2b796
--- /dev/null
+++ 
b/src/test/scripts/functions/federated/codegen/FederatedCodegenMultipleFedMOTest.dml
@@ -0,0 +1,333 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+test_num = $in_test_num;
+row_part = $in_rp;
+
+if(row_part) {
+  X = federated(addresses=list($in_X1, $in_X2),
+    ranges=list(list(0, 0), list($rows_x / 2, $cols_x), list($rows_x / 2, 0), 
list($rows_x, $cols_x)));
+  Y = federated(addresses=list($in_Y1, $in_Y2),
+    ranges=list(list(0, 0), list($rows_y / 2, $cols_y), list($rows_y / 2, 0), 
list($rows_y, $cols_y)));
+}
+else {
+  X = federated(addresses=list($in_X1, $in_X2),
+    ranges=list(list(0, 0), list($rows_x, $cols_x / 2), list(0, $cols_x / 2), 
list($rows_x, $cols_x)));
+  Y = federated(addresses=list($in_Y1, $in_Y2),
+    ranges=list(list(0, 0), list($rows_y, $cols_y / 2), list(0, $cols_y / 2), 
list($rows_y, $cols_y)));
+}
+
+if(test_num == 1) { # cellwise #4
+  # X ... 4x4 matrix
+  # Y ... 4x4 matrix
+  w = matrix(3, rows=4, cols=4);
+  Z = test1(X, Y, w);
+}
+else if(test_num == 2) { # cellwise #5
+  # X ... 4x4 matrix
+  # Y ... 4x1 / 1x4 vector
+  U = matrix( "1 2 3 4", rows=4, cols=1);
+  Z = test2(X, Y, U);
+}
+else if(test_num == 3) { # cellwise #6
+  # X ... 4x1 vector
+  # Y ... 4x1 vector
+  v = matrix("3 3 3 3", rows=4, cols=1);
+  Z = test3(X, Y, v);
+}
+else if(test_num == 4) { # cellwise #7
+  # X ... 1000x1 vector
+  # Y ... 1000x1 vector
+  Z = test4(X, Y);
+}
+else if(test_num == 5) { # cellwise #9
+  # X ... 500x2 matrix
+  # Y ... 500x2 matrix
+  Z = test5(X, Y);
+}
+else if(test_num == 6) { # cellwise #12
+  # X ... 2x500 matrix
+  # Y ... 2x500 matrix
+  Z = test6(X, Y);
+}
+else if(test_num == 7) { # cellwise #13
+  # X ... 2x4 matrix
+  # Y ... 2x4 matrix
+  w = matrix(seq(1,8), rows=2, cols=4);
+  Z = test1(X, Y, w);
+}
+else if(test_num == 101) { # rowwise #2
+  # X ... 6x2 matrix
+  # Y ... 6x2 matrix
+  U = matrix(1, rows=2, cols=1);
+  Z = test101(X, Y, U);
+}
+else if(test_num == 102) { # rowwise #3
+  # X ... 6x1 vector
+  # Y ... 6x4 vector
+  U = matrix( "1 2 3 4 5 6", rows=6, cols=1);
+  V = matrix(1,rows=4,cols=1);
+  Z = test102(X, Y, U, V);
+}
+else if(test_num == 103) { # rowwise #4
+  # X ... 6x4 matrix
+  # Y ... 6x2 matrix
+  Z = test103(X, Y);
+}
+else if(test_num == 104) { # rowwise #10
+  # X ... 150x10 matrix
+  # Y ... 150x10 matrix
+  Z = test104(X, Y);
+}
+else if(test_num == 201) { # multiagg #4
+  # X ... 6x4 matrix
+  # Y ... 6x4 matrix
+  Z = test201(X, Y);
+}
+else if(test_num == 202) { # multiagg #5
+  # X ... 6x4 matrix
+  # Y ... 6x4 matrix
+  U = matrix(seq(0,23), rows=6, cols=4);
+  V = matrix(seq(2,25), rows=6, cols=4);
+  Z = test202(X, Y, U, V);
+}
+else if(test_num == 203) { # multiagg #7
+  # X ... 20x1 vector
+  # Y ... 20x1 vector
+  Z = test203(X, Y);
+}
+else if(test_num == 301) { # outerproduct #1
+  # X ... 1500x1500 matrix
+  # Y ... 1500x10 matrix
+  V = matrix(seq(1,15000), rows=1500, cols=10);
+  Z = test301(X, Y, V);
+}
+else if(test_num == 302) { # outerproduct #1
+  # X ... 2000x2000 matrix
+  # Y ... 10x2000 matrix
+  U = matrix(seq(1,20000), rows=2000, cols=10);
+  Z = test301(X, U, t(Y));
+}
+else if(test_num == 303) { # outerproduct #2
+  # X ... 4000x2000 matrix
+  # Y ... 4000x10 matrix
+  V = matrix(seq(51, 20050), rows=2000, cols=10);
+  Z = test303(X, Y, V);
+}
+else if(test_num == 304) { # outerproduct #2
+  # X ... 4000x2000 matrix
+  # Y ... 10x2000 matrix
+  U = matrix(seq(51, 40050), rows=4000, cols=10);
+  Z = test303(X, U, t(Y));
+}
+else if(test_num == 305) { # outerproduct #6
+  # X ... 4000x2000 matrix
+  # Y ... 4000x10 matrix
+  V = matrix(seq(-1, 19998), rows=2000, cols=10);
+  Z = test305(X, Y, V);
+}
+else if(test_num == 306) { # outerproduct #6
+  # X ... 4000x2000 matrix
+  # Y ... 10x2000 matrix
+  U = matrix(seq(1, 40000), rows=4000, cols=10);
+  Z = test305(X, U, t(Y));
+}
+else if(test_num == 307) { # outerproduct #8
+  # X ... 1000x2000 matrix
+  # Y ... 1000x10 matrix
+  V = matrix(seq(1, 20000), rows=2000, cols=10);
+  Z = test307(X, Y, V);
+}
+else if(test_num == 308) { # outerproduct #8
+  # X ... 1000x2000 matrix
+  # Y ... 10x2000 matrix
+  U = matrix(seq(1, 10000), rows=1000, cols=10);
+  Z = test307(X, U, t(Y));
+}
+else if(test_num == 309) { # outerproduct #9
+  # X ... 1000x2000 matrix
+  # Y ... 1000x10 matrix
+  V = matrix(seq(1, 20000), rows=2000, cols=10);
+  Z = test309(X, Y, V);
+}
+else if(test_num == 310) { # outerproduct #9
+  # X ... 1000x2000 matrix
+  # Y ... 10x2000 matrix
+  U = matrix(seq(1, 10000), rows=1000, cols=10);
+  Z = test309(X, U, t(Y));
+}
+else if(test_num == 311) { # outerproduct #8
+  # X ... 1000x2000 matrix
+  # Y ... 1000x10 matrix
+  Y = t(Y); # col partitioned Y
+  while(FALSE) { }
+  # Y ... 10x1000 matrix
+  V = matrix(seq(1, 20000), rows=2000, cols=10);
+  Z = test307(X, t(Y), V);
+}
+else if(test_num == 312) { # outerproduct #8
+  # X ... 1000x2000 matrix
+  Y = t(Y); # row partitioned Y
+  while(FALSE) { }
+  # Y ... 2000x10 matrix
+  U = matrix(seq(1, 10000), rows=1000, cols=10);
+  Z = test307(X, U, Y);
+}
+else if(test_num == 313) {
+  # X ... 4000x2000 matrix
+  # Y ... 4000x10 matrix
+  Y = t(Y); # col partitioned Y
+  while(FALSE) { }
+  # Y ... 10x4000 matrix
+  V = matrix(seq(51, 20050), rows=2000, cols=10);
+  Z = test303(X, t(Y), V);
+}
+else if(test_num == 314) {
+  # X ... 4000x2000 matrix
+  # Y ... 10x2000 matrix
+  Y = t(Y); # row partitioned Y
+  while(FALSE) { }
+  # Y ... 2000x10 matrix
+  U = matrix(seq(51, 40050), rows=4000, cols=10);
+  Z = test303(X, U, Y);
+}
+else if(test_num == 401) { # combined tests
+  # X ... 20x10 matrix
+  # Y ... 20x6 matrix
+  
+  A = test103(X, Y); # not federated output
+  B = test2(X, Y[, 1], t(cbind(A, A)));
+  while(FALSE) { }
+  U = X[6:13, 7:10];
+  V = B[6:13, 3:6];
+  while(FALSE) { }
+  C = test201(U, V);
+  while(FALSE) { }
+  Z = B - C;
+}
+else if(test_num == 402) { # combined outerproduct tests
+  # X ... 2000x2000 matrix
+  # Y ... 2000x10 matrix
+  
+  V = matrix(seq(1,20000), rows=2000, cols=10);
+  A = test301(X, Y, V);
+  while(FALSE) { }
+  B = test305(X, Y, V);
+  while(FALSE) { }
+  C = test309(X, Y, V);
+  while(FALSE) { }
+  X = t(X); # col partitioned X and Y
+  Y = t(Y);
+  while(FALSE) { }
+  U = matrix(seq(1, 20000), rows=2000, cols=10);
+  D = test301(X, U, t(Y));
+  while(FALSE) { }
+  E = test305(X, U, t(Y));
+  while(FALSE) { }
+  F = test309(X, U, t(Y));
+  while(FALSE) { }
+  Z = as.scalar(A) - B + C - as.scalar(D) + E - F;
+}
+
+write(Z, $out_Z);
+
+# ************** Tests defined in functions for reusability **************
+test1 = function(Matrix[Double] X, Matrix[Double] Y, Matrix[Double] w) 
return(Matrix[Double] Z) {
+  Z = 10 + floor(round(abs((X + w) * Y)));
+}
+test2 = function(Matrix[Double] X, Matrix[Double] Y, Matrix[Double] U) 
return(Matrix[Double] Z) {
+  G = abs(exp(X));
+  V = 10 + floor(round(abs((X / Y) + U)));
+  Z = G + V;
+}
+test3 = function(Matrix[Double] X, Matrix[Double] Y, Matrix[Double] v) 
return(Matrix[Double] Z) {
+  Z = as.matrix(sum(X * Y * v));
+}
+test4 = function(Matrix[Double] X, Matrix[Double] Y) return(Matrix[Double] Z) {
+  U = X + Y - 7 + abs(X);
+  Z = t(U) %*% U;
+}
+test5 = function(Matrix[Double] X, Matrix[Double] Y) return(Matrix[Double] Z) {
+  U = X + 7 * Y;
+  Z = as.matrix(sum(log(U)));
+}
+test6 = function(Matrix[Double] X, Matrix[Double] Y) return(Matrix[Double] Z) {
+  U = X + 7 * Y;
+  Z = as.matrix(sum(sqrt(U)));
+}
+
+test101 = function(Matrix[Double] X, Matrix[Double] Y, Matrix[Double] U) 
return(Matrix[Double] Z) {
+  lambda = sum(Y);
+  Z = t(X) %*% (lambda * (X %*% U));
+}
+test102 = function(Matrix[Double] X, Matrix[Double] Y, Matrix[Double] U, 
Matrix[Double] V) return(Matrix[Double] Z) {
+  Z = t(Y) %*% (U + (2 - (X * (Y %*% V))));
+}
+test103 = function(Matrix[Double] X, Matrix[Double] Y) return(Matrix[Double] 
Z) {
+  Z = colSums(X / rowSums(Y));
+}
+test104 = function(Matrix[Double] X, Matrix[Double] Y) return(Matrix[Double] 
Z) {
+  Y = Y + (X <= rowMins(X));
+  U = (Y / rowSums(Y));
+  Z = colSums(U);
+}
+
+test201 = function(Matrix[Double] X, Matrix[Double] Y) return(Matrix[Double] 
Z) {
+  #disjoint partitions with partial shared reads
+  r1 = sum(X * Y);
+  r2 = sum(X ^ 2);
+  r3 = sum(Y ^ 2);
+  Z = as.matrix(r1 + r2 + r3);
+}
+test202 = function(Matrix[Double] X, Matrix[Double] Y, Matrix[Double] U, 
Matrix[Double] V) return(Matrix[Double] Z) {
+  #disjoint partitions with transitive partial shared reads
+  r1 = sum(X * U);
+  r2 = sum(V * Y);
+  r3 = sum(X * V * Y);
+  Z = as.matrix(r1 + r2 + r3);
+}
+test203 = function(Matrix[Double] X, Matrix[Double] Y) return(Matrix[Double] 
Z) {
+  r1 = t(X) %*% X;
+  r2 = t(X) %*% Y;
+  r3 = t(Y) %*% Y;
+  Z = r1 + r2 + r3;
+}
+
+test301 = function(Matrix[Double] X, Matrix[Double] U, Matrix[Double] V) 
return(Matrix[Double] Z) {
+  eps = 0.1;
+  Z = as.matrix(sum(X * log(U %*% t(V) + eps)));
+}
+test303 = function(Matrix[Double] X, Matrix[Double] U, Matrix[Double] V) 
return(Matrix[Double] Z) {
+  eps = 0.1;
+  Z = t(t(U) %*% (X / (U %*% t(V) + eps)));
+}
+test305 = function(Matrix[Double] X, Matrix[Double] U, Matrix[Double] V) 
return(Matrix[Double] Z) {
+  eps = 0.1;
+  Z = (X / ((U %*% t(V)) + eps)) %*% V;
+}
+test307 = function(Matrix[Double] X, Matrix[Double] U, Matrix[Double] V) 
return(Matrix[Double] Z) {
+  eps = 0.1;
+  Z = X * (1 / (1 + exp(-(U %*% t(V)))));
+}
+test309 = function(Matrix[Double] X, Matrix[Double] U, Matrix[Double] V) 
return(Matrix[Double] Z) {
+  eps = 0.4;
+  Z = t(t(U) %*% (X / (U %*% t(V) + eps)));
+}
diff --git 
a/src/test/scripts/functions/federated/codegen/FederatedCodegenMultipleFedMOTestReference.dml
 
b/src/test/scripts/functions/federated/codegen/FederatedCodegenMultipleFedMOTestReference.dml
new file mode 100644
index 0000000..6279e3d
--- /dev/null
+++ 
b/src/test/scripts/functions/federated/codegen/FederatedCodegenMultipleFedMOTestReference.dml
@@ -0,0 +1,329 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+test_num = $in_test_num;
+row_part = $in_rp;
+
+if(row_part) {
+  X = rbind(read($in_X1), read($in_X2));
+  Y = rbind(read($in_Y1), read($in_Y2));
+}
+else {
+  X = cbind(read($in_X1), read($in_X2));
+  Y = cbind(read($in_Y1), read($in_Y2));
+}
+
+if(test_num == 1) { # cellwise #4
+  # X ... 4x4 matrix
+  # Y ... 4x4 matrix
+  w = matrix(3, rows=4, cols=4);
+  Z = test1(X, Y, w);
+}
+else if(test_num == 2) { # cellwise #5
+  # X ... 4x4 matrix
+  # Y ... 4x1 / 1x4 vector
+  U = matrix( "1 2 3 4", rows=4, cols=1);
+  Z = test2(X, Y, U);
+}
+else if(test_num == 3) { # cellwise #6
+  # X ... 4x1 vector
+  # Y ... 4x1 vector
+  v = matrix("3 3 3 3", rows=4, cols=1);
+  Z = test3(X, Y, v);
+}
+else if(test_num == 4) { # cellwise #7
+  # X ... 1000x1 vector
+  # Y ... 1000x1 vector
+  Z = test4(X, Y);
+}
+else if(test_num == 5) { # cellwise #9
+  # X ... 500x2 matrix
+  # Y ... 500x2 matrix
+  Z = test5(X, Y);
+}
+else if(test_num == 6) { # cellwise #12
+  # X ... 2x500 matrix
+  # Y ... 2x500 matrix
+  Z = test6(X, Y);
+}
+else if(test_num == 7) { # cellwise #13
+  # X ... 2x4 matrix
+  # Y ... 2x4 matrix
+  w = matrix(seq(1,8), rows=2, cols=4);
+  Z = test1(X, Y, w);
+}
+else if(test_num == 101) { # rowwise #2
+  # X ... 6x2 matrix
+  # Y ... 6x2 matrix
+  U = matrix(1, rows=2, cols=1);
+  Z = test101(X, Y, U);
+}
+else if(test_num == 102) { # rowwise #3
+  # X ... 6x1 vector
+  # Y ... 6x4 vector
+  U = matrix( "1 2 3 4 5 6", rows=6, cols=1);
+  V = matrix(1,rows=4,cols=1);
+  Z = test102(X, Y, U, V);
+}
+else if(test_num == 103) { # rowwise #4
+  # X ... 6x4 matrix
+  # Y ... 6x2 matrix
+  Z = test103(X, Y);
+}
+else if(test_num == 104) { # rowwise #10
+  # X ... 150x10 matrix
+  # Y ... 150x10 matrix
+  Z = test104(X, Y);
+}
+else if(test_num == 201) { # multiagg #4
+  # X ... 6x4 matrix
+  # Y ... 6x4 matrix
+  Z = test201(X, Y);
+}
+else if(test_num == 202) { # multiagg #5
+  # X ... 6x4 matrix
+  # Y ... 6x4 matrix
+  U = matrix(seq(0,23), rows=6, cols=4);
+  V = matrix(seq(2,25), rows=6, cols=4);
+  Z = test202(X, Y, U, V);
+}
+else if(test_num == 203) { # multiagg #7
+  # X ... 20x1 vector
+  # Y ... 20x1 vector
+  Z = test203(X, Y);
+}
+else if(test_num == 301) { # outerproduct #1
+  # X ... 1500x1500 matrix
+  # Y ... 1500x10 matrix
+  V = matrix(seq(1,15000), rows=1500, cols=10);
+  Z = test301(X, Y, V);
+}
+else if(test_num == 302) { # outerproduct #1
+  # X ... 2000x2000 matrix
+  # Y ... 10x2000 matrix
+  U = matrix(seq(1,20000), rows=2000, cols=10);
+  Z = test301(X, U, t(Y));
+}
+else if(test_num == 303) { # outerproduct #2
+  # X ... 4000x2000 matrix
+  # Y ... 4000x10 matrix
+  V = matrix(seq(51, 20050), rows=2000, cols=10);
+  Z = test303(X, Y, V);
+}
+else if(test_num == 304) { # outerproduct #2
+  # X ... 4000x2000 matrix
+  # Y ... 10x2000 matrix
+  U = matrix(seq(51, 40050), rows=4000, cols=10);
+  Z = test303(X, U, t(Y));
+}
+else if(test_num == 305) { # outerproduct #6
+  # X ... 4000x2000 matrix
+  # Y ... 4000x10 matrix
+  V = matrix(seq(-1, 19998), rows=2000, cols=10);
+  Z = test305(X, Y, V);
+}
+else if(test_num == 306) { # outerproduct #6
+  # X ... 4000x2000 matrix
+  # Y ... 10x2000 matrix
+  U = matrix(seq(1, 40000), rows=4000, cols=10);
+  Z = test305(X, U, t(Y));
+}
+else if(test_num == 307) { # outerproduct #8
+  # X ... 1000x2000 matrix
+  # Y ... 1000x10 matrix
+  V = matrix(seq(1, 20000), rows=2000, cols=10);
+  Z = test307(X, Y, V);
+}
+else if(test_num == 308) { # outerproduct #8
+  # X ... 1000x2000 matrix
+  # Y ... 10x2000 matrix
+  U = matrix(seq(1, 10000), rows=1000, cols=10);
+  Z = test307(X, U, t(Y));
+}
+else if(test_num == 309) { # outerproduct #9
+  # X ... 1000x2000 matrix
+  # Y ... 1000x10 matrix
+  V = matrix(seq(1, 20000), rows=2000, cols=10);
+  Z = test309(X, Y, V);
+}
+else if(test_num == 310) { # outerproduct #9
+  # X ... 1000x2000 matrix
+  # Y ... 10x2000 matrix
+  U = matrix(seq(1, 10000), rows=1000, cols=10);
+  Z = test309(X, U, t(Y));
+}
+else if(test_num == 311) { # outerproduct #8
+  # X ... 1000x2000 matrix
+  # Y ... 1000x10 matrix
+  Y = t(Y); # col partitioned Y
+  while(FALSE) { }
+  # Y ... 10x1000 matrix
+  V = matrix(seq(1, 20000), rows=2000, cols=10);
+  Z = test307(X, t(Y), V);
+}
+else if(test_num == 312) { # outerproduct #8
+  # X ... 1000x2000 matrix
+  Y = t(Y); # row partitioned Y
+  while(FALSE) { }
+  # Y ... 2000x10 matrix
+  U = matrix(seq(1, 10000), rows=1000, cols=10);
+  Z = test307(X, U, Y);
+}
+else if(test_num == 313) {
+  # X ... 4000x2000 matrix
+  # Y ... 4000x10 matrix
+  Y = t(Y); # col partitioned Y
+  while(FALSE) { }
+  # Y ... 10x4000 matrix
+  V = matrix(seq(51, 20050), rows=2000, cols=10);
+  Z = test303(X, t(Y), V);
+}
+else if(test_num == 314) {
+  # X ... 4000x2000 matrix
+  # Y ... 10x2000 matrix
+  Y = t(Y); # row partitioned Y
+  while(FALSE) { }
+  # Y ... 2000x10 matrix
+  U = matrix(seq(51, 40050), rows=4000, cols=10);
+  Z = test303(X, U, Y);
+}
+else if(test_num == 401) { # combined tests
+  # X ... 20x10 matrix
+  # Y ... 20x6 matrix
+  
+  A = test103(X, Y); # not federated output
+  B = test2(X, Y[, 1], t(cbind(A, A)));
+  while(FALSE) { }
+  U = X[6:13, 7:10];
+  V = B[6:13, 3:6];
+  while(FALSE) { }
+  C = test201(U, V);
+  while(FALSE) { }
+  Z = B - C;
+}
+else if(test_num == 402) { # combined outerproduct tests
+  # X ... 2000x2000 matrix
+  # Y ... 2000x10 matrix
+  
+  V = matrix(seq(1,20000), rows=2000, cols=10);
+  A = test301(X, Y, V);
+  while(FALSE) { }
+  B = test305(X, Y, V);
+  while(FALSE) { }
+  C = test309(X, Y, V);
+  while(FALSE) { }
+  X = t(X); # col partitioned X and Y
+  Y = t(Y);
+  while(FALSE) { }
+  U = matrix(seq(1, 20000), rows=2000, cols=10);
+  D = test301(X, U, t(Y));
+  while(FALSE) { }
+  E = test305(X, U, t(Y));
+  while(FALSE) { }
+  F = test309(X, U, t(Y));
+  while(FALSE) { }
+  Z = as.scalar(A) - B + C - as.scalar(D) + E - F;
+}
+
+write(Z, $out_Z);
+
+# ************** Tests defined in functions for reusability **************
+test1 = function(Matrix[Double] X, Matrix[Double] Y, Matrix[Double] w) 
return(Matrix[Double] Z) {
+  Z = 10 + floor(round(abs((X + w) * Y)));
+}
+test2 = function(Matrix[Double] X, Matrix[Double] Y, Matrix[Double] U) 
return(Matrix[Double] Z) {
+  G = abs(exp(X));
+  V = 10 + floor(round(abs((X / Y) + U)));
+  Z = G + V;
+}
+test3 = function(Matrix[Double] X, Matrix[Double] Y, Matrix[Double] v) 
return(Matrix[Double] Z) {
+  Z = as.matrix(sum(X * Y * v));
+}
+test4 = function(Matrix[Double] X, Matrix[Double] Y) return(Matrix[Double] Z) {
+  U = X + Y - 7 + abs(X);
+  Z = t(U) %*% U;
+}
+test5 = function(Matrix[Double] X, Matrix[Double] Y) return(Matrix[Double] Z) {
+  U = X + 7 * Y;
+  Z = as.matrix(sum(log(U)));
+}
+test6 = function(Matrix[Double] X, Matrix[Double] Y) return(Matrix[Double] Z) {
+  U = X + 7 * Y;
+  Z = as.matrix(sum(sqrt(U)));
+}
+
+test101 = function(Matrix[Double] X, Matrix[Double] Y, Matrix[Double] U) 
return(Matrix[Double] Z) {
+  lambda = sum(Y);
+  Z = t(X) %*% (lambda * (X %*% U));
+}
+test102 = function(Matrix[Double] X, Matrix[Double] Y, Matrix[Double] U, 
Matrix[Double] V) return(Matrix[Double] Z) {
+  Z = t(Y) %*% (U + (2 - (X * (Y %*% V))));
+}
+test103 = function(Matrix[Double] X, Matrix[Double] Y) return(Matrix[Double] 
Z) {
+  Z = colSums(X / rowSums(Y));
+}
+test104 = function(Matrix[Double] X, Matrix[Double] Y) return(Matrix[Double] 
Z) {
+  Y = Y + (X <= rowMins(X));
+  U = (Y / rowSums(Y));
+  Z = colSums(U);
+}
+
+test201 = function(Matrix[Double] X, Matrix[Double] Y) return(Matrix[Double] 
Z) {
+  #disjoint partitions with partial shared reads
+  r1 = sum(X * Y);
+  r2 = sum(X ^ 2);
+  r3 = sum(Y ^ 2);
+  Z = as.matrix(r1 + r2 + r3);
+}
+test202 = function(Matrix[Double] X, Matrix[Double] Y, Matrix[Double] U, 
Matrix[Double] V) return(Matrix[Double] Z) {
+  #disjoint partitions with transitive partial shared reads
+  r1 = sum(X * U);
+  r2 = sum(V * Y);
+  r3 = sum(X * V * Y);
+  Z = as.matrix(r1 + r2 + r3);
+}
+test203 = function(Matrix[Double] X, Matrix[Double] Y) return(Matrix[Double] 
Z) {
+  r1 = t(X) %*% X;
+  r2 = t(X) %*% Y;
+  r3 = t(Y) %*% Y;
+  Z = r1 + r2 + r3;
+}
+
+test301 = function(Matrix[Double] X, Matrix[Double] U, Matrix[Double] V) 
return(Matrix[Double] Z) {
+  eps = 0.1;
+  Z = as.matrix(sum(X * log(U %*% t(V) + eps)));
+}
+test303 = function(Matrix[Double] X, Matrix[Double] U, Matrix[Double] V) 
return(Matrix[Double] Z) {
+  eps = 0.1;
+  Z = t(t(U) %*% (X / (U %*% t(V) + eps)));
+}
+test305 = function(Matrix[Double] X, Matrix[Double] U, Matrix[Double] V) 
return(Matrix[Double] Z) {
+  eps = 0.1;
+  Z = (X / ((U %*% t(V)) + eps)) %*% V;
+}
+test307 = function(Matrix[Double] X, Matrix[Double] U, Matrix[Double] V) 
return(Matrix[Double] Z) {
+  eps = 0.1;
+  Z = X * (1 / (1 + exp(-(U %*% t(V)))));
+}
+test309 = function(Matrix[Double] X, Matrix[Double] U, Matrix[Double] V) 
return(Matrix[Double] Z) {
+  eps = 0.4;
+  Z = t(t(U) %*% (X / (U %*% t(V) + eps)));
+}
diff --git 
a/src/test/scripts/functions/federated/codegen/FederatedOuterProductTmplTestReference.dml
 
b/src/test/scripts/functions/federated/codegen/FederatedOuterProductTmplTestReference.dml
index e592dcb..242305c 100644
--- 
a/src/test/scripts/functions/federated/codegen/FederatedOuterProductTmplTestReference.dml
+++ 
b/src/test/scripts/functions/federated/codegen/FederatedOuterProductTmplTestReference.dml
@@ -31,7 +31,7 @@ else {
 
 if(test_num == 1) { # wcemm
   # X ... 2000x2000 matrix
-  
+
   U = matrix(seq(1, 20000), rows=2000, cols=10);
   V = matrix(seq(20001, 40000), rows=2000, cols=10);
   eps = 0.1;

Reply via email to