This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 2576c2e  [SYSTEMDS-2867] Cleanup federated binary operations, incl 
tests
2576c2e is described below

commit 2576c2e9df350f549e6fd9c3463466c9630d923f
Author: ywcb00 <[email protected]>
AuthorDate: Sat Feb 20 18:04:33 2021 +0100

    [SYSTEMDS-2867] Cleanup federated binary operations, incl tests
    
    Closes #1182.
---
 .../instructions/fed/BinaryFEDInstruction.java     |  19 ++
 .../fed/BinaryMatrixMatrixFEDInstruction.java      |  60 ++---
 .../instructions/fed/FEDInstructionUtils.java      |   2 +
 .../federated/primitives/FederatedLogicalTest.java | 257 ++++++++++++++++-----
 .../binary/FederatedLogicalMatrixMatrixTest.dml    |  23 +-
 .../FederatedLogicalMatrixMatrixTestReference.dml  |  15 +-
 .../binary/FederatedLogicalMatrixScalarTest.dml    |  23 +-
 .../FederatedLogicalMatrixScalarTestReference.dml  |  15 +-
 8 files changed, 318 insertions(+), 96 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java
index bfe0c27..9f0c91a 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java
@@ -20,6 +20,9 @@
 package org.apache.sysds.runtime.instructions.fed;
 
 import org.apache.sysds.common.Types.DataType;
+import org.apache.sysds.common.Types.ExecType;
+import org.apache.sysds.lops.BinaryM.VectorType;
+import org.apache.sysds.lops.Lop;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
@@ -33,6 +36,11 @@ public abstract class BinaryFEDInstruction extends 
ComputationFEDInstruction {
        }
 
        public static BinaryFEDInstruction parseInstruction(String str) {
+               if(str.startsWith(ExecType.SPARK.name())) {
+                       // rewrite the spark instruction to a cp instruction
+                       str = rewriteSparkInstructionToCP(str);
+               }
+
                String[] parts = 
InstructionUtils.getInstructionPartsWithValueType(str);
                InstructionUtils.checkNumFields(parts, 3, 4);
                String opcode = parts[0];
@@ -65,4 +73,15 @@ public abstract class BinaryFEDInstruction extends 
ComputationFEDInstruction {
                        throw new DMLRuntimeException("Element-wise matrix 
operations between variables " + in1.getName() +
                                " and " + in2.getName() + " must produce a 
matrix, which " + out.getName() + " is not");
        }
+
+       private static String rewriteSparkInstructionToCP(String inst_str) {
+               // rewrite the spark instruction to a cp instruction
+               inst_str = inst_str.replace(ExecType.SPARK.name(), 
ExecType.CP.name());
+               inst_str = inst_str.replace(Lop.OPERAND_DELIMITOR + "map", 
Lop.OPERAND_DELIMITOR);
+               inst_str = inst_str.replace(Lop.OPERAND_DELIMITOR + "RIGHT", 
"");
+               inst_str = inst_str.replace(Lop.OPERAND_DELIMITOR + 
VectorType.ROW_VECTOR.name(), "");
+               inst_str = inst_str.replace(Lop.OPERAND_DELIMITOR + 
VectorType.COL_VECTOR.name(), "");
+
+               return inst_str;
+       }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
index 0ba1935..6f7dcc9 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
@@ -29,6 +29,7 @@ import org.apache.sysds.runtime.instructions.cp.CPOperand;
 import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
 import org.apache.sysds.runtime.matrix.operators.Operator;
 
+
 public class BinaryMatrixMatrixFEDInstruction extends BinaryFEDInstruction
 {
        protected BinaryMatrixMatrixFEDInstruction(Operator op,
@@ -62,46 +63,49 @@ public class BinaryMatrixMatrixFEDInstruction extends 
BinaryFEDInstruction
                                        + "federated right input are only 
supported for special cases yet.");
                        }
                }
-               else {
-                       //matrix-matrix binary operations -> lhs fed input -> 
fed output
-                       if(mo2.getNumRows() > 1 && mo2.getNumColumns() == 1 ) { 
//MV col vector
-                               FederatedRequest[] fr1 = 
mo1.getFedMapping().broadcastSliced(mo2, false);
+               else { // matrix-matrix binary operations -> lhs fed input -> 
fed output
+                       if(mo1.isFederated(FType.FULL)) {
+                               // full federated (row and col)
+                               if(mo1.getFedMapping().getSize() == 1) {
+                                       // only one partition (MM on a single 
fed worker)
+                                       FederatedRequest fr1 = 
mo1.getFedMapping().broadcast(mo2);
+                                       fr2 = 
FederationUtils.callInstruction(instString, output, new CPOperand[]{input1, 
input2},
+                                       new long[]{mo1.getFedMapping().getID(), 
fr1.getID()});
+                                       FederatedRequest fr3 = 
mo1.getFedMapping().cleanup(getTID(), fr1.getID());
+                                       //execute federated instruction and 
cleanup intermediates
+                                       mo1.getFedMapping().execute(getTID(), 
true, fr1, fr2, fr3);
+                               }
+                               else {
+                                       throw new 
DMLRuntimeException("Matrix-matrix binary operations with a full partitioned 
federated input with multiple partitions are not supported yet.");
+                               }
+                       }
+                       else if((mo1.isFederated(FType.ROW) && mo2.getNumRows() 
== 1 && mo2.getNumColumns() > 1)
+                               || (mo1.isFederated(FType.COL) && 
mo2.getNumRows() > 1 && mo2.getNumColumns() == 1)) {
+                               // MV row partitioned row vector, MV col 
partitioned col vector
+                               FederatedRequest fr1 = 
mo1.getFedMapping().broadcast(mo2);
                                fr2 = 
FederationUtils.callInstruction(instString, output, new CPOperand[]{input1, 
input2},
-                                       new long[]{mo1.getFedMapping().getID(), 
fr1[0].getID()});
-                               FederatedRequest fr3 = 
mo1.getFedMapping().cleanup(getTID(), fr1[0].getID());
+                               new long[]{mo1.getFedMapping().getID(), 
fr1.getID()});
+                               FederatedRequest fr3 = 
mo1.getFedMapping().cleanup(getTID(), fr1.getID());
                                //execute federated instruction and cleanup 
intermediates
                                mo1.getFedMapping().execute(getTID(), true, 
fr1, fr2, fr3);
                        }
-                       else if(mo2.getNumRows() == 1 && mo2.getNumColumns() > 
1) { //MV row vector
-                               FederatedRequest fr1 = 
mo1.getFedMapping().broadcast(mo2);
+                       else if(mo1.isFederated(FType.ROW) ^ 
mo1.isFederated(FType.COL)) {
+                               // row partitioned MM or col partitioned MM
+                               FederatedRequest[] fr1 = 
mo1.getFedMapping().broadcastSliced(mo2, false);
                                fr2 = 
FederationUtils.callInstruction(instString, output, new CPOperand[]{input1, 
input2},
-                                       new long[]{mo1.getFedMapping().getID(), 
fr1.getID()});
-                               FederatedRequest fr3 = 
mo1.getFedMapping().cleanup(getTID(), fr1.getID());
+                                       new long[]{mo1.getFedMapping().getID(), 
fr1[0].getID()});
+                               FederatedRequest fr3 = 
mo1.getFedMapping().cleanup(getTID(), fr1[0].getID());
                                //execute federated instruction and cleanup 
intermediates
                                mo1.getFedMapping().execute(getTID(), true, 
fr1, fr2, fr3);
                        }
-                       else { //MM
-                               if(mo1.isFederated(FType.ROW)) {
-                                       FederatedRequest[] fr1 = 
mo1.getFedMapping().broadcastSliced(mo2, false);
-                                       fr2 = 
FederationUtils.callInstruction(instString, output, new CPOperand[]{input1, 
input2},
-                                               new 
long[]{mo1.getFedMapping().getID(), fr1[0].getID()});
-                                       FederatedRequest fr3 = 
mo1.getFedMapping().cleanup(getTID(), fr1[0].getID());
-                                       //execute federated instruction and 
cleanup intermediates
-                                       mo1.getFedMapping().execute(getTID(), 
true, fr1, fr2, fr3);
-                               }
-                               else {
-                                       FederatedRequest fr1 = 
mo1.getFedMapping().broadcast(mo2);
-                                       fr2 = 
FederationUtils.callInstruction(instString, output, new CPOperand[]{input1, 
input2},
-                                               new 
long[]{mo1.getFedMapping().getID(), fr1.getID()});
-                                       FederatedRequest fr3 = 
mo1.getFedMapping().cleanup(getTID(), fr1.getID());
-                                       //execute federated instruction and 
cleanup intermediates
-                                       mo1.getFedMapping().execute(getTID(), 
true, fr1, fr2, fr3);
-                               }
+                       else {
+                               throw new DMLRuntimeException("Matrix-matrix 
binary operations are only supported with a row partitioned or column 
partitioned federated input yet.");
                        }
                }
 
-               //derive new fed mapping for output
+               // derive new fed mapping for output
                MatrixObject out = ec.getMatrixObject(output);
+
                out.getDataCharacteristics().set(mo1.getDataCharacteristics());
                
out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr2.getID()));
        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index 2a608f3..613ff31 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -45,6 +45,7 @@ import 
org.apache.sysds.runtime.instructions.cp.VariableCPInstruction.VariableOp
 import org.apache.sysds.runtime.instructions.spark.AggregateUnarySPInstruction;
 import org.apache.sysds.runtime.instructions.spark.AppendGAlignedSPInstruction;
 import org.apache.sysds.runtime.instructions.spark.AppendGSPInstruction;
+import 
org.apache.sysds.runtime.instructions.spark.BinaryMatrixBVectorSPInstruction;
 import 
org.apache.sysds.runtime.instructions.spark.BinaryMatrixMatrixSPInstruction;
 import 
org.apache.sysds.runtime.instructions.spark.BinaryMatrixScalarSPInstruction;
 import org.apache.sysds.runtime.instructions.spark.BinarySPInstruction;
@@ -266,6 +267,7 @@ public class FEDInstructionUtils {
                        }
                        else if (inst instanceof BinaryMatrixScalarSPInstruction
                                || inst instanceof 
BinaryMatrixMatrixSPInstruction
+                               || inst instanceof 
BinaryMatrixBVectorSPInstruction
                                || inst instanceof 
BinaryTensorTensorSPInstruction
                                || inst instanceof 
BinaryTensorTensorBroadcastSPInstruction) {
                                BinarySPInstruction instruction = 
(BinarySPInstruction) inst;
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedLogicalTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedLogicalTest.java
index 53dfb2e..a79e3b7 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedLogicalTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedLogicalTest.java
@@ -36,6 +36,12 @@ import java.util.Arrays;
 import java.util.Collection;
 import java.util.HashMap;
 
+/*
+ * Testing following logical operations:
+ *   >, <, ==, !=, >=, <=
+ * with a row/col partitioned federated matrix X and a scalar/vector/matrix Y
+*/
+
 @RunWith(value = Parameterized.class)
 @net.jcip.annotations.NotThreadSafe
 public class FederatedLogicalTest extends AutomatedTestBase
@@ -47,9 +53,9 @@ public class FederatedLogicalTest extends AutomatedTestBase
 
        private final static String OUTPUT_NAME = "Z";
        private final static double TOLERANCE = 0;
-       private final static int blocksize = 1024;
+       private final static int BLOCKSIZE = 1024;
 
-       public enum Type{
+       private enum Type {
                GREATER,
                LESS,
                EQUALS,
@@ -58,12 +64,29 @@ public class FederatedLogicalTest extends AutomatedTestBase
                LESS_EQUALS
        }
 
+       private enum FederationType {
+               SINGLE_FED_WORKER,
+               ROW_PARTITIONED,
+               COL_PARTITIONED,
+               FULL_PARTITIONED
+       }
+
+       private enum YType {
+               MATRIX,
+               ROW_VEC,
+               COL_VEC
+       }
+
        @Parameterized.Parameter()
        public int rows;
        @Parameterized.Parameter(1)
        public int cols;
        @Parameterized.Parameter(2)
        public double sparsity;
+       @Parameterized.Parameter(3)
+       public FederationType fed_type;
+       @Parameterized.Parameter(4)
+       public YType y_type;
 
        @Override
        public void setUp() {
@@ -73,13 +96,73 @@ public class FederatedLogicalTest extends AutomatedTestBase
 
        @Parameterized.Parameters
        public static Collection<Object[]> data() {
-               // rows must be even
+               // rows must be divisable by 4 for row partitioned data
+               // cols must be divisable by 4 for col partitioned data
+               // rows and cols must be divisable by 2 for full partitioned 
data
                return Arrays.asList(new Object[][] {
-                       // {rows, cols, sparsity}
-                       {100, 75, 0.01},
-                       {100, 75, 0.9},
-                       {2, 75, 0.01},
-                       {2, 75, 0.9}
+                       // {rows, cols, sparsity, fed_type, y_type}
+
+                       // row partitioned MM
+                       {100, 75, 0.01, FederationType.ROW_PARTITIONED, 
YType.MATRIX},
+                       {100, 75, 0.9, FederationType.ROW_PARTITIONED, 
YType.MATRIX},
+                       // {4, 75, 0.01, FederationType.ROW_PARTITIONED, 
YType.MATRIX},
+                       // {4, 75, 0.9, FederationType.ROW_PARTITIONED, 
YType.MATRIX},
+                       // {100, 1, 0.01, FederationType.ROW_PARTITIONED, 
YType.MATRIX},
+                       // {100, 1, 0.9, FederationType.ROW_PARTITIONED, 
YType.MATRIX},
+
+                       // row partitioned MV row vector
+                       {100, 75, 0.01, FederationType.ROW_PARTITIONED, 
YType.ROW_VEC},
+                       {100, 75, 0.9, FederationType.ROW_PARTITIONED, 
YType.ROW_VEC},
+                       // {4, 75, 0.01, FederationType.ROW_PARTITIONED, 
YType.ROW_VEC},
+                       // {4, 75, 0.9, FederationType.ROW_PARTITIONED, 
YType.ROW_VEC},
+                       // {100, 1, 0.01, FederationType.ROW_PARTITIONED, 
YType.ROW_VEC},
+                       // {100, 1, 0.9, FederationType.ROW_PARTITIONED, 
YType.ROW_VEC},
+
+                       // row partitioned MV col vector
+                       {100, 75, 0.01, FederationType.ROW_PARTITIONED, 
YType.COL_VEC},
+                       {100, 75, 0.9, FederationType.ROW_PARTITIONED, 
YType.COL_VEC},
+                       // {4, 75, 0.01, FederationType.ROW_PARTITIONED, 
YType.COL_VEC},
+                       // {4, 75, 0.9, FederationType.ROW_PARTITIONED, 
YType.COL_VEC},
+                       // {100, 1, 0.01, FederationType.ROW_PARTITIONED, 
YType.COL_VEC},
+                       // {100, 1, 0.9, FederationType.ROW_PARTITIONED, 
YType.COL_VEC},
+
+                       // col partitioned MM
+                       {100, 76, 0.01, FederationType.COL_PARTITIONED, 
YType.MATRIX},
+                       {100, 76, 0.9, FederationType.COL_PARTITIONED, 
YType.MATRIX},
+                       // {1, 76, 0.01, FederationType.COL_PARTITIONED, 
YType.MATRIX},
+                       // {1, 76, 0.9, FederationType.COL_PARTITIONED, 
YType.MATRIX},
+                       // {100, 4, 0.01, FederationType.COL_PARTITIONED, 
YType.MATRIX},
+                       // {100, 4, 0.9, FederationType.COL_PARTITIONED, 
YType.MATRIX},
+
+                       // col partitioned MV row vector
+                       {100, 76, 0.01, FederationType.COL_PARTITIONED, 
YType.ROW_VEC},
+                       {100, 76, 0.9, FederationType.COL_PARTITIONED, 
YType.ROW_VEC},
+                       // {1, 76, 0.01, FederationType.COL_PARTITIONED, 
YType.ROW_VEC},
+                       // {1, 76, 0.9, FederationType.COL_PARTITIONED, 
YType.ROW_VEC},
+                       // {100, 4, 0.01, FederationType.COL_PARTITIONED, 
YType.ROW_VEC},
+                       // {100, 4, 0.9, FederationType.COL_PARTITIONED, 
YType.ROW_VEC},
+
+                       // col partitioned MV col vector
+                       {100, 76, 0.01, FederationType.COL_PARTITIONED, 
YType.COL_VEC},
+                       {100, 76, 0.9, FederationType.COL_PARTITIONED, 
YType.COL_VEC},
+                       // {1, 76, 0.01, FederationType.COL_PARTITIONED, 
YType.COL_VEC},
+                       // {1, 76, 0.9, FederationType.COL_PARTITIONED, 
YType.COL_VEC},
+                       // {100, 4, 0.01, FederationType.COL_PARTITIONED, 
YType.COL_VEC},
+                       // {100, 4, 0.9, FederationType.COL_PARTITIONED, 
YType.COL_VEC},
+
+                       // single federated worker MM
+                       {100, 75, 0.01, FederationType.SINGLE_FED_WORKER, 
YType.MATRIX},
+                       {100, 75, 0.9, FederationType.SINGLE_FED_WORKER, 
YType.MATRIX},
+                       // {1, 75, 0.01, FederationType.SINGLE_FED_WORKER, 
YType.MATRIX},
+                       // {1, 75, 0.9, FederationType.SINGLE_FED_WORKER, 
YType.MATRIX},
+                       // {100, 1, 0.01, FederationType.SINGLE_FED_WORKER, 
YType.MATRIX},
+                       // {100, 1, 0.9, FederationType.SINGLE_FED_WORKER, 
YType.MATRIX},
+
+                       // full partitioned (not supported yet)
+                       // {70, 80, 0.01, FederationType.FULL_PARTITIONED, 
YType.MATRIX},
+                       // {70, 80, 0.9, FederationType.FULL_PARTITIONED, 
YType.MATRIX},
+                       // {2, 2, 0.01, FederationType.FULL_PARTITIONED, 
YType.MATRIX},
+                       // {2, 2, 0.9, FederationType.FULL_PARTITIONED, 
YType.MATRIX}
                });
        }
 
@@ -99,15 +182,15 @@ public class FederatedLogicalTest extends AutomatedTestBase
                federatedLogicalTest(SCALAR_TEST_NAME, Type.GREATER, 
ExecMode.SPARK);
        }
 
-       @Test
-       public void federatedLogicalScalarLessSingleNode() {
-               federatedLogicalTest(SCALAR_TEST_NAME, Type.LESS, 
ExecMode.SINGLE_NODE);
-       }
-
-       @Test
-       public void federatedLogicalScalarLessSpark() {
-               federatedLogicalTest(SCALAR_TEST_NAME, Type.LESS, 
ExecMode.SPARK);
-       }
+//     @Test
+//     public void federatedLogicalScalarLessSingleNode() {
+//             federatedLogicalTest(SCALAR_TEST_NAME, Type.LESS, 
ExecMode.SINGLE_NODE);
+//     }
+//
+//     @Test
+//     public void federatedLogicalScalarLessSpark() {
+//             federatedLogicalTest(SCALAR_TEST_NAME, Type.LESS, 
ExecMode.SPARK);
+//     }
 
        @Test
        public void federatedLogicalScalarEqualsSingleNode() {
@@ -139,15 +222,15 @@ public class FederatedLogicalTest extends 
AutomatedTestBase
                federatedLogicalTest(SCALAR_TEST_NAME, Type.GREATER_EQUALS, 
ExecMode.SPARK);
        }
 
-       @Test
-       public void federatedLogicalScalarLessEqualsSingleNode() {
-               federatedLogicalTest(SCALAR_TEST_NAME, Type.LESS_EQUALS, 
ExecMode.SINGLE_NODE);
-       }
-
-       @Test
-       public void federatedLogicalScalarLessEqualsSpark() {
-               federatedLogicalTest(SCALAR_TEST_NAME, Type.LESS_EQUALS, 
ExecMode.SPARK);
-       }
+//     @Test
+//     public void federatedLogicalScalarLessEqualsSingleNode() {
+//             federatedLogicalTest(SCALAR_TEST_NAME, Type.LESS_EQUALS, 
ExecMode.SINGLE_NODE);
+//     }
+//
+//     @Test
+//     public void federatedLogicalScalarLessEqualsSpark() {
+//             federatedLogicalTest(SCALAR_TEST_NAME, Type.LESS_EQUALS, 
ExecMode.SPARK);
+//     }
 
        //---------------------------MATRIX MATRIX--------------------------
        @Test
@@ -160,15 +243,15 @@ public class FederatedLogicalTest extends 
AutomatedTestBase
                federatedLogicalTest(MATRIX_TEST_NAME, Type.GREATER, 
ExecMode.SPARK);
        }
 
-       @Test
-       public void federatedLogicalMatrixLessSingleNode() {
-               federatedLogicalTest(MATRIX_TEST_NAME, Type.LESS, 
ExecMode.SINGLE_NODE);
-       }
-
-       @Test
-       public void federatedLogicalMatrixLessSpark() {
-               federatedLogicalTest(MATRIX_TEST_NAME, Type.LESS, 
ExecMode.SPARK);
-       }
+//     @Test
+//     public void federatedLogicalMatrixLessSingleNode() {
+//             federatedLogicalTest(MATRIX_TEST_NAME, Type.LESS, 
ExecMode.SINGLE_NODE);
+//     }
+//
+//     @Test
+//     public void federatedLogicalMatrixLessSpark() {
+//             federatedLogicalTest(MATRIX_TEST_NAME, Type.LESS, 
ExecMode.SPARK);
+//     }
 
        @Test
        public void federatedLogicalMatrixEqualsSingleNode() {
@@ -200,15 +283,15 @@ public class FederatedLogicalTest extends 
AutomatedTestBase
                federatedLogicalTest(MATRIX_TEST_NAME, Type.GREATER_EQUALS, 
ExecMode.SPARK);
        }
 
-       @Test
-       public void federatedLogicalMatrixLessEqualsSingleNode() {
-               federatedLogicalTest(MATRIX_TEST_NAME, Type.LESS_EQUALS, 
ExecMode.SINGLE_NODE);
-       }
-
-       @Test
-       public void federatedLogicalMatrixLessEqualsSpark() {
-               federatedLogicalTest(MATRIX_TEST_NAME, Type.LESS_EQUALS, 
ExecMode.SPARK);
-       }
+//     @Test
+//     public void federatedLogicalMatrixLessEqualsSingleNode() {
+//             federatedLogicalTest(MATRIX_TEST_NAME, Type.LESS_EQUALS, 
ExecMode.SINGLE_NODE);
+//     }
+//
+//     @Test
+//     public void federatedLogicalMatrixLessEqualsSpark() {
+//             federatedLogicalTest(MATRIX_TEST_NAME, Type.LESS_EQUALS, 
ExecMode.SPARK);
+//     }
 
 // 
-----------------------------------------------------------------------------
 
@@ -220,39 +303,78 @@ public class FederatedLogicalTest extends 
AutomatedTestBase
                getAndLoadTestConfiguration(testname);
                String HOME = SCRIPT_DIR + TEST_DIR;
 
-               int fed_rows = rows / 2;
-               int fed_cols = cols;
+               int fed_rows = 0;
+               int fed_cols = 0;
+               switch(fed_type) {
+                       case SINGLE_FED_WORKER:
+                               fed_rows = rows;
+                               fed_cols = cols;
+                               break;
+                       case ROW_PARTITIONED:
+                               fed_rows = rows / 4;
+                               fed_cols = cols;
+                               break;
+                       case COL_PARTITIONED:
+                               fed_rows = rows;
+                               fed_cols = cols / 4;
+                               break;
+                       case FULL_PARTITIONED:
+                               fed_rows = rows / 2;
+                               fed_cols = cols / 2;
+                               break;
+               }
+
+               boolean single_fed_worker = (fed_type == 
FederationType.SINGLE_FED_WORKER);
 
                // generate dataset
-               // matrix handled by two federated workers
-               double[][] X1 = getRandomMatrix(fed_rows, fed_cols, 1, 2, 1, 
13);
-               double[][] X2 = getRandomMatrix(fed_rows, fed_cols, 1, 2, 1, 2);
-
-               writeInputMatrixWithMTD("X1", X1, false, new 
MatrixCharacteristics(fed_rows, fed_cols, blocksize, fed_rows * fed_cols));
-               writeInputMatrixWithMTD("X2", X2, false, new 
MatrixCharacteristics(fed_rows, fed_cols, blocksize, fed_rows * fed_cols));
+               // matrix handled by four federated workers
+               // X2, X3, X4 not used if single_fed_worker == true
+               double[][] X1 = getRandomMatrix(fed_rows, fed_cols, 0, 1, 
sparsity, 13);
+               double[][] X2 = (!single_fed_worker ? getRandomMatrix(fed_rows, 
fed_cols, 0, 1, sparsity, 2) : null);
+               double[][] X3 = (!single_fed_worker ? getRandomMatrix(fed_rows, 
fed_cols, 0, 1, sparsity, 211) : null);
+               double[][] X4 = (!single_fed_worker ? getRandomMatrix(fed_rows, 
fed_cols, 0, 1, sparsity, 65) : null);
+
+               writeInputMatrixWithMTD("X1", X1, false, new 
MatrixCharacteristics(fed_rows, fed_cols, BLOCKSIZE, fed_rows * fed_cols));
+               if(!single_fed_worker) {
+                       writeInputMatrixWithMTD("X2", X2, false, new 
MatrixCharacteristics(fed_rows, fed_cols, BLOCKSIZE, fed_rows * fed_cols));
+                       writeInputMatrixWithMTD("X3", X3, false, new 
MatrixCharacteristics(fed_rows, fed_cols, BLOCKSIZE, fed_rows * fed_cols));
+                       writeInputMatrixWithMTD("X4", X4, false, new 
MatrixCharacteristics(fed_rows, fed_cols, BLOCKSIZE, fed_rows * fed_cols));
+               }
 
                boolean is_matrix_test = testname.equals(MATRIX_TEST_NAME);
 
                double[][] Y_mat = null;
                double Y_scal = 0;
                if(is_matrix_test) {
-                       Y_mat = getRandomMatrix(rows, cols, 0, 1, sparsity, 
5040);
-                       writeInputMatrixWithMTD("Y", Y_mat, true);
+                       int y_rows = (y_type == YType.ROW_VEC ? 1 : rows);
+                       int y_cols = (y_type == YType.COL_VEC ? 1 : cols);
+
+                       Y_mat = getRandomMatrix(y_rows, y_cols, 0, 1, sparsity, 
5040);
+                       writeInputMatrixWithMTD("Y", Y_mat, false, new 
MatrixCharacteristics(y_rows, y_cols, BLOCKSIZE, y_rows * y_cols));
                }
 
                // empty script name because we don't execute any script, just 
start the worker
                fullDMLScriptName = "";
                int port1 = getRandomAvailablePort();
-               int port2 = getRandomAvailablePort();
-               Thread thread1 = startLocalFedWorkerThread(port1, 
FED_WORKER_WAIT_S);
-               Thread thread2 = startLocalFedWorkerThread(port2);
+               int port2 = (!single_fed_worker ? getRandomAvailablePort() : 0);
+               int port3 = (!single_fed_worker ? getRandomAvailablePort() : 0);
+               int port4 = (!single_fed_worker ? getRandomAvailablePort() : 0);
+               Thread thread1 = startLocalFedWorkerThread(port1, 
(!single_fed_worker ? FED_WORKER_WAIT_S : FED_WORKER_WAIT));
+               Thread thread2 = (!single_fed_worker ? 
startLocalFedWorkerThread(port2, FED_WORKER_WAIT_S) : null);
+               Thread thread3 = (!single_fed_worker ? 
startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S) : null);
+               Thread thread4 = (!single_fed_worker ? 
startLocalFedWorkerThread(port4) : null);
 
                getAndLoadTestConfiguration(testname);
 
                // Run reference dml script with normal matrix
                fullDMLScriptName = HOME + testname + "Reference.dml";
-               programArgs = new String[] {"-nvargs", "in_X1=" + input("X1"), 
"in_X2=" + input("X2"),
+               programArgs = new String[] {"-nvargs",
+                       "in_X1=" + input("X1"),
+                       "in_X2=" + (!single_fed_worker ? input("X2") : 
input("X1")), // not needed in case of a single federated worker
+                       "in_X3=" + (!single_fed_worker ? input("X3") : 
input("X1")), // not needed in case of a single federated worker
+                       "in_X4=" + (!single_fed_worker ? input("X4") : 
input("X1")), // not needed in case of a single federated worker
                        "in_Y=" + (is_matrix_test ? input("Y") : 
Double.toString(Y_scal)),
+                       "in_fed_type=" + Integer.toString(fed_type.ordinal()),
                        "in_op_type=" + Integer.toString(op_type.ordinal()),
                        "out_Z=" + expected(OUTPUT_NAME)};
                runTest(true, false, null, -1);
@@ -260,10 +382,15 @@ public class FederatedLogicalTest extends 
AutomatedTestBase
                // Run actual dml script with federated matrix
                fullDMLScriptName = HOME + testname + ".dml";
                programArgs = new String[] {"-stats", "-nvargs",
-                       "in_X1=" + TestUtils.federatedAddress(port1, 
input("X1")), "in_X2=" + TestUtils.federatedAddress(port2, input("X2")),
+                       "in_X1=" + TestUtils.federatedAddress(port1, 
input("X1")),
+                       "in_X2=" + (!single_fed_worker ? 
TestUtils.federatedAddress(port2, input("X2")) : null),
+                       "in_X3=" + (!single_fed_worker ? 
TestUtils.federatedAddress(port3, input("X3")) : null),
+                       "in_X4=" + (!single_fed_worker ? 
TestUtils.federatedAddress(port4, input("X4")) : null),
                        "in_Y=" + (is_matrix_test ? input("Y") : 
Double.toString(Y_scal)),
+                       "in_fed_type=" + Integer.toString(fed_type.ordinal()),
                        "in_op_type=" + Integer.toString(op_type.ordinal()),
-                       "rows=" + fed_rows, "cols=" + fed_cols, "out_Z=" + 
output(OUTPUT_NAME)};
+                       "rows=" + Integer.toString(fed_rows), "cols=" + 
Integer.toString(fed_cols),
+                       "out_Z=" + output(OUTPUT_NAME)};
                runTest(true, false, null, -1);
 
                // compare the results via files
@@ -271,7 +398,9 @@ public class FederatedLogicalTest extends AutomatedTestBase
                HashMap<CellIndex, Double> fedResults = 
readDMLMatrixFromOutputDir(OUTPUT_NAME);
                TestUtils.compareMatrices(fedResults, refResults, TOLERANCE, 
"Fed", "Ref");
 
-               TestUtils.shutdownThreads(thread1, thread2);
+               TestUtils.shutdownThreads(thread1);
+               if(!single_fed_worker)
+                       TestUtils.shutdownThreads(thread2, thread3, thread4);
 
                // check for federated operations
                switch(op_type)
@@ -298,7 +427,11 @@ public class FederatedLogicalTest extends AutomatedTestBase
 
                // check that federated input files are still existing
                Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
-               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
+               if(!single_fed_worker) {
+                       
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
+                       
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X3")));
+                       
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4")));
+               }
 
                resetExecMode(platform_old);
        }
diff --git 
a/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixMatrixTest.dml
 
b/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixMatrixTest.dml
index 7fe350c..b84ff6f 100644
--- 
a/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixMatrixTest.dml
+++ 
b/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixMatrixTest.dml
@@ -19,8 +19,27 @@
 #
 #-------------------------------------------------------------
 
-X = federated(addresses=list($in_X1, $in_X2),
-  ranges=list(list(0, 0), list($rows, $cols), list($rows, 0), list($rows * 2, 
$cols)));
+fed_type = $in_fed_type;
+
+if(fed_type == 0) { # single federated worker
+  X = federated(addresses=list($in_X1),
+    ranges=list(list(0, 0), list($rows, $cols)));
+}
+else if(fed_type == 1) { # row partitioned
+  X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+    ranges=list(list(0, 0), list($rows, $cols), list($rows, 0), list($rows * 
2, $cols),
+      list($rows * 2, 0), list($rows * 3, $cols), list($rows * 3, 0), 
list($rows * 4, $cols)));
+}
+else if(fed_type == 2) { # col partitioned
+  X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+    ranges=list(list(0, 0), list($rows, $cols), list(0, $cols), list($rows, 
$cols * 2),
+      list(0, $cols * 2), list($rows, $cols * 3), list(0, $cols * 3), 
list($rows, $cols * 4)));
+}
+else if(fed_type == 3) { # full partitioned
+  X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+    ranges=list(list(0, 0), list($rows, $cols), list(0, $cols), list($rows, 
$cols * 2),
+      list($rows, 0), list($rows * 2, $cols), list($rows, $cols), list($rows * 
2, $cols * 2)));
+}
 
 Y = read($in_Y);
 op_type = $in_op_type;
diff --git 
a/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixMatrixTestReference.dml
 
b/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixMatrixTestReference.dml
index 1285eb0..e217ae8 100644
--- 
a/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixMatrixTestReference.dml
+++ 
b/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixMatrixTestReference.dml
@@ -19,7 +19,20 @@
 #
 #-------------------------------------------------------------
 
-X = rbind(read($in_X1), read($in_X2));
+fed_type = $in_fed_type;
+
+if(fed_type == 0) { # single federated worker
+  X = read($in_X1);
+}
+else if(fed_type == 1) { # row partitioned
+  X = rbind(read($in_X1), read($in_X2), read($in_X3), read($in_X4));
+}
+else if(fed_type == 2) { # col partitioned
+  X = cbind(read($in_X1), read($in_X2), read($in_X3), read($in_X4));
+}
+else if(fed_type == 3) { # full partitioned
+  X = rbind(cbind(read($in_X1), read($in_X2)), cbind(read($in_X3), 
read($in_X4)));
+}
 
 Y = read($in_Y);
 op_type = $in_op_type;
diff --git 
a/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixScalarTest.dml
 
b/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixScalarTest.dml
index 1dfb762..b4a520c 100644
--- 
a/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixScalarTest.dml
+++ 
b/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixScalarTest.dml
@@ -19,8 +19,27 @@
 #
 #-------------------------------------------------------------
 
-X = federated(addresses=list($in_X1, $in_X2),
-  ranges=list(list(0, 0), list($rows, $cols), list($rows, 0), list($rows * 2, 
$cols)));
+fed_type = $in_fed_type;
+
+if(fed_type == 0) { # single federated worker
+  X = federated(addresses=list($in_X1),
+    ranges=list(list(0, 0), list($rows, $cols)));
+}
+else if(fed_type == 1) { # row partitioned
+  X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+    ranges=list(list(0, 0), list($rows, $cols), list($rows, 0), list($rows * 
2, $cols),
+      list($rows * 2, 0), list($rows * 3, $cols), list($rows * 3, 0), 
list($rows * 4, $cols)));
+}
+else if(fed_type == 2) { # col partitioned
+  X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+    ranges=list(list(0, 0), list($rows, $cols), list(0, $cols), list($rows, 
$cols * 2),
+      list(0, $cols * 2), list($rows, $cols * 3), list(0, $cols * 3), 
list($rows, $cols * 4)));
+}
+else if(fed_type == 3) { # full partitioned
+  X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+    ranges=list(list(0, 0), list($rows, $cols), list(0, $cols), list($rows, 
$cols * 2),
+      list($rows, 0), list($rows * 2, $cols), list($rows, $cols), list($rows * 
2, $cols * 2)));
+}
 
 y = $in_Y;
 op_type = $in_op_type;
diff --git 
a/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixScalarTestReference.dml
 
b/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixScalarTestReference.dml
index 4682aea..40bb906 100644
--- 
a/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixScalarTestReference.dml
+++ 
b/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixScalarTestReference.dml
@@ -19,7 +19,20 @@
 #
 #-------------------------------------------------------------
 
-X = rbind(read($in_X1), read($in_X2));
+fed_type = $in_fed_type;
+
+if(fed_type == 0) { # single federated worker
+  X = read($in_X1);
+}
+else if(fed_type == 1) { # row partitioned
+  X = rbind(read($in_X1), read($in_X2), read($in_X3), read($in_X4));
+}
+else if(fed_type == 2) { # col partitioned
+  X = cbind(read($in_X1), read($in_X2), read($in_X3), read($in_X4));
+}
+else if(fed_type == 3) { # full partitioned
+  X = rbind(cbind(read($in_X1), read($in_X2)), cbind(read($in_X3), 
read($in_X4)));
+}
 
 y = $in_Y;
 op_type = $in_op_type;

Reply via email to