This is an automated email from the ASF dual-hosted git repository.

sebwrede pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 65ea7f3189 [SYSTEMDS-3018] Add Function Parameters to Cost-Based 
Federated Planner
65ea7f3189 is described below

commit 65ea7f318957127e3e75f5bc8cc7d1b5a356c885
Author: sebwrede <[email protected]>
AuthorDate: Tue May 17 10:32:12 2022 +0200

    [SYSTEMDS-3018] Add Function Parameters to Cost-Based Federated Planner
    
    This commit will also:
    
    - Add Null Check to Repetition Estimate Update
    
    - Add Transient Writes to Terminal Hops
    
    - Edit Transpose FEDInstruction So That LOUT Binds Output Fedmapping 
Correctly
    
    - Edit L2SVM Fed Planning Test To Prepare for L2SVM Function Call Tests
    
    Closes #1618.
---
 .../java/org/apache/sysds/hops/AggBinaryOp.java    |   4 +-
 .../hops/fedplanner/FederatedPlannerCostbased.java | 122 ++++++++++++++-------
 .../apache/sysds/hops/rewrite/HopRewriteUtils.java |   4 +-
 .../org/apache/sysds/parser/ForStatementBlock.java |   9 +-
 .../apache/sysds/parser/WhileStatementBlock.java   |   3 +-
 .../instructions/fed/ReorgFEDInstruction.java      |   4 +-
 .../fedplanning/FederatedL2SVMPlanningTest.java    |  46 ++++++--
 .../fedplanning/FederatedMultiplyPlanningTest.java |   1 -
 .../FederatedL2SVMFunctionPlanningTest.dml         |  36 ++++++
 ...FederatedL2SVMFunctionPlanningTestReference.dml |  35 ++++++
 10 files changed, 203 insertions(+), 61 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java 
b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
index 403a0466f0..16b42c4840 100644
--- a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
@@ -44,7 +44,7 @@ import org.apache.sysds.lops.PMMJ;
 import org.apache.sysds.lops.PMapMult;
 import org.apache.sysds.lops.Transform;
 import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
-import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
+import 
org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.meta.DataCharacteristics;
 import org.apache.sysds.runtime.meta.MatrixCharacteristics;
@@ -677,7 +677,7 @@ public class AggBinaryOp extends MultiThreadedHop {
                setLineNumbers(mult);
 
                //result transpose (dimensions set outside)
-               ExecType outTransposeExecType = ( _federatedOutput == 
FEDInstruction.FederatedOutput.FOUT ) ?
+               ExecType outTransposeExecType = ( _federatedOutput == 
FederatedOutput.FOUT ) ?
                        ExecType.FED : ExecType.CP;
                Lop out = new Transform(mult, ReOrgOp.TRANS, getDataType(), 
getValueType(), outTransposeExecType, k);
 
diff --git 
a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java 
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java
index e9a25206f8..1f9abb4c18 100644
--- 
a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java
+++ 
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java
@@ -78,7 +78,7 @@ public class FederatedPlannerCostbased extends 
AFederatedPlanner {
        @Override
        public void rewriteProgram( DMLProgram prog, FunctionCallGraph fgraph, 
FunctionCallSizeInfo fcallSizes ) {
                prog.updateRepetitionEstimates();
-               rewriteStatementBlocks(prog, prog.getStatementBlocks());
+               rewriteStatementBlocks(prog, prog.getStatementBlocks(), null);
                setFinalFedouts();
                updateExplain();
        }
@@ -89,12 +89,13 @@ public class FederatedPlannerCostbased extends 
AFederatedPlanner {
         *
         * @param prog dml program
         * @param sbs  list of statement blocks
+        * @param paramMap map of parameters in function call
         * @return list of statement blocks with the federated output value 
updated for each hop
         */
-       private ArrayList<StatementBlock> rewriteStatementBlocks(DMLProgram 
prog, List<StatementBlock> sbs) {
+       private ArrayList<StatementBlock> rewriteStatementBlocks(DMLProgram 
prog, List<StatementBlock> sbs, Map<String, Hop> paramMap) {
                ArrayList<StatementBlock> rewrittenStmBlocks = new 
ArrayList<>();
                for(StatementBlock stmBlock : sbs)
-                       rewrittenStmBlocks.addAll(rewriteStatementBlock(prog, 
stmBlock));
+                       rewrittenStmBlocks.addAll(rewriteStatementBlock(prog, 
stmBlock, paramMap));
                return rewrittenStmBlocks;
        }
 
@@ -104,79 +105,99 @@ public class FederatedPlannerCostbased extends 
AFederatedPlanner {
         *
         * @param prog dml program
         * @param sb   statement block
+        * @param paramMap map of parameters in function call
         * @return list of statement blocks with the federated output value 
updated for each hop
         */
-       public ArrayList<StatementBlock> rewriteStatementBlock(DMLProgram prog, 
StatementBlock sb) {
+       public ArrayList<StatementBlock> rewriteStatementBlock(DMLProgram prog, 
StatementBlock sb, Map<String, Hop> paramMap) {
                if(sb instanceof WhileStatementBlock)
-                       return rewriteWhileStatementBlock(prog, 
(WhileStatementBlock) sb);
+                       return rewriteWhileStatementBlock(prog, 
(WhileStatementBlock) sb, paramMap);
                else if(sb instanceof IfStatementBlock)
-                       return rewriteIfStatementBlock(prog, (IfStatementBlock) 
sb);
+                       return rewriteIfStatementBlock(prog, (IfStatementBlock) 
sb, paramMap);
                else if(sb instanceof ForStatementBlock) {
                        // This also includes ParForStatementBlocks
-                       return rewriteForStatementBlock(prog, 
(ForStatementBlock) sb);
+                       return rewriteForStatementBlock(prog, 
(ForStatementBlock) sb, paramMap);
                }
                else if(sb instanceof FunctionStatementBlock)
-                       return rewriteFunctionStatementBlock(prog, 
(FunctionStatementBlock) sb);
+                       return rewriteFunctionStatementBlock(prog, 
(FunctionStatementBlock) sb, paramMap);
                else {
                        // StatementBlock type (no subclass)
-                       return rewriteDefaultStatementBlock(prog, sb);
+                       return rewriteDefaultStatementBlock(prog, sb, paramMap);
                }
        }
 
-       private ArrayList<StatementBlock> rewriteWhileStatementBlock(DMLProgram 
prog, WhileStatementBlock whileSB) {
+       private ArrayList<StatementBlock> rewriteWhileStatementBlock(DMLProgram 
prog, WhileStatementBlock whileSB, Map<String, Hop> paramMap) {
                Hop whilePredicateHop = whileSB.getPredicateHops();
-               selectFederatedExecutionPlan(whilePredicateHop);
+               selectFederatedExecutionPlan(whilePredicateHop, paramMap);
                for(Statement stm : whileSB.getStatements()) {
                        WhileStatement whileStm = (WhileStatement) stm;
-                       whileStm.setBody(rewriteStatementBlocks(prog, 
whileStm.getBody()));
+                       whileStm.setBody(rewriteStatementBlocks(prog, 
whileStm.getBody(), paramMap));
                }
                return new ArrayList<>(Collections.singletonList(whileSB));
        }
 
-       private ArrayList<StatementBlock> rewriteIfStatementBlock(DMLProgram 
prog, IfStatementBlock ifSB) {
-               selectFederatedExecutionPlan(ifSB.getPredicateHops());
+       private ArrayList<StatementBlock> rewriteIfStatementBlock(DMLProgram 
prog, IfStatementBlock ifSB, Map<String, Hop> paramMap) {
+               selectFederatedExecutionPlan(ifSB.getPredicateHops(), paramMap);
                for(Statement statement : ifSB.getStatements()) {
                        IfStatement ifStatement = (IfStatement) statement;
-                       ifStatement.setIfBody(rewriteStatementBlocks(prog, 
ifStatement.getIfBody()));
-                       ifStatement.setElseBody(rewriteStatementBlocks(prog, 
ifStatement.getElseBody()));
+                       ifStatement.setIfBody(rewriteStatementBlocks(prog, 
ifStatement.getIfBody(), paramMap));
+                       ifStatement.setElseBody(rewriteStatementBlocks(prog, 
ifStatement.getElseBody(), paramMap));
                }
                return new ArrayList<>(Collections.singletonList(ifSB));
        }
 
-       private ArrayList<StatementBlock> rewriteForStatementBlock(DMLProgram 
prog, ForStatementBlock forSB) {
-               selectFederatedExecutionPlan(forSB.getFromHops());
-               selectFederatedExecutionPlan(forSB.getToHops());
-               selectFederatedExecutionPlan(forSB.getIncrementHops());
+       private ArrayList<StatementBlock> rewriteForStatementBlock(DMLProgram 
prog, ForStatementBlock forSB, Map<String, Hop> paramMap) {
+               selectFederatedExecutionPlan(forSB.getFromHops(), paramMap);
+               selectFederatedExecutionPlan(forSB.getToHops(), paramMap);
+               selectFederatedExecutionPlan(forSB.getIncrementHops(), 
paramMap);
                for(Statement statement : forSB.getStatements()) {
                        ForStatement forStatement = ((ForStatement) statement);
-                       forStatement.setBody(rewriteStatementBlocks(prog, 
forStatement.getBody()));
+                       forStatement.setBody(rewriteStatementBlocks(prog, 
forStatement.getBody(), paramMap));
                }
                return new ArrayList<>(Collections.singletonList(forSB));
        }
 
-       private ArrayList<StatementBlock> 
rewriteFunctionStatementBlock(DMLProgram prog, FunctionStatementBlock funcSB) {
+       private ArrayList<StatementBlock> 
rewriteFunctionStatementBlock(DMLProgram prog, FunctionStatementBlock funcSB, 
Map<String, Hop> paramMap) {
                for(Statement statement : funcSB.getStatements()) {
                        FunctionStatement funcStm = (FunctionStatement) 
statement;
-                       funcStm.setBody(rewriteStatementBlocks(prog, 
funcStm.getBody()));
+                       funcStm.setBody(rewriteStatementBlocks(prog, 
funcStm.getBody(), paramMap));
                }
                return new ArrayList<>(Collections.singletonList(funcSB));
        }
 
-       private ArrayList<StatementBlock> 
rewriteDefaultStatementBlock(DMLProgram prog, StatementBlock sb) {
+       private ArrayList<StatementBlock> 
rewriteDefaultStatementBlock(DMLProgram prog, StatementBlock sb, Map<String, 
Hop> paramMap) {
                if(sb.hasHops()) {
                        for(Hop sbHop : sb.getHops()) {
+                               selectFederatedExecutionPlan(sbHop, paramMap);
                                if(sbHop instanceof FunctionOp) {
                                        String funcName = ((FunctionOp) 
sbHop).getFunctionName();
+                                       Map<String, Hop> funcParamMap = 
getParamMap((FunctionOp) sbHop);
+                                       if ( paramMap != null && funcParamMap 
!= null)
+                                               funcParamMap.putAll(paramMap);
+                                       paramMap = funcParamMap;
                                        FunctionStatementBlock sbFuncBlock = 
prog.getBuiltinFunctionDictionary().getFunction(funcName);
-                                       rewriteStatementBlock(prog, 
sbFuncBlock);
+                                       rewriteStatementBlock(prog, 
sbFuncBlock, paramMap);
                                }
-                               else
-                                       selectFederatedExecutionPlan(sbHop);
                        }
                }
                return new ArrayList<>(Collections.singletonList(sb));
        }
 
+       /**
+        * Return parameter map containing the mapping from parameter name to 
input hop
+        * for all parameters of the function hop.
+        * @param funcOp hop for which the mapping of parameter names to input 
hops are made
+        * @return parameter map or empty map if function has no parameters
+        */
+       private Map<String,Hop> getParamMap(FunctionOp funcOp){
+               String[] inputNames = funcOp.getInputVariableNames();
+               Map<String,Hop> paramMap = new HashMap<>();
+               if ( inputNames != null ){
+                       for ( int i = 0; i < funcOp.getInput().size(); i++ )
+                               paramMap.put(inputNames[i],funcOp.getInput(i));
+               }
+               return paramMap;
+       }
+
        /**
         * Set final fedouts of all hops starting from terminal hops.
         */
@@ -266,21 +287,23 @@ public class FederatedPlannerCostbased extends 
AFederatedPlanner {
         * The cost estimates of the hops are also updated when FederatedOutput 
is updated in the hops.
         *
         * @param roots starting point for going through the Hop DAG to update 
the FederatedOutput fields.
+        * @param paramMap map of parameters in function call
         */
        @SuppressWarnings("unused")
-       private void selectFederatedExecutionPlan(ArrayList<Hop> roots){
+       private void selectFederatedExecutionPlan(ArrayList<Hop> roots, 
Map<String, Hop> paramMap){
                for ( Hop root : roots )
-                       selectFederatedExecutionPlan(root);
+                       selectFederatedExecutionPlan(root, paramMap);
        }
 
        /**
         * Select federated execution plan for every Hop in the DAG starting 
from given root.
         *
         * @param root starting point for going through the Hop DAG to update 
the federatedOutput fields
+        * @param paramMap map of parameters in function call
         */
-       private void selectFederatedExecutionPlan(Hop root) {
+       private void selectFederatedExecutionPlan(Hop root, Map<String, Hop> 
paramMap) {
                if ( root != null ){
-                       visitFedPlanHop(root);
+                       visitFedPlanHop(root, paramMap);
                        if ( HopRewriteUtils.isTerminalHop(root) )
                                terminalHops.add(root);
                }
@@ -290,17 +313,18 @@ public class FederatedPlannerCostbased extends 
AFederatedPlanner {
         * Go through the Hop DAG and set the FederatedOutput field and cost 
estimate for each Hop from leaf to given currentHop.
         *
         * @param currentHop the Hop from which the DAG is visited
+        * @param paramMap map of parameters in function call
         */
-       private void visitFedPlanHop(Hop currentHop) {
+       private void visitFedPlanHop(Hop currentHop, Map<String, Hop> paramMap) 
{
                // If the currentHop is in the hopRelMemo table, it means that 
it has been visited
                if(hopRelMemo.containsHop(currentHop))
                        return;
                debugLog(currentHop);
                // If the currentHop has input, then the input should be 
visited depth-first
                for(Hop input : currentHop.getInput())
-                       visitFedPlanHop(input);
+                       visitFedPlanHop(input, paramMap);
                // Put FOUT and LOUT HopRels into the memo table
-               ArrayList<HopRel> hopRels = getFedPlans(currentHop);
+               ArrayList<HopRel> hopRels = getFedPlans(currentHop, paramMap);
                // Put NONE HopRel into memo table if no FOUT or LOUT HopRels 
were added
                if(hopRels.isEmpty())
                        hopRels.add(getNONEHopRel(currentHop));
@@ -319,17 +343,14 @@ public class FederatedPlannerCostbased extends 
AFederatedPlanner {
        /**
         * Get the alternative plans regarding the federated output for given 
currentHop.
         * @param currentHop for which alternative federated plans are generated
+        * @param paramMap map of parameters in function call
         * @return list of alternative plans
         */
-       private ArrayList<HopRel> getFedPlans(Hop currentHop){
+       private ArrayList<HopRel> getFedPlans(Hop currentHop, Map<String, Hop> 
paramMap){
                ArrayList<HopRel> hopRels = new ArrayList<>();
                ArrayList<Hop> inputHops = currentHop.getInput();
-               if ( HopRewriteUtils.isData(currentHop, 
Types.OpOpData.TRANSIENTREAD) ){
-                       Hop tWriteHop = 
transientWrites.get(currentHop.getName());
-                       if ( tWriteHop == null )
-                               throw new DMLRuntimeException("Transient write 
not found for " + currentHop);
-                       inputHops = new 
ArrayList<>(Collections.singletonList(tWriteHop));
-               }
+               if ( HopRewriteUtils.isData(currentHop, 
Types.OpOpData.TRANSIENTREAD) )
+                       inputHops = getTransientInputs(currentHop, paramMap);
                if ( HopRewriteUtils.isData(currentHop, 
Types.OpOpData.TRANSIENTWRITE) )
                        transientWrites.put(currentHop.getName(), currentHop);
                if ( HopRewriteUtils.isData(currentHop, 
Types.OpOpData.FEDERATED) )
@@ -341,6 +362,25 @@ public class FederatedPlannerCostbased extends 
AFederatedPlanner {
                return hopRels;
        }
 
+       /**
+        * Get transient inputs from either paramMap or transientWrites.
+        * Inputs from paramMap has higher priority than inputs from 
transientWrites.
+        * @param currentHop hop for which inputs are read from maps
+        * @param paramMap of local parameters
+        * @return inputs of currentHop
+        */
+       private ArrayList<Hop> getTransientInputs(Hop currentHop, Map<String, 
Hop> paramMap){
+               Hop tWriteHop = null;
+               if ( paramMap != null)
+                       tWriteHop = paramMap.get(currentHop.getName());
+               if ( tWriteHop == null )
+                       tWriteHop = transientWrites.get(currentHop.getName());
+               if ( tWriteHop == null )
+                       throw new DMLRuntimeException("Transient write not 
found for " + currentHop);
+               else
+                       return new 
ArrayList<>(Collections.singletonList(tWriteHop));
+       }
+
        /**
         * Generate a collection of FOUT HopRels representing the different 
possible FType outputs.
         * For each FType output, only the minimum cost input combination is 
chosen.
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java 
b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
index 6dbc5e35b6..d10a43e810 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
@@ -1167,7 +1167,9 @@ public class HopRewriteUtils {
        public static boolean isTerminalHop(Hop hop){
                return isUnary(hop, OpOp1.PRINT)
                        || isNary(hop, OpOpN.PRINTF)
-                       || isData(hop, OpOpData.PERSISTENTWRITE);
+                       || isData(hop, OpOpData.PERSISTENTWRITE)
+                       || isData(hop, OpOpData.TRANSIENTWRITE)
+                       || hop instanceof FunctionOp;
        }
        
        public static boolean isMatrixMultiply(Hop hop) {
diff --git a/src/main/java/org/apache/sysds/parser/ForStatementBlock.java 
b/src/main/java/org/apache/sysds/parser/ForStatementBlock.java
index b21b9b58a6..ce31ae9bcf 100644
--- a/src/main/java/org/apache/sysds/parser/ForStatementBlock.java
+++ b/src/main/java/org/apache/sysds/parser/ForStatementBlock.java
@@ -453,9 +453,12 @@ public class ForStatementBlock extends StatementBlock
        @Override
        public void updateRepetitionEstimates(double repetitions){
                this.repetitions = repetitions * getEstimateReps();
-               _fromHops.updateRepetitionEstimates(this.repetitions);
-               _toHops.updateRepetitionEstimates(this.repetitions);
-               _incrementHops.updateRepetitionEstimates(this.repetitions);
+               if ( _fromHops != null )
+                       _fromHops.updateRepetitionEstimates(this.repetitions);
+               if ( _toHops != null )
+                       _toHops.updateRepetitionEstimates(this.repetitions);
+               if ( _incrementHops != null )
+                       
_incrementHops.updateRepetitionEstimates(this.repetitions);
                for(Statement statement : getStatements()) {
                        List<StatementBlock> children = ((ForStatement) 
statement).getBody();
                        for ( StatementBlock stmBlock : children ){
diff --git a/src/main/java/org/apache/sysds/parser/WhileStatementBlock.java 
b/src/main/java/org/apache/sysds/parser/WhileStatementBlock.java
index b28e6825bb..8a92f3bc23 100644
--- a/src/main/java/org/apache/sysds/parser/WhileStatementBlock.java
+++ b/src/main/java/org/apache/sysds/parser/WhileStatementBlock.java
@@ -322,7 +322,8 @@ public class WhileStatementBlock extends StatementBlock
        @Override
        public void updateRepetitionEstimates(double repetitions){
                this.repetitions = repetitions * DEFAULT_LOOP_REPETITIONS;
-               getPredicateHops().updateRepetitionEstimates(this.repetitions);
+               if ( getPredicateHops() != null )
+                       
getPredicateHops().updateRepetitionEstimates(this.repetitions);
                for(Statement statement : getStatements()) {
                        List<StatementBlock> children = 
((WhileStatement)statement).getBody();
                        for ( StatementBlock stmBlock : children ){
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
index aff69a24a6..2f9e26a2a1 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
@@ -104,7 +104,7 @@ public class ReorgFEDInstruction extends 
UnaryFEDInstruction {
                if( !mo1.isFederated() )
                        throw new DMLRuntimeException("Federated Reorg: "
                                + "Federated input expected, but invoked w/ 
"+mo1.isFederated());
-               if ( !( mo1.isFederated(FType.COL) || 
mo1.isFederated(FType.ROW) || mo1.isFederated(FType.PART) ) )
+               if ( !( mo1.isFederated(FType.COL) || 
mo1.isFederated(FType.ROW) ) )
                        throw new DMLRuntimeException("Federation type " + 
mo1.getFedMapping().getType()
                                + " is not supported for Reorg processing");
 
@@ -126,7 +126,7 @@ public class ReorgFEDInstruction extends 
UnaryFEDInstruction {
                                FederatedRequest getRequest = new 
FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr1.getID());
                                Future<FederatedResponse>[] execResponse = 
mo1.getFedMapping().execute(getTID(), true, fr1, getRequest);
                                ec.setMatrixOutput(output.getName(),
-                                       FederationUtils.bind(execResponse, 
mo1.isFederated(FType.COL)));
+                                       FederationUtils.bind(execResponse, 
mo1.isFederated(FType.ROW)));
                        }
                } else if ( mo1.isFederated(FType.PART) ){
                        throw new DMLRuntimeException("Operation with opcode " 
+ instOpcode + " is not supported with PART input");
diff --git 
a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedL2SVMPlanningTest.java
 
b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedL2SVMPlanningTest.java
index 1ba9966773..60ab0d93ce 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedL2SVMPlanningTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedL2SVMPlanningTest.java
@@ -28,6 +28,7 @@ import org.apache.sysds.runtime.privacy.PrivacyConstraint;
 import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
+import org.junit.Ignore;
 import org.junit.Test;
 
 import java.io.File;
@@ -41,6 +42,7 @@ public class FederatedL2SVMPlanningTest extends 
AutomatedTestBase {
 
        private final static String TEST_DIR = "functions/privacy/fedplanning/";
        private final static String TEST_NAME = "FederatedL2SVMPlanningTest";
+       private final static String TEST_NAME_2 = 
"FederatedL2SVMFunctionPlanningTest";
        private final static String TEST_CLASS_DIR = TEST_DIR + 
FederatedL2SVMPlanningTest.class.getSimpleName() + "/";
        private static File TEST_CONF_FILE;
 
@@ -52,6 +54,7 @@ public class FederatedL2SVMPlanningTest extends 
AutomatedTestBase {
        public void setUp() {
                TestUtils.clearAssertionInformation();
                addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"Z"}));
+               addTestConfiguration(TEST_NAME_2, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_2, new String[] {"Z"}));
        }
 
        @Test
@@ -59,24 +62,47 @@ public class FederatedL2SVMPlanningTest extends 
AutomatedTestBase {
                String[] expectedHeavyHitters = new String[]{ "fed_fedinit", 
"fed_ba+*", "fed_tak+*", "fed_+*",
                        "fed_max", "fed_1-*", "fed_tsmm", "fed_>"};
                setTestConf("SystemDS-config-fout.xml");
-               loadAndRunTest(expectedHeavyHitters);
+               loadAndRunTest(expectedHeavyHitters, TEST_NAME);
        }
 
        @Test
        public void runL2SVMHeuristicTest(){
                String[] expectedHeavyHitters = new String[]{ "fed_fedinit", 
"fed_ba+*"};
                setTestConf("SystemDS-config-heuristic.xml");
-               loadAndRunTest(expectedHeavyHitters);
+               loadAndRunTest(expectedHeavyHitters, TEST_NAME);
        }
 
        @Test
        public void runL2SVMCostBasedTest(){
-               //String[] expectedHeavyHitters = new String[]{ "fed_fedinit", 
"fed_ba+*", "fed_tak+*", "fed_+*",
-               //      "fed_max", "fed_1-*", "fed_tsmm", "fed_>"};
                String[] expectedHeavyHitters = new String[]{ "fed_fedinit", 
"fed_ba+*", "fed_tak+*", "fed_+*",
-                       "fed_max", "fed_1-*", "fed_>"};
+                       "fed_max", "fed_1-*", "fed_tsmm", "fed_>"};
+               setTestConf("SystemDS-config-cost-based.xml");
+               loadAndRunTest(expectedHeavyHitters, TEST_NAME);
+       }
+
+       @Test
+       @Ignore
+       public void runL2SVMFunctionFOUTTest(){
+               String[] expectedHeavyHitters = new String[]{ "fed_fedinit", 
"fed_ba+*", "fed_tak+*", "fed_+*",
+                       "fed_max", "fed_1-*", "fed_tsmm", "fed_>"};
+               setTestConf("SystemDS-config-fout.xml");
+               loadAndRunTest(expectedHeavyHitters, TEST_NAME_2);
+       }
+
+       @Test
+       @Ignore
+       public void runL2SVMFunctionHeuristicTest(){
+               String[] expectedHeavyHitters = new String[]{ "fed_fedinit", 
"fed_ba+*"};
+               setTestConf("SystemDS-config-heuristic.xml");
+               loadAndRunTest(expectedHeavyHitters, TEST_NAME_2);
+       }
+
+       @Test
+       public void runL2SVMFunctionCostBasedTest(){
+               String[] expectedHeavyHitters = new String[]{ "fed_fedinit", 
"fed_ba+*", "fed_tak+*", "fed_+*",
+                       "fed_max", "fed_1-*", "fed_tsmm", "fed_>"};
                setTestConf("SystemDS-config-cost-based.xml");
-               loadAndRunTest(expectedHeavyHitters);
+               loadAndRunTest(expectedHeavyHitters, TEST_NAME_2);
        }
 
        private void setTestConf(String test_conf){
@@ -117,7 +143,7 @@ public class FederatedL2SVMPlanningTest extends 
AutomatedTestBase {
                writeStandardMatrix(matrixName, seed, halfRows, 
privacyConstraint);
        }
 
-       private void loadAndRunTest(String[] expectedHeavyHitters){
+       private void loadAndRunTest(String[] expectedHeavyHitters, String 
testName){
 
                boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
                Types.ExecMode platformOld = rtplatform;
@@ -126,7 +152,7 @@ public class FederatedL2SVMPlanningTest extends 
AutomatedTestBase {
                Thread t1 = null, t2 = null;
 
                try {
-                       getAndLoadTestConfiguration(TEST_NAME);
+                       getAndLoadTestConfiguration(testName);
                        String HOME = SCRIPT_DIR + TEST_DIR;
 
                        writeInputMatrices();
@@ -137,7 +163,7 @@ public class FederatedL2SVMPlanningTest extends 
AutomatedTestBase {
                        t2 = startLocalFedWorkerThread(port2);
 
                        // Run actual dml script with federated matrix
-                       fullDMLScriptName = HOME + TEST_NAME + ".dml";
+                       fullDMLScriptName = HOME + testName + ".dml";
                        programArgs = new String[] { "-stats", "-explain", 
"hops", "-nvargs",
                                "X1=" + TestUtils.federatedAddress(port1, 
input("X1")),
                                "X2=" + TestUtils.federatedAddress(port2, 
input("X2")),
@@ -145,7 +171,7 @@ public class FederatedL2SVMPlanningTest extends 
AutomatedTestBase {
                        runTest(true, false, null, -1);
 
                        // Run reference dml script with normal matrix
-                       fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+                       fullDMLScriptName = HOME + testName + "Reference.dml";
                        programArgs = new String[] {"-nvargs", "X1=" + 
input("X1"), "X2=" + input("X2"),
                                "Y=" + input("Y"), "Z=" + expected("Z")};
                        runTest(true, false, null, -1);
diff --git 
a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
 
b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
index e8d16f6bcb..b9a3a14fd5 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
@@ -130,7 +130,6 @@ public class FederatedMultiplyPlanningTest extends 
AutomatedTestBase {
        }
 
        @Test
-       @Ignore
        public void federatedMultiplyDoubleHop() {
                String[] expectedHeavyHitters = new String[]{"fed_*", 
"fed_fedinit", "fed_r'", "fed_ba+*"};
                federatedTwoMatricesSingleNodeTest(TEST_NAME_7, 
expectedHeavyHitters);
diff --git 
a/src/test/scripts/functions/privacy/fedplanning/FederatedL2SVMFunctionPlanningTest.dml
 
b/src/test/scripts/functions/privacy/fedplanning/FederatedL2SVMFunctionPlanningTest.dml
new file mode 100644
index 0000000000..134d1b35c2
--- /dev/null
+++ 
b/src/test/scripts/functions/privacy/fedplanning/FederatedL2SVMFunctionPlanningTest.dml
@@ -0,0 +1,36 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+  maxii = 20
+  verbose = FALSE
+  columnId = -1
+  Y = read($Y)
+  X = federated(addresses=list($X1, $X2),
+    ranges=list(list(0, 0), list($r / 2, $c), list($r / 2, 0), list($r, $c)))
+  intercept = FALSE
+  epsilon = 1e-12
+  reg = 1
+  maxIterations = 100
+
+  model = l2svm(X=X,  Y=Y, intercept = FALSE, epsilon = epsilon, reg = reg, 
maxIterations = maxIterations)
+
+  write(model, $Z)
diff --git 
a/src/test/scripts/functions/privacy/fedplanning/FederatedL2SVMFunctionPlanningTestReference.dml
 
b/src/test/scripts/functions/privacy/fedplanning/FederatedL2SVMFunctionPlanningTestReference.dml
new file mode 100644
index 0000000000..7fec5d2a20
--- /dev/null
+++ 
b/src/test/scripts/functions/privacy/fedplanning/FederatedL2SVMFunctionPlanningTestReference.dml
@@ -0,0 +1,35 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+  maxii = 20
+  verbose = FALSE
+  columnId = -1
+  Y = read($Y)
+  X = rbind(read($X1), read($X2))
+  intercept = FALSE
+  epsilon = 1e-12
+  reg = 1
+  maxIterations = 100
+
+  model = l2svm(X=X,  Y=Y, intercept = FALSE, epsilon = epsilon, reg = reg, 
maxIterations = maxIterations)
+
+  write(model, $Z)

Reply via email to