This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 6523457 [SYSTEMDS-3018] Costing of Federated Execution Plans
6523457 is described below
commit 65234573b986e8d0ba5bd27bd3d7641b921a7641
Author: sebwrede <[email protected]>
AuthorDate: Fri Jun 18 17:31:31 2021 +0200
[SYSTEMDS-3018] Costing of Federated Execution Plans
Closes #1367.
---
src/main/java/org/apache/sysds/hops/Hop.java | 93 +++++--
.../codegen/opt/PlanSelectionFuseCostBasedV2.java | 182 +-------------
.../org/apache/sysds/hops/cost/ComputeCost.java | 225 +++++++++++++++++
.../sysds/hops/cost/CostEstimationWrapper.java | 13 +-
.../org/apache/sysds/hops/cost/CostEstimator.java | 72 +++---
.../hops/cost/CostEstimatorStaticRuntime.java | 40 +--
.../org/apache/sysds/hops/cost/FederatedCost.java | 117 +++++++++
.../sysds/hops/cost/FederatedCostEstimator.java | 214 ++++++++++++++++
.../apache/sysds/hops/rewrite/ProgramRewriter.java | 1 +
.../hops/rewrite/RewriteFederatedExecution.java | 133 ++++++++--
.../rewrite/RewriteFederatedStatementBlocks.java | 66 +++++
.../runtime/instructions/FEDInstructionParser.java | 1 +
.../fed/AggregateUnaryFEDInstruction.java | 56 ++++-
.../instructions/fed/AppendFEDInstruction.java | 5 +-
.../instructions/fed/CtableFEDInstruction.java | 4 +-
.../runtime/instructions/fed/FEDInstruction.java | 3 +
.../instructions/fed/ReorgFEDInstruction.java | 36 ++-
.../org/apache/sysds/test/AutomatedTestBase.java | 26 ++
.../fedplanning/FederatedCostEstimatorTest.java | 279 +++++++++++++++++++++
.../fedplanning/FederatedMultiplyPlanningTest.java | 33 +--
.../fedplanning/BinaryCostEstimatorTest.dml | 26 ++
.../FederatedMultiplyCostEstimatorTest.dml | 31 +++
.../fedplanning/ForLoopCostEstimatorTest.dml | 27 ++
.../fedplanning/FunctionCostEstimatorTest.dml | 28 +++
.../fedplanning/IfElseCostEstimatorTest.dml | 30 +++
.../fedplanning/ParForLoopCostEstimatorTest.dml | 27 ++
.../privacy/fedplanning/WhileCostEstimatorTest.dml | 27 ++
27 files changed, 1483 insertions(+), 312 deletions(-)
diff --git a/src/main/java/org/apache/sysds/hops/Hop.java
b/src/main/java/org/apache/sysds/hops/Hop.java
index ececf52..45fc3af 100644
--- a/src/main/java/org/apache/sysds/hops/Hop.java
+++ b/src/main/java/org/apache/sysds/hops/Hop.java
@@ -36,6 +36,7 @@ import org.apache.sysds.common.Types.OpOp2;
import org.apache.sysds.common.Types.OpOpData;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.conf.ConfigurationManager;
+import org.apache.sysds.hops.cost.FederatedCost;
import org.apache.sysds.hops.recompile.Recompiler;
import org.apache.sysds.hops.recompile.Recompiler.ResetType;
import org.apache.sysds.lops.CSVReBlock;
@@ -91,8 +92,9 @@ public abstract class Hop implements ParseInfo {
* If it is lout, the output should be retrieved by the coordinator.
*/
protected FederatedOutput _federatedOutput = FederatedOutput.NONE;
+ protected FederatedCost _federatedCost = new FederatedCost();
- // Estimated size for the output produced from this Hop
+ // Estimated size for the output produced from this Hop in bytes
protected double _outputMemEstimate = OptimizerUtils.INVALID_SIZE;
// Estimated size for the entire operation represented by this Hop
@@ -535,7 +537,7 @@ public abstract class Hop implements ParseInfo {
* only use getMemEstimate(), which gives memory required to store
* all inputs and the output.
*
- * @return output size memory estimate
+ * @return output size memory estimate in bytes
*/
protected double getOutputSize() {
return _outputMemEstimate;
@@ -545,14 +547,22 @@ public abstract class Hop implements ParseInfo {
return getInputSize(null);
}
- protected double getInputSize(Collection<String> exclVars) {
+ /**
+ * Get the memory estimate of inputs as the sum of input estimates in
bytes.
+ * @param exclVars name of input hops to exclude from the input estimate
+ * @param injectedDefault default memory estimate (bytes) used when the
memory estimate of the input is negative
+ * @return input memory estimate in bytes
+ */
+ protected double getInputSize(Collection<String> exclVars, double
injectedDefault){
double sum = 0;
int len = _input.size();
for( int i=0; i<len; i++ ) { //for all inputs
Hop hi = _input.get(i);
if( exclVars != null && exclVars.contains(hi.getName())
)
continue;
- double hmout = hi.getOutputMemEstimate();
+ double hmout = hi.getOutputMemEstimate(injectedDefault);
+ if (hmout < 0)
+ hmout =
injectedDefault*(Math.max(hi.getDim1(),1) * Math.max(hi.getDim2(),1));
if( hmout > 1024*1024 ) {//for relevant sizes
//check if already included in estimate (if an
input is used
//multiple times it is still only required once
in memory)
@@ -564,10 +574,19 @@ public abstract class Hop implements ParseInfo {
}
sum += hmout;
}
-
+
return sum;
}
+ /**
+ * Get the memory estimate of inputs as the sum of input estimates in
bytes.
+ * @param exclVars name of input hops to exclude from the input estimate
+ * @return input memory estimate in bytes
+ */
+ protected double getInputSize(Collection<String> exclVars) {
+ return getInputSize(exclVars, OptimizerUtils.INVALID_SIZE);
+ }
+
protected double getInputSize( int pos ){
double ret = 0;
if( _input.size()>pos )
@@ -582,12 +601,11 @@ public abstract class Hop implements ParseInfo {
/**
* NOTES:
* * Purpose: Whenever the output dimensions / sparsity of a hop are
unknown, this hop
- * should store its worst-case output statistics (if known) in that
table. Subsequent
- * hops can then
+ * should store its worst-case output statistics (if known) in that
table.
* * Invocation: Intended to be called for ALL root nodes of one Hops
DAG with the same
* (initially empty) memo table.
*
- * @return memory estimate
+ * @return memory estimate in bytes
*/
public double getMemEstimate() {
if ( OptimizerUtils.isMemoryBasedOptLevel() ) {
@@ -620,17 +638,46 @@ public abstract class Hop implements ParseInfo {
}
//wrappers for meaningful public names to memory estimates.
-
+
+ /**
+ * Get the memory estimate of inputs as the sum of input estimates in
bytes.
+ * @return input memory estimate in bytes
+ */
public double getInputMemEstimate()
{
return getInputSize();
}
-
+
+ /**
+ * Get the memory estimate of inputs as the sum of input estimates in
bytes.
+ * @param injectedDefault default memory estimate (bytes) used when the
memory estimate of the input is negative
+ * @return input memory estimate in bytes
+ */
+ public double getInputMemEstimate(double injectedDefault){
+ return getInputSize(null, injectedDefault);
+ }
+
+ /**
+ * Output memory estimate in bytes.
+ * @return output memory estimate in bytes
+ */
public double getOutputMemEstimate()
{
return getOutputSize();
}
+ /**
+ * Output memory estimate in bytes with negative memory estimates
replaced by the injected default.
+ * The injected default represents the memory estimate per output cell,
hence it is multiplied by the estimated
+ * dimensions of the output of the hop.
+ * @param injectedDefault memory estimate to be returned in case the
memory estimate defaults to a negative number
+ * @return output memory estimate in bytes
+ */
+ public double getOutputMemEstimate(double injectedDefault)
+ {
+ return
Math.max(getOutputMemEstimate(),injectedDefault*(Math.max(getDim1(),1) *
Math.max(getDim2(),1)));
+ }
+
public double getIntermediateMemEstimate()
{
return getIntermediateSize();
@@ -823,17 +870,13 @@ public abstract class Hop implements ParseInfo {
* This method only has an effect if FEDERATED_COMPILATION is activated.
*/
protected void updateETFed(){
- if ( _federatedOutput == FederatedOutput.FOUT ||
_federatedOutput == FederatedOutput.LOUT )
+ if ( _federatedOutput.isForced() )
_etype = ExecType.FED;
}
public boolean isFederated(){
return getExecType() == ExecType.FED;
}
-
- public boolean isFederatedOutput(){
- return _federatedOutput == FederatedOutput.FOUT;
- }
public boolean someInputFederated(){
return getInput().stream().anyMatch(Hop::hasFederatedOutput);
@@ -889,6 +932,26 @@ public abstract class Hop implements ParseInfo {
return _federatedOutput == FederatedOutput.FOUT;
}
+ public boolean hasLocalOutput(){
+ return _federatedOutput == FederatedOutput.LOUT;
+ }
+
+ /**
+ * Check if federated cost has been initialized for this Hop.
+ * @return true if federated cost has been initialized
+ */
+ public boolean federatedCostInitialized(){
+ return _federatedCost.getTotal() > 0;
+ }
+
+ public FederatedCost getFederatedCost(){
+ return _federatedCost;
+ }
+
+ public void setFederatedCost(FederatedCost cost){
+ _federatedCost = cost;
+ }
+
public void setUpdateType(UpdateType update){
_updateType = update;
}
diff --git
a/src/main/java/org/apache/sysds/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
b/src/main/java/org/apache/sysds/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
index 0b20876..c6cfe9f 100644
---
a/src/main/java/org/apache/sysds/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
+++
b/src/main/java/org/apache/sysds/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
@@ -46,17 +46,8 @@ import org.apache.sysds.common.Types.OpOpData;
import org.apache.sysds.common.Types.OpOpN;
import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.hops.AggUnaryOp;
-import org.apache.sysds.hops.BinaryOp;
-import org.apache.sysds.hops.DnnOp;
import org.apache.sysds.hops.Hop;
-import org.apache.sysds.hops.IndexingOp;
-import org.apache.sysds.hops.LiteralOp;
-import org.apache.sysds.hops.NaryOp;
import org.apache.sysds.hops.OptimizerUtils;
-import org.apache.sysds.hops.ParameterizedBuiltinOp;
-import org.apache.sysds.hops.ReorgOp;
-import org.apache.sysds.hops.TernaryOp;
-import org.apache.sysds.hops.UnaryOp;
import org.apache.sysds.hops.codegen.opt.ReachabilityGraph.SubProblem;
import org.apache.sysds.hops.codegen.template.CPlanMemoTable;
import org.apache.sysds.hops.codegen.template.TemplateOuterProduct;
@@ -64,6 +55,7 @@ import org.apache.sysds.hops.codegen.template.TemplateRow;
import org.apache.sysds.hops.codegen.template.TemplateUtils;
import org.apache.sysds.hops.codegen.template.CPlanMemoTable.MemoTableEntry;
import org.apache.sysds.hops.codegen.template.TemplateBase.TemplateType;
+import org.apache.sysds.hops.cost.ComputeCost;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.runtime.codegen.LibSpoofPrimitives;
import org.apache.sysds.runtime.controlprogram.caching.LazyWriteBuffer;
@@ -175,8 +167,8 @@ public class PlanSelectionFuseCostBasedV2 extends
PlanSelection
//obtain hop compute costs per cell once
HashMap<Long, Double> computeCosts = new HashMap<>();
for( Long hopID : part.getPartition() )
- getComputeCosts(memo.getHopRefs().get(hopID),
computeCosts);
-
+ computeCosts.put(hopID,
ComputeCost.getHOPComputeCost(memo.getHopRefs().get(hopID)));
+
//prepare pruning helpers and prune memo table w/
determined mat points
StaticCosts costs = new StaticCosts(computeCosts,
sumComputeCost(computeCosts),
getReadCost(part, memo),
getWriteCost(part.getRoots(), memo), minOuterSparsity(part, memo));
@@ -1011,174 +1003,6 @@ public class PlanSelectionFuseCostBasedV2 extends
PlanSelection
return costs;
}
- private static void getComputeCosts(Hop current, HashMap<Long, Double>
computeCosts)
- {
- //get costs for given hop
- double costs = 1;
- if( current instanceof UnaryOp ) {
- switch( ((UnaryOp)current).getOp() ) {
- case ABS:
- case ROUND:
- case CEIL:
- case FLOOR:
- case SIGN: costs = 1; break;
- case SPROP:
- case SQRT: costs = 2; break;
- case EXP: costs = 18; break;
- case SIGMOID: costs = 21; break;
- case LOG:
- case LOG_NZ: costs = 32; break;
- case NCOL:
- case NROW:
- case PRINT:
- case ASSERT:
- case CAST_AS_BOOLEAN:
- case CAST_AS_DOUBLE:
- case CAST_AS_INT:
- case CAST_AS_MATRIX:
- case CAST_AS_SCALAR: costs = 1; break;
- case SIN: costs = 18; break;
- case COS: costs = 22; break;
- case TAN: costs = 42; break;
- case ASIN: costs = 93; break;
- case ACOS: costs = 103; break;
- case ATAN: costs = 40; break;
- case SINH: costs = 93; break; // TODO:
- case COSH: costs = 103; break;
- case TANH: costs = 40; break;
- case CUMSUM:
- case CUMMIN:
- case CUMMAX:
- case CUMPROD: costs = 1; break;
- case CUMSUMPROD: costs = 2; break;
- default:
- LOG.warn("Cost model not "
- + "implemented yet for:
"+((UnaryOp)current).getOp());
- }
- }
- else if( current instanceof BinaryOp ) {
- switch( ((BinaryOp)current).getOp() ) {
- case MULT:
- case PLUS:
- case MINUS:
- case MIN:
- case MAX:
- case AND:
- case OR:
- case EQUAL:
- case NOTEQUAL:
- case LESS:
- case LESSEQUAL:
- case GREATER:
- case GREATEREQUAL:
- case CBIND:
- case RBIND: costs = 1; break;
- case INTDIV: costs = 6; break;
- case MODULUS: costs = 8; break;
- case DIV: costs = 22; break;
- case LOG:
- case LOG_NZ: costs = 32; break;
- case POW: costs =
(HopRewriteUtils.isLiteralOfValue(
- current.getInput().get(1), 2) ?
1 : 16); break;
- case MINUS_NZ:
- case MINUS1_MULT: costs = 2; break;
- case MOMENT:
- int type = (int)
(current.getInput().get(1) instanceof LiteralOp ?
-
HopRewriteUtils.getIntValueSafe((LiteralOp)current.getInput().get(1)) : 2);
- switch( type ) {
- case 0: costs = 1; break;
//count
- case 1: costs = 8; break; //mean
- case 2: costs = 16; break; //cm2
- case 3: costs = 31; break; //cm3
- case 4: costs = 51; break; //cm4
- case 5: costs = 16; break;
//variance
- }
- break;
- case COV: costs = 23; break;
- default:
- LOG.warn("Cost model not "
- + "implemented yet for:
"+((BinaryOp)current).getOp());
- }
- }
- else if( current instanceof TernaryOp ) {
- switch( ((TernaryOp)current).getOp() ) {
- case IFELSE:
- case PLUS_MULT:
- case MINUS_MULT: costs = 2; break;
- case CTABLE: costs = 3; break;
- case MOMENT:
- int type = (int)
(current.getInput().get(1) instanceof LiteralOp ?
-
HopRewriteUtils.getIntValueSafe((LiteralOp)current.getInput().get(1)) : 2);
- switch( type ) {
- case 0: costs = 2; break;
//count
- case 1: costs = 9; break; //mean
- case 2: costs = 17; break; //cm2
- case 3: costs = 32; break; //cm3
- case 4: costs = 52; break; //cm4
- case 5: costs = 17; break;
//variance
- }
- break;
- case COV: costs = 23; break;
- default:
- LOG.warn("Cost model not "
- + "implemented yet for:
"+((TernaryOp)current).getOp());
- }
- }
- else if( current instanceof NaryOp ) {
- costs = HopRewriteUtils.isNary(current, OpOpN.MIN,
OpOpN.MAX, OpOpN.PLUS) ?
- current.getInput().size() : 1;
- }
- else if( current instanceof ParameterizedBuiltinOp ) {
- costs = 1;
- }
- else if( current instanceof IndexingOp ) {
- costs = 1;
- }
- else if( current instanceof ReorgOp ) {
- costs = 1;
- }
- else if( current instanceof DnnOp ) {
- switch( ((DnnOp)current).getOp() ) {
- case BIASADD:
- case BIASMULT:
- costs = 2;
- default:
- LOG.warn("Cost model not "
- + "implemented yet for:
"+((DnnOp)current).getOp());
- }
- }
- else if( current instanceof AggBinaryOp ) {
- //outer product template w/ matrix-matrix
- //or row template w/ matrix-vector or matrix-matrix
- costs = 2 * current.getInput().get(0).getDim2();
- if( current.getInput().get(0).dimsKnown(true) )
- costs *=
current.getInput().get(0).getSparsity();
- }
- else if( current instanceof AggUnaryOp) {
- switch(((AggUnaryOp)current).getOp()) {
- case SUM: costs = 4; break;
- case SUM_SQ: costs = 5; break;
- case MIN:
- case MAX: costs = 1; break;
- default:
- LOG.warn("Cost model not "
- + "implemented yet for:
"+((AggUnaryOp)current).getOp());
- }
- switch(((AggUnaryOp)current).getDirection()) {
- case Col: costs *=
Math.max(current.getInput().get(0).getDim1(),1); break;
- case Row: costs *=
Math.max(current.getInput().get(0).getDim2(),1); break;
- case RowCol: costs *=
getSize(current.getInput().get(0)); break;
- }
- }
-
- //scale by current output size in order to correctly reflect
- //a mix of row and cell operations in the same fused operator
- //(e.g., row template with fused column vector operations)
- costs *= getSize(current);
-
- computeCosts.put(current.getHopID(), costs);
- }
-
private static boolean hasNoRefToMatPoint(long hopID,
MemoTableEntry me, InterestingPoint[] M, boolean[]
plan) {
return !InterestingPoint.isMatPoint(M, hopID, me, plan);
diff --git a/src/main/java/org/apache/sysds/hops/cost/ComputeCost.java
b/src/main/java/org/apache/sysds/hops/cost/ComputeCost.java
new file mode 100644
index 0000000..3ac64b6
--- /dev/null
+++ b/src/main/java/org/apache/sysds/hops/cost/ComputeCost.java
@@ -0,0 +1,225 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.cost;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.hops.AggBinaryOp;
+import org.apache.sysds.hops.AggUnaryOp;
+import org.apache.sysds.hops.BinaryOp;
+import org.apache.sysds.hops.DnnOp;
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.hops.IndexingOp;
+import org.apache.sysds.hops.LiteralOp;
+import org.apache.sysds.hops.NaryOp;
+import org.apache.sysds.hops.ParameterizedBuiltinOp;
+import org.apache.sysds.hops.ReorgOp;
+import org.apache.sysds.hops.TernaryOp;
+import org.apache.sysds.hops.UnaryOp;
+import org.apache.sysds.hops.rewrite.HopRewriteUtils;
+
+/**
+ * Class with methods estimating compute costs of operations.
+ */
+public class ComputeCost {
+ private static final Log LOG =
LogFactory.getLog(ComputeCost.class.getName());
+
+ /**
+ * Get compute cost for given HOP based on the number of floating point
operations per output cell
+ * and the total number of output cells.
+ * @param currentHop for which compute cost is returned
+ * @return compute cost of currentHop as number of floating point
operations
+ */
+ public static double getHOPComputeCost(Hop currentHop){
+ double costs = 1;
+ if( currentHop instanceof UnaryOp) {
+ switch( ((UnaryOp)currentHop).getOp() ) {
+ case ABS:
+ case ROUND:
+ case CEIL:
+ case FLOOR:
+ case SIGN: costs = 1; break;
+ case SPROP:
+ case SQRT: costs = 2; break;
+ case EXP: costs = 18; break;
+ case SIGMOID: costs = 21; break;
+ case LOG:
+ case LOG_NZ: costs = 32; break;
+ case NCOL:
+ case NROW:
+ case PRINT:
+ case ASSERT:
+ case CAST_AS_BOOLEAN:
+ case CAST_AS_DOUBLE:
+ case CAST_AS_INT:
+ case CAST_AS_MATRIX:
+ case CAST_AS_SCALAR: costs = 1; break;
+ case SIN: costs = 18; break;
+ case COS: costs = 22; break;
+ case TAN: costs = 42; break;
+ case ASIN: costs = 93; break;
+ case ACOS: costs = 103; break;
+ case ATAN: costs = 40; break;
+ case SINH: costs = 93; break; // TODO:
+ case COSH: costs = 103; break;
+ case TANH: costs = 40; break;
+ case CUMSUM:
+ case CUMMIN:
+ case CUMMAX:
+ case CUMPROD: costs = 1; break;
+ case CUMSUMPROD: costs = 2; break;
+ default:
+ LOG.warn("Cost model not "
+ + "implemented yet for:
"+((UnaryOp)currentHop).getOp());
+ }
+ }
+ else if( currentHop instanceof BinaryOp) {
+ switch( ((BinaryOp)currentHop).getOp() ) {
+ case MULT:
+ case PLUS:
+ case MINUS:
+ case MIN:
+ case MAX:
+ case AND:
+ case OR:
+ case EQUAL:
+ case NOTEQUAL:
+ case LESS:
+ case LESSEQUAL:
+ case GREATER:
+ case GREATEREQUAL:
+ case CBIND:
+ case RBIND: costs = 1; break;
+ case INTDIV: costs = 6; break;
+ case MODULUS: costs = 8; break;
+ case DIV: costs = 22; break;
+ case LOG:
+ case LOG_NZ: costs = 32; break;
+ case POW: costs =
(HopRewriteUtils.isLiteralOfValue(
+ currentHop.getInput().get(1), 2) ? 1 :
16); break;
+ case MINUS_NZ:
+ case MINUS1_MULT: costs = 2; break;
+ case MOMENT:
+ int type = (int)
(currentHop.getInput().get(1) instanceof LiteralOp ?
+
HopRewriteUtils.getIntValueSafe((LiteralOp)currentHop.getInput().get(1)) : 2);
+ switch( type ) {
+ case 0: costs = 1; break;
//count
+ case 1: costs = 8; break; //mean
+ case 2: costs = 16; break; //cm2
+ case 3: costs = 31; break; //cm3
+ case 4: costs = 51; break; //cm4
+ case 5: costs = 16; break;
//variance
+ }
+ break;
+ case COV: costs = 23; break;
+ default:
+ LOG.warn("Cost model not "
+ + "implemented yet for:
"+((BinaryOp)currentHop).getOp());
+ }
+ }
+ else if( currentHop instanceof TernaryOp) {
+ switch( ((TernaryOp)currentHop).getOp() ) {
+ case IFELSE:
+ case PLUS_MULT:
+ case MINUS_MULT: costs = 2; break;
+ case CTABLE: costs = 3; break;
+ case MOMENT:
+ int type = (int)
(currentHop.getInput().get(1) instanceof LiteralOp ?
+
HopRewriteUtils.getIntValueSafe((LiteralOp)currentHop.getInput().get(1)) : 2);
+ switch( type ) {
+ case 0: costs = 2; break;
//count
+ case 1: costs = 9; break; //mean
+ case 2: costs = 17; break; //cm2
+ case 3: costs = 32; break; //cm3
+ case 4: costs = 52; break; //cm4
+ case 5: costs = 17; break;
//variance
+ }
+ break;
+ case COV: costs = 23; break;
+ default:
+ LOG.warn("Cost model not "
+ + "implemented yet for:
"+((TernaryOp)currentHop).getOp());
+ }
+ }
+ else if( currentHop instanceof NaryOp) {
+ costs = HopRewriteUtils.isNary(currentHop,
Types.OpOpN.MIN, Types.OpOpN.MAX, Types.OpOpN.PLUS) ?
+ currentHop.getInput().size() : 1;
+ }
+ else if( currentHop instanceof ParameterizedBuiltinOp) {
+ costs = 1;
+ }
+ else if( currentHop instanceof IndexingOp) {
+ costs = 1;
+ }
+ else if( currentHop instanceof ReorgOp) {
+ costs = 1;
+ }
+ else if( currentHop instanceof DnnOp) {
+ switch( ((DnnOp)currentHop).getOp() ) {
+ case BIASADD:
+ case BIASMULT:
+ costs = 2;
+ default:
+ LOG.warn("Cost model not "
+ + "implemented yet for:
"+((DnnOp)currentHop).getOp());
+ }
+ }
+ else if( currentHop instanceof AggBinaryOp) {
+ //outer product template w/ matrix-matrix
+ //or row template w/ matrix-vector or matrix-matrix
+ costs = 2 * currentHop.getInput().get(0).getDim2();
+ if( currentHop.getInput().get(0).dimsKnown(true) )
+ costs *=
currentHop.getInput().get(0).getSparsity();
+ }
+ else if( currentHop instanceof AggUnaryOp) {
+ switch(((AggUnaryOp)currentHop).getOp()) {
+ case SUM: costs = 4; break;
+ case SUM_SQ: costs = 5; break;
+ case MIN:
+ case MAX: costs = 1; break;
+ default:
+ LOG.warn("Cost model not "
+ + "implemented yet for:
"+((AggUnaryOp)currentHop).getOp());
+ }
+ switch(((AggUnaryOp)currentHop).getDirection()) {
+ case Col: costs *=
Math.max(currentHop.getInput().get(0).getDim1(),1); break;
+ case Row: costs *=
Math.max(currentHop.getInput().get(0).getDim2(),1); break;
+ case RowCol: costs *=
getSize(currentHop.getInput().get(0)); break;
+ }
+ }
+
+ //scale by current output size in order to correctly reflect
+ //a mix of row and cell operations in the same fused operator
+ //(e.g., row template with fused column vector operations)
+ costs *= getSize(currentHop);
+ return costs;
+ }
+
+ /**
+ * Get number of output cells of given hop.
+ * @param hop for which the number of output cells are found
+ * @return number of output cells of given hop
+ */
+ private static long getSize(Hop hop) {
+ return Math.max(hop.getDim1(),1)
+ * Math.max(hop.getDim2(),1);
+ }
+}
diff --git
a/src/main/java/org/apache/sysds/hops/cost/CostEstimationWrapper.java
b/src/main/java/org/apache/sysds/hops/cost/CostEstimationWrapper.java
index f8d6a2d..23fcf51 100644
--- a/src/main/java/org/apache/sysds/hops/cost/CostEstimationWrapper.java
+++ b/src/main/java/org/apache/sysds/hops/cost/CostEstimationWrapper.java
@@ -32,7 +32,6 @@ import
org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
public class CostEstimationWrapper
{
-
public enum CostType {
NUM_MRJOBS, //based on number of MR jobs, [number MR jobs]
STATIC // based on FLOPS, read/write, etc, [time in sec]
@@ -44,17 +43,13 @@ public class CostEstimationWrapper
private static CostEstimator _costEstim = null;
- static
- {
-
+ static {
//create cost estimator
- try
- {
+ try {
//TODO config parameter?
_costEstim = createCostEstimator(DEFAULT_COSTTYPE);
}
- catch(Exception ex)
- {
+ catch(Exception ex) {
LOG.error("Failed cost estimator initialization.", ex);
}
}
@@ -89,5 +84,5 @@ public class CostEstimationWrapper
default:
throw new DMLRuntimeException("Unknown cost
type: "+type);
}
- }
+ }
}
diff --git a/src/main/java/org/apache/sysds/hops/cost/CostEstimator.java
b/src/main/java/org/apache/sysds/hops/cost/CostEstimator.java
index bb8753c..03948d4 100644
--- a/src/main/java/org/apache/sysds/hops/cost/CostEstimator.java
+++ b/src/main/java/org/apache/sysds/hops/cost/CostEstimator.java
@@ -58,9 +58,8 @@ import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
-public abstract class CostEstimator
+public abstract class CostEstimator
{
-
protected static final Log LOG =
LogFactory.getLog(CostEstimator.class.getName());
private static final int DEFAULT_NUMITER = 15;
@@ -84,7 +83,7 @@ public abstract class CostEstimator
public double getTimeEstimate(ProgramBlock pb, LocalVariableMap vars,
HashMap<String,VarStats> stats, boolean recursive) {
//obtain stats from symboltable (e.g., during recompile)
maintainVariableStatistics(vars, stats);
-
+
//get cost estimate
return rGetTimeEstimate(pb, stats, new HashSet<String>(),
recursive);
}
@@ -281,10 +280,8 @@ public abstract class CostEstimator
VarStats[] vs = new VarStats[3];
String[] attr = null;
- if( inst instanceof UnaryCPInstruction )
- {
- if( inst instanceof DataGenCPInstruction )
- {
+ if( inst instanceof UnaryCPInstruction ) {
+ if( inst instanceof DataGenCPInstruction ) {
DataGenCPInstruction rinst =
(DataGenCPInstruction) inst;
vs[0] = _unknownStats;
vs[1] = _unknownStats;
@@ -298,15 +295,13 @@ public abstract class CostEstimator
type = 1;
attr = new String[]{String.valueOf(type)};
}
- else if( inst instanceof StringInitCPInstruction )
- {
+ else if( inst instanceof StringInitCPInstruction ) {
StringInitCPInstruction rinst =
(StringInitCPInstruction) inst;
vs[0] = _unknownStats;
vs[1] = _unknownStats;
vs[2] = stats.get( rinst.output.getName() );
}
- else //general unary
- {
+ else { //general unary
UnaryCPInstruction uinst = (UnaryCPInstruction)
inst;
vs[0] = stats.get( uinst.input1.getName() );
vs[1] = _unknownStats;
@@ -317,69 +312,61 @@ public abstract class CostEstimator
if( vs[2] == null ) //scalar output
vs[2] = _scalarStats;
- if( inst instanceof MMTSJCPInstruction )
- {
+ if( inst instanceof MMTSJCPInstruction ) {
String type =
((MMTSJCPInstruction)inst).getMMTSJType().toString();
attr = new String[]{type};
}
- else if( inst instanceof
AggregateUnaryCPInstruction )
- {
+ else if( inst instanceof
AggregateUnaryCPInstruction ) {
String[] parts =
InstructionUtils.getInstructionParts(inst.toString());
String opcode = parts[0];
if( opcode.equals("cm") )
- attr = new
String[]{parts[parts.length-2]};
- }
+ attr = new
String[]{parts[parts.length-2]};
+ }
}
}
- else if( inst instanceof BinaryCPInstruction )
- {
+ else if( inst instanceof BinaryCPInstruction ) {
BinaryCPInstruction binst = (BinaryCPInstruction) inst;
vs[0] = stats.get( binst.input1.getName() );
vs[1] = stats.get( binst.input2.getName() );
vs[2] = stats.get( binst.output.getName() );
-
- if( vs[0] == null ) //scalar input,
+ if( vs[0] == null ) //scalar input,
vs[0] = _scalarStats;
- if( vs[1] == null ) //scalar input,
+ if( vs[1] == null ) //scalar input,
vs[1] = _scalarStats;
if( vs[2] == null ) //scalar output
vs[2] = _scalarStats;
- }
- else if( inst instanceof AggregateTernaryCPInstruction )
- {
+ }
+ else if( inst instanceof AggregateTernaryCPInstruction ) {
AggregateTernaryCPInstruction binst =
(AggregateTernaryCPInstruction) inst;
//of same dimension anyway but missing third input
- vs[0] = stats.get( binst.input1.getName() );
+ vs[0] = stats.get( binst.input1.getName() );
vs[1] = stats.get( binst.input2.getName() );
vs[2] = stats.get( binst.output.getName() );
- if( vs[0] == null ) //scalar input,
+ if( vs[0] == null ) //scalar input,
vs[0] = _scalarStats;
- if( vs[1] == null ) //scalar input,
+ if( vs[1] == null ) //scalar input,
vs[1] = _scalarStats;
if( vs[2] == null ) //scalar output
vs[2] = _scalarStats;
}
- else if( inst instanceof ParameterizedBuiltinCPInstruction )
- {
+ else if( inst instanceof ParameterizedBuiltinCPInstruction ) {
//ParameterizedBuiltinCPInstruction pinst =
(ParameterizedBuiltinCPInstruction) inst;
String[] parts =
InstructionUtils.getInstructionParts(inst.toString());
String opcode = parts[0];
- if( opcode.equals("groupedagg") )
- {
+ if( opcode.equals("groupedagg") ) {
HashMap<String,String> paramsMap =
ParameterizedBuiltinCPInstruction.constructParameterMap(parts);
String fn = paramsMap.get("fn");
String order = paramsMap.get("order");
AggregateOperationTypes type =
CMOperator.getAggOpType(fn, order);
attr = new
String[]{String.valueOf(type.ordinal())};
}
- else if( opcode.equals("rmempty") )
- {
+ else if( opcode.equals("rmempty") ) {
HashMap<String,String> paramsMap =
ParameterizedBuiltinCPInstruction.constructParameterMap(parts);
attr = new
String[]{String.valueOf(paramsMap.get("margin").equals("rows")?0:1)};
}
-
+
vs[0] = stats.get(
parts[1].substring(7).replaceAll(Lop.VARIABLE_NAME_PLACEHOLDER, "") );
vs[1] = _unknownStats; //TODO
vs[2] = stats.get( parts[parts.length-1] );
@@ -389,16 +376,14 @@ public abstract class CostEstimator
if( vs[2] == null ) //scalar output
vs[2] = _scalarStats;
}
- else if( inst instanceof MultiReturnBuiltinCPInstruction )
- {
+ else if( inst instanceof MultiReturnBuiltinCPInstruction ) {
//applies to qr, lu, eigen (cost computation on input1)
MultiReturnBuiltinCPInstruction minst =
(MultiReturnBuiltinCPInstruction) inst;
vs[0] = stats.get( minst.input1.getName() );
vs[1] = stats.get( minst.getOutput(0).getName() );
vs[2] = stats.get( minst.getOutput(1).getName() );
}
- else if( inst instanceof VariableCPInstruction )
- {
+ else if( inst instanceof VariableCPInstruction ) {
setUnknownStats(vs);
VariableCPInstruction varinst = (VariableCPInstruction)
inst;
@@ -407,11 +392,10 @@ public abstract class CostEstimator
if( stats.containsKey(
varinst.getInput1().getName() ) )
vs[0] = stats.get(
varinst.getInput1().getName() );
attr = new
String[]{varinst.getInput3().getName()};
- }
+ }
}
- else
- {
- setUnknownStats(vs);
+ else {
+ setUnknownStats(vs);
}
//maintain var status (CP output always inmem)
@@ -426,7 +410,7 @@ public abstract class CostEstimator
private static void setUnknownStats(VarStats[] vs) {
vs[0] = _unknownStats;
vs[1] = _unknownStats;
- vs[2] = _unknownStats;
+ vs[2] = _unknownStats;
}
private static long getNumIterations(HashMap<String,VarStats> stats,
ForProgramBlock pb) {
diff --git
a/src/main/java/org/apache/sysds/hops/cost/CostEstimatorStaticRuntime.java
b/src/main/java/org/apache/sysds/hops/cost/CostEstimatorStaticRuntime.java
index e2a4a75..3f29132 100644
--- a/src/main/java/org/apache/sysds/hops/cost/CostEstimatorStaticRuntime.java
+++ b/src/main/java/org/apache/sysds/hops/cost/CostEstimatorStaticRuntime.java
@@ -344,30 +344,30 @@ public class CostEstimatorStaticRuntime extends
CostEstimator
}
return (leftSparse) ? xcm *
(d1m * d1s + 1) : xcm * d1m;
}
- else if( optype.equals("uatrace") ||
optype.equals("uaktrace") )
- return 2 * d1m * d1n;
- else if( optype.equals("ua+") ||
optype.equals("uar+") || optype.equals("uac+") ){
- //sparse safe operations
- if( !leftSparse ) //dense
- return d1m * d1n;
- else //sparse
- return d1m * d1n * d1s;
- }
- else if( optype.equals("uak+") ||
optype.equals("uark+") || optype.equals("uack+"))
- return 4 * d1m * d1n; //1*k+
- else if( optype.equals("uasqk+") ||
optype.equals("uarsqk+") || optype.equals("uacsqk+"))
+ else if( optype.equals("uatrace") ||
optype.equals("uaktrace") )
+ return 2 * d1m * d1n;
+ else if( optype.equals("ua+") ||
optype.equals("uar+") || optype.equals("uac+") ){
+ //sparse safe operations
+ if( !leftSparse ) //dense
+ return d1m * d1n;
+ else //sparse
+ return d1m * d1n * d1s;
+ }
+ else if( optype.equals("uak+") ||
optype.equals("uark+") || optype.equals("uack+"))
+ return 4 * d1m * d1n; //1*k+
+ else if( optype.equals("uasqk+") ||
optype.equals("uarsqk+") || optype.equals("uacsqk+"))
return 5 * d1m * d1n; // +1 for
multiplication to square term
- else if( optype.equals("uamean") ||
optype.equals("uarmean") || optype.equals("uacmean"))
+ else if( optype.equals("uamean") ||
optype.equals("uarmean") || optype.equals("uacmean"))
return 7 * d1m * d1n; //1*k+
- else if( optype.equals("uavar") ||
optype.equals("uarvar") || optype.equals("uacvar"))
+ else if( optype.equals("uavar") ||
optype.equals("uarvar") || optype.equals("uacvar"))
return 14 * d1m * d1n;
- else if( optype.equals("uamax") ||
optype.equals("uarmax") || optype.equals("uacmax")
- || optype.equals("uamin") ||
optype.equals("uarmin") || optype.equals("uacmin")
- || optype.equals("uarimax") ||
optype.equals("ua*") )
- return d1m * d1n;
+ else if( optype.equals("uamax") ||
optype.equals("uarmax") || optype.equals("uacmax")
+ || optype.equals("uamin") ||
optype.equals("uarmin") || optype.equals("uacmin")
+ || optype.equals("uarimax") ||
optype.equals("ua*") )
+ return d1m * d1n;
- return 0;
-
+ return 0;
+
case Binary: //opcodes: +, -, *, /, ^ (incl.
^2, *2),
//max, min, solve, ==, !=, <, >, <=, >=
//note: all relational ops are not
sparsesafe
diff --git a/src/main/java/org/apache/sysds/hops/cost/FederatedCost.java
b/src/main/java/org/apache/sysds/hops/cost/FederatedCost.java
new file mode 100644
index 0000000..f4f8db4
--- /dev/null
+++ b/src/main/java/org/apache/sysds/hops/cost/FederatedCost.java
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.cost;
+
+/**
+ * Class storing execution cost estimates for federated executions with cost
estimates split into different categories
+ * such as compute, read, and transfer cost.
+ */
+public class FederatedCost {
+ protected double _computeCost = 0;
+ protected double _readCost = 0;
+ protected double _inputTransferCost = 0;
+ protected double _outputTransferCost = 0;
+ protected double _inputTotalCost = 0;
+
+ public FederatedCost(){}
+
+ public FederatedCost(double readCost, double inputTransferCost, double
outputTransferCost,
+ double computeCost, double inputTotalCost){
+ _readCost = readCost;
+ _inputTransferCost = inputTransferCost;
+ _outputTransferCost = outputTransferCost;
+ _computeCost = computeCost;
+ _inputTotalCost = inputTotalCost;
+ }
+
+ /**
+ * Get the total sum of costs stored in this object.
+ * @return total cost
+ */
+ public double getTotal(){
+ return _computeCost + _readCost + _inputTransferCost +
_outputTransferCost + _inputTotalCost;
+ }
+
+ /**
+ * Multiply the input costs by the number of times the costs are
repeated.
+ * @param repetitionNumber number of repetitions of the costs
+ */
+ public void addRepetitionCost(int repetitionNumber){
+ _inputTotalCost *= repetitionNumber;
+ }
+
+ /**
+ * Get summed input costs.
+ * @return summed input costs
+ */
+ public double getInputTotalCost(){
+ return _inputTotalCost;
+ }
+
+ public void setInputTotalCost(double inputTotalCost){
+ _inputTotalCost = inputTotalCost;
+ }
+
+ /**
+ * Add cost to the stored input cost.
+ * @param additionalCost to add to total input cost
+ */
+ public void addInputTotalCost(double additionalCost){
+ _inputTotalCost += additionalCost;
+ }
+
+ /**
+ * Add total of federatedCost to stored inputTotalCost.
+ * @param federatedCost input cost from which the total is retrieved
+ */
+ public void addInputTotalCost(FederatedCost federatedCost){
+ _inputTotalCost += federatedCost.getTotal();
+ }
+
+ /**
+ * Add costs of FederatedCost object to this object's current costs.
+ * @param additionalCost object to add to this object
+ */
+ public void addFederatedCost(FederatedCost additionalCost){
+ _readCost += additionalCost._readCost;
+ _inputTransferCost += additionalCost._inputTransferCost;
+ _outputTransferCost += additionalCost._outputTransferCost;
+ _computeCost += additionalCost._computeCost;
+ _inputTotalCost += additionalCost._inputTotalCost;
+ }
+
+ @Override
+ public String toString(){
+ StringBuilder builder = new StringBuilder();
+ builder.append(" computeCost: ");
+ builder.append(_computeCost);
+ builder.append("\n readCost: ");
+ builder.append(_readCost);
+ builder.append("\n inputTransferCost: ");
+ builder.append(_inputTransferCost);
+ builder.append("\n outputTransferCost: ");
+ builder.append(_outputTransferCost);
+ builder.append("\n inputTotalCost: ");
+ builder.append(_inputTotalCost);
+ builder.append("\n total cost: ");
+ builder.append(getTotal());
+ return builder.toString();
+ }
+}
diff --git
a/src/main/java/org/apache/sysds/hops/cost/FederatedCostEstimator.java
b/src/main/java/org/apache/sysds/hops/cost/FederatedCostEstimator.java
new file mode 100644
index 0000000..3e2f994
--- /dev/null
+++ b/src/main/java/org/apache/sysds/hops/cost/FederatedCostEstimator.java
@@ -0,0 +1,214 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.cost;
+
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.parser.DMLProgram;
+import org.apache.sysds.parser.ForStatement;
+import org.apache.sysds.parser.ForStatementBlock;
+import org.apache.sysds.parser.FunctionStatement;
+import org.apache.sysds.parser.FunctionStatementBlock;
+import org.apache.sysds.parser.IfStatement;
+import org.apache.sysds.parser.IfStatementBlock;
+import org.apache.sysds.parser.Statement;
+import org.apache.sysds.parser.StatementBlock;
+import org.apache.sysds.parser.WhileStatement;
+import org.apache.sysds.parser.WhileStatementBlock;
+
+import java.util.ArrayList;
+
+/**
+ * Cost estimator for federated executions with methods and constants for
going through DML programs to estimate costs.
+ */
+public class FederatedCostEstimator {
+ public int DEFAULT_MEMORY_ESTIMATE = 8;
+ public int DEFAULT_ITERATION_NUMBER = 15;
+ public double WORKER_NETWORK_BANDWIDTH_BYTES_PS = 1024*1024*1024;
//Default network bandwidth in bytes per second
+ public double WORKER_COMPUTE_BANDWITH_FLOPS = 2.5*1024*1024*1024;
//Default compute bandwidth in FLOPS
+ public double WORKER_DEGREE_OF_PARALLELISM = 8; //Default number of
parallel processes for workers
+ public double WORKER_READ_BANDWIDTH_BYTES_PS = 3.5*1024*1024*1024;
//Default read bandwidth in bytes per second
+
+ public boolean printCosts = false; //Temporary for debugging purposes
+
+ /**
+ * Estimate cost of given DML program in bytes.
+ * @param dmlProgram for which the cost is estimated
+ * @return federated cost object with cost estimate in bytes
+ */
+ public FederatedCost costEstimate(DMLProgram dmlProgram){
+ FederatedCost programTotalCost = new FederatedCost();
+ for ( StatementBlock stmBlock : dmlProgram.getStatementBlocks()
)
+
programTotalCost.addInputTotalCost(costEstimate(stmBlock).getTotal());
+ return programTotalCost;
+ }
+
+ /**
+ * Cost estimate in bytes of given statement block.
+ * @param sb statement block
+ * @return federated cost object with cost estimate in bytes
+ */
+ private FederatedCost costEstimate(StatementBlock sb){
+ if ( sb instanceof WhileStatementBlock){
+ WhileStatementBlock whileSB = (WhileStatementBlock) sb;
+ FederatedCost whileSBCost =
costEstimate(whileSB.getPredicateHops());
+ for ( Statement statement : whileSB.getStatements() ){
+ WhileStatement whileStatement =
(WhileStatement) statement;
+ for ( StatementBlock bodyBlock :
whileStatement.getBody() )
+
whileSBCost.addInputTotalCost(costEstimate(bodyBlock));
+ }
+ whileSBCost.addRepetitionCost(DEFAULT_ITERATION_NUMBER);
+ return whileSBCost;
+ }
+ else if ( sb instanceof IfStatementBlock){
+ //Get cost of if-block + else-block and divide by two
+ // since only one of the code blocks will be executed
in the end
+ IfStatementBlock ifSB = (IfStatementBlock) sb;
+ FederatedCost ifSBCost = new FederatedCost();
+ for ( Statement statement : ifSB.getStatements() ){
+ IfStatement ifStatement = (IfStatement)
statement;
+ for ( StatementBlock ifBodySB :
ifStatement.getIfBody() )
+
ifSBCost.addInputTotalCost(costEstimate(ifBodySB));
+ for ( StatementBlock elseBodySB :
ifStatement.getElseBody() )
+
ifSBCost.addInputTotalCost(costEstimate(elseBodySB));
+ }
+
ifSBCost.setInputTotalCost(ifSBCost.getInputTotalCost()/2);
+
ifSBCost.addInputTotalCost(costEstimate(ifSB.getPredicateHops()));
+ return ifSBCost;
+ }
+ else if ( sb instanceof ForStatementBlock){
+ // This also includes ParForStatementBlocks
+ ForStatementBlock forSB = (ForStatementBlock) sb;
+ ArrayList<Hop> predicateHops = new ArrayList<>();
+ predicateHops.add(forSB.getFromHops());
+ predicateHops.add(forSB.getToHops());
+ predicateHops.add(forSB.getIncrementHops());
+ FederatedCost forSBCost = costEstimate(predicateHops);
+ for ( Statement statement : forSB.getStatements() ){
+ ForStatement forStatement = (ForStatement)
statement;
+ for ( StatementBlock forStatementBlockBody :
forStatement.getBody() )
+
forSBCost.addInputTotalCost(costEstimate(forStatementBlockBody));
+ }
+ forSBCost.addRepetitionCost(forSB.getEstimateReps());
+ return forSBCost;
+ }
+ else if ( sb instanceof FunctionStatementBlock){
+ FederatedCost funcCost = addInitialInputCost(sb);
+ FunctionStatementBlock funcSB =
(FunctionStatementBlock) sb;
+ for(Statement statement : funcSB.getStatements()) {
+ FunctionStatement funcStatement =
(FunctionStatement) statement;
+ for ( StatementBlock funcStatementBody :
funcStatement.getBody() )
+
funcCost.addInputTotalCost(costEstimate(funcStatementBody));
+ }
+ return funcCost;
+ }
+ else {
+ // StatementBlock type (no subclass)
+ return costEstimate(sb.getHops());
+ }
+ }
+
+ /**
+ * Creates new FederatedCost object and adds all child statement block
cost estimates to the object.
+ * @param sb statement block
+ * @return new FederatedCost estimate object with all estimates of
child statement blocks added
+ */
+ private FederatedCost addInitialInputCost(StatementBlock sb){
+ FederatedCost basicCost = new FederatedCost();
+ for ( StatementBlock childSB :
sb.getDMLProg().getStatementBlocks() )
+
basicCost.addInputTotalCost(costEstimate(childSB).getTotal());
+ return basicCost;
+ }
+
+ /**
+ * Cost estimate in bytes of given list of roots.
+ * The individual cost estimates of the hops are summed.
+ * @param roots list of hops
+ * @return new FederatedCost object with sum of cost estimates of given
hops
+ */
+ private FederatedCost costEstimate(ArrayList<Hop> roots){
+ FederatedCost basicCost = new FederatedCost();
+ for ( Hop root : roots )
+ basicCost.addInputTotalCost(costEstimate(root));
+ return basicCost;
+ }
+
+ /**
+ * Return cost estimate in bytes of Hop DAG starting from given root.
+ * @param root of Hop DAG for which cost is estimated
+ * @return cost estimation of Hop DAG starting from given root
+ */
+ private FederatedCost costEstimate(Hop root){
+ if ( root.federatedCostInitialized() )
+ return root.getFederatedCost();
+ else {
+ // If no input has FOUT, the root will be processed by
the coordinator
+ boolean hasFederatedInput = root.someInputFederated();
+ //the input cost is included the first time the input
hop is used
+ //for additional usage, the additional cost is zero
(disregarding potential read cost)
+ double inputCosts = root.getInput().stream()
+ .mapToDouble( in ->
in.federatedCostInitialized() ? 0 : costEstimate(in).getTotal() )
+ .sum();
+ double inputTransferCost = hasFederatedInput ?
root.getInput().stream()
+ .filter(Hop::hasLocalOutput)
+ .mapToDouble(in ->
in.getOutputMemEstimate(DEFAULT_MEMORY_ESTIMATE))
+ .map(inMem -> inMem/
WORKER_NETWORK_BANDWIDTH_BYTES_PS)
+ .sum() : 0;
+ double computingCost =
ComputeCost.getHOPComputeCost(root);
+ if ( hasFederatedInput ){
+ //Find the number of inputs that has FOUT set.
+ int numWorkers =
(int)root.getInput().stream().filter(Hop::hasFederatedOutput).count();
+ //divide memory usage by the number of workers
the computation would be split to multiplied by
+ //the number of parallel processes at each
worker multiplied by the FLOPS of each process
+ //This assumes uniform workload among the
workers with FOUT data involved in the operation
+ //and assumes that the degree of parallelism
and compute bandwidth are equal for all workers
+ computingCost = computingCost /
(numWorkers*WORKER_DEGREE_OF_PARALLELISM*WORKER_COMPUTE_BANDWITH_FLOPS);
+ } else computingCost = computingCost /
(WORKER_DEGREE_OF_PARALLELISM*WORKER_COMPUTE_BANDWITH_FLOPS);
+ //Calculate output transfer cost if the operation is
computed at federated workers and the output is forced to the coordinator
+ double outputTransferCost = ( root.hasLocalOutput() &&
hasFederatedInput ) ?
+
root.getOutputMemEstimate(DEFAULT_MEMORY_ESTIMATE) /
WORKER_NETWORK_BANDWIDTH_BYTES_PS : 0;
+ double readCost =
root.getInputMemEstimate(DEFAULT_MEMORY_ESTIMATE) /
WORKER_READ_BANDWIDTH_BYTES_PS;
+
+ FederatedCost rootFedCost =
+ new FederatedCost(readCost, inputTransferCost,
outputTransferCost, computingCost, inputCosts);
+ root.setFederatedCost(rootFedCost);
+
+ if ( printCosts )
+ printCosts(root);
+
+ return rootFedCost;
+ }
+ }
+
+ /**
+ * Prints costs and information about root for debugging purposes
+ * @param root hop for which information is printed
+ */
+ private static void printCosts(Hop root){
+ System.out.println("===============================");
+ System.out.println(root);
+ System.out.println("Is federated: " + root.isFederated());
+ System.out.println("Has federated output: " +
root.hasFederatedOutput());
+ System.out.println(root.getText());
+ System.out.println("Pure computeCost: " +
ComputeCost.getHOPComputeCost(root));
+ System.out.println("Dim1: " + root.getDim1() + " Dim2: " +
root.getDim2());
+ System.out.println(root.getFederatedCost().toString());
+ System.out.println("===============================");
+ }
+}
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
index 2e3edb0..04cdf32 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
@@ -140,6 +140,7 @@ public class ProgramRewriter
}
if ( OptimizerUtils.FEDERATED_COMPILATION ) {
_dagRuleSet.add( new
RewriteFederatedExecution() );
+ _sbRuleSet.add( new
RewriteFederatedStatementBlocks() );
}
}
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedExecution.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedExecution.java
index 29cda4a..e6a92ce 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedExecution.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedExecution.java
@@ -23,9 +23,14 @@ import org.apache.commons.lang3.tuple.Pair;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.sysds.api.DMLException;
+import org.apache.sysds.hops.AggBinaryOp;
+import org.apache.sysds.hops.AggUnaryOp;
+import org.apache.sysds.hops.BinaryOp;
import org.apache.sysds.hops.DataOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.LiteralOp;
+import org.apache.sysds.hops.ReorgOp;
+import org.apache.sysds.hops.TernaryOp;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
@@ -36,6 +41,7 @@ import
org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
import
org.apache.sysds.runtime.controlprogram.federated.FederatedWorkerHandlerException;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
+import
org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
import org.apache.sysds.runtime.instructions.fed.InitFEDInstruction;
import org.apache.sysds.runtime.io.IOUtilFunctions;
import org.apache.sysds.runtime.lineage.LineageItem;
@@ -52,6 +58,9 @@ import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.UnknownHostException;
import java.util.ArrayList;
+import java.util.Collections;
+import java.util.EnumMap;
+import java.util.Map;
import java.util.concurrent.Future;
public class RewriteFederatedExecution extends HopRewriteRule {
@@ -61,18 +70,115 @@ public class RewriteFederatedExecution extends
HopRewriteRule {
return null;
for ( Hop root : roots )
visitHop(root);
+
+ return selectFederatedExecutionPlan(roots);
+ }
+
+ @Override public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus
state) {
+ return null;
+ }
+
+ /**
+ * Select federated execution plan for every Hop in the DAG starting
from given roots.
+ * @param roots starting point for going through the Hop DAG to update
the FederatedOutput fields.
+ * @return the list of roots with updated FederatedOutput fields.
+ */
+ private static ArrayList<Hop>
selectFederatedExecutionPlan(ArrayList<Hop> roots){
+ for (Hop root : roots){
+ root.resetVisitStatus();
+ }
+ for ( Hop root : roots ){
+ visitFedPlanHop(root);
+ }
return roots;
}
- @Override
- public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) {
- if( root == null )
- return null;
- visitHop(root);
- return root;
+ /**
+ * Go through the Hop DAG and set the FederatedOutput field for each
Hop from leaf to given currentHop.
+ * @param currentHop the Hop from which the DAG is visited
+ */
+ private static void visitFedPlanHop(Hop currentHop){
+ if ( currentHop.isVisited() )
+ return;
+ if ( currentHop.getInput() != null &&
currentHop.getInput().size() > 0 && !isFederatedDataOp(currentHop) ){
+ // Depth first to get to the input
+ for ( Hop input : currentHop.getInput() )
+ visitFedPlanHop(input);
+ } else if ( isFederatedDataOp(currentHop) ) {
+ // leaf federated node
+ //TODO: This will block the cases where the federated
DataOp is based on input that are also federated.
+ // This means that the actual federated leaf nodes will
never be reached.
+ currentHop.setFederatedOutput(FederatedOutput.FOUT);
+ }
+ if ( ( isFedInstSupportedHop(currentHop) ) ){
+ // The Hop can be FOUT or LOUT or None. Check utility
of FOUT vs LOUT vs None.
+
currentHop.setFederatedOutput(getHighestUtilFedOut(currentHop));
+ }
+ else
+
currentHop.setFederatedOutput(FEDInstruction.FederatedOutput.NONE);
+ currentHop.setVisited();
+ }
+
+ /**
+ * Returns the FederatedOutput with the highest utility out of the
valid FederatedOutput values.
+ * @param hop for which the utility is found
+ * @return the FederatedOutput value with highest utility for the given
Hop
+ */
+ private static FederatedOutput getHighestUtilFedOut(Hop hop){
+ Map<FederatedOutput,Long> fedOutUtilMap = new
EnumMap<>(FederatedOutput.class);
+ if ( isFOUTSupported(hop) )
+ fedOutUtilMap.put(FederatedOutput.FOUT, getUtilFout());
+ if ( hop.getPrivacy() == null || (hop.getPrivacy() != null &&
!hop.getPrivacy().hasConstraints()) )
+ fedOutUtilMap.put(FederatedOutput.LOUT,
getUtilLout(hop));
+ fedOutUtilMap.put(FederatedOutput.NONE, 0L);
+
+ Map.Entry<FederatedOutput, Long> fedOutMax =
Collections.max(fedOutUtilMap.entrySet(), Map.Entry.comparingByValue());
+ return fedOutMax.getKey();
}
-
- private void visitHop(Hop hop){
+
+ /**
+ * Utility if hop is FOUT. This is a simple version where it always
returns 1.
+ * @return utility if hop is FOUT
+ */
+ private static long getUtilFout(){
+ //TODO: Make better utility estimation
+ return 1;
+ }
+
+ /**
+ * Utility if hop is LOUT. This is a simple version only based on
dimensions.
+ * @param hop for which utility is calculated
+ * @return utility if hop is LOUT
+ */
+ private static long getUtilLout(Hop hop){
+ //TODO: Make better utility estimation
+ return -(long)hop.getMemEstimate();
+ }
+
+ private static boolean isFedInstSupportedHop(Hop hop){
+
+ // Check that some input is FOUT, otherwise none of the fed
instructions will run unless it is fedinit
+ if ( (!isFederatedDataOp(hop)) &&
hop.getInput().stream().noneMatch(Hop::hasFederatedOutput) )
+ return false;
+
+ // The following operations are supported given that the above
conditions have not returned already
+ return ( hop instanceof AggBinaryOp || hop instanceof BinaryOp
|| hop instanceof ReorgOp
+ || hop instanceof AggUnaryOp || hop instanceof
TernaryOp || hop instanceof DataOp );
+ }
+
+ /**
+ * Checks to see if the associatedHop supports FOUT.
+ * @param associatedHop for which FOUT support is checked
+ * @return true if FOUT is supported by the associatedHop
+ */
+ private static boolean isFOUTSupported(Hop associatedHop){
+ // If the output of AggUnaryOp is a scalar, the operation
cannot be FOUT
+ if ( associatedHop instanceof AggUnaryOp )
+ return !associatedHop.isScalar();
+ return true;
+ }
+
+ private static void visitHop(Hop hop){
if (hop.isVisited())
return;
@@ -84,15 +190,6 @@ public class RewriteFederatedExecution extends
HopRewriteRule {
hop.setVisited();
}
- private static void privacyBasedHopDecision(Hop hop){
- PrivacyPropagator.hopPropagation(hop);
- PrivacyConstraint privacyConstraint = hop.getPrivacy();
- if ( privacyConstraint != null &&
privacyConstraint.hasConstraints() )
-
hop.setFederatedOutput(FEDInstruction.FederatedOutput.FOUT);
- else if ( hop.someInputFederated() )
-
hop.setFederatedOutput(FEDInstruction.FederatedOutput.LOUT);
- }
-
/**
* Get privacy constraints of DataOps from federated worker,
* propagate privacy constraints from input to current hop,
@@ -101,7 +198,7 @@ public class RewriteFederatedExecution extends
HopRewriteRule {
*/
private static void privacyBasedHopDecisionWithFedCall(Hop hop){
loadFederatedPrivacyConstraints(hop);
- privacyBasedHopDecision(hop);
+ PrivacyPropagator.hopPropagation(hop);
}
/**
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedStatementBlocks.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedStatementBlocks.java
new file mode 100644
index 0000000..18b36d5
--- /dev/null
+++
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedStatementBlocks.java
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.rewrite;
+
+import org.apache.sysds.parser.StatementBlock;
+
+import java.util.Arrays;
+import java.util.List;
+
+public class RewriteFederatedStatementBlocks extends StatementBlockRewriteRule
{
+
+ /**
+ * Indicates if the rewrite potentially splits dags, which is used
+ * for phase ordering of rewrites.
+ *
+ * @return true if dag splits are possible.
+ */
+ @Override public boolean createsSplitDag() {
+ return false;
+ }
+
+ /**
+ * Handle an arbitrary statement block. Specific type constraints have
to be ensured
+ * within the individual rewrites. If a rewrite does not apply to
individual blocks, it
+ * should simply return the input block.
+ *
+ * @param sb statement block
+ * @param state program rewrite status
+ * @return list of statement blocks
+ */
+ @Override
+ public List<StatementBlock> rewriteStatementBlock(StatementBlock sb,
ProgramRewriteStatus state) {
+ return Arrays.asList(sb);
+ }
+
+ /**
+ * Handle a list of statement blocks. Specific type constraints have to
be ensured
+ * within the individual rewrites. If a rewrite does not require
sequence access, it
+ * should simply return the input list of statement blocks.
+ *
+ * @param sbs list of statement blocks
+ * @param state program rewrite status
+ * @return list of statement blocks
+ */
+ @Override
+ public List<StatementBlock> rewriteStatementBlocks(List<StatementBlock>
sbs, ProgramRewriteStatus state) {
+ return sbs;
+ }
+}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
b/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
index bae38a2..755287a 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
@@ -64,6 +64,7 @@ public class FEDInstructionParser extends InstructionParser
String2FEDInstructionType.put( "r'" , FEDType.Reorg );
String2FEDInstructionType.put( "rdiag" , FEDType.Reorg );
String2FEDInstructionType.put( "rshape" , FEDType.Reorg );
+ String2FEDInstructionType.put( "rev" , FEDType.Reorg );
// Ternary Instruction Opcodes
String2FEDInstructionType.put( "+*" , FEDType.Ternary);
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
index fb0647e..2e5366e 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
@@ -25,6 +25,7 @@ import org.apache.sysds.common.Types.ExecType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
@@ -124,9 +125,60 @@ public class AggregateUnaryFEDInstruction extends
UnaryFEDInstruction {
new CPOperand[]{input1}, new
long[]{in.getFedMapping().getID()}, true);
map.execute(getTID(), fr1);
- // derive new fed mapping for output
MatrixObject out = ec.getMatrixObject(output);
-
out.setFedMapping(in.getFedMapping().copyWithNewID(fr1.getID()));
+ deriveNewOutputFedMapping(in, out, fr1);
+ }
+
+ /**
+ * Set output fed mapping based on federated partitioning and
aggregation type.
+ * @param in matrix object from which fed partitioning originates from
+ * @param out matrix object holding the dimensions of the instruction
output
+ * @param fr1 federated request holding the instruction execution call
+ */
+ private void deriveNewOutputFedMapping(MatrixObject in, MatrixObject
out, FederatedRequest fr1){
+ //Get agg type
+ if ( !(instOpcode.equals("uack+") ||
instOpcode.equals("uark+")) )
+ throw new DMLRuntimeException("Operation " + instOpcode
+ " is unknown to FOUT processing");
+ boolean isColAgg = instOpcode.equals("uack+");
+ //Get partition type
+ FederationMap.FType inFtype = in.getFedMapping().getType();
+ //Get fedmap from in
+ FederationMap inputFedMapCopy =
in.getFedMapping().copyWithNewID(fr1.getID());
+
+ //if partition type is row and aggregation type is row
+ // then get row dim split from input and use as row dimension
and get col dimension from output col dimension
+ // and set FType to ROW
+ if ( inFtype.isRowPartitioned() && !isColAgg ){
+ for ( FederatedRange range :
inputFedMapCopy.getFederatedRanges() )
+ range.setEndDim(1,out.getNumColumns());
+ inputFedMapCopy.setType(FederationMap.FType.ROW);
+ }
+ //if partition type is row and aggregation type is col
+ // then get row and col dimension from out and use those
dimensions for both federated workers
+ // and set FType to PART
+ //if partition type is col and aggregation type is row
+ // then set row and col dimension from out and use those
dimensions for both federated workers
+ // and set FType to PART
+ if ( (inFtype.isRowPartitioned() && isColAgg) ||
(inFtype.isColPartitioned() && !isColAgg) ){
+ for ( FederatedRange range :
inputFedMapCopy.getFederatedRanges() ){
+ range.setBeginDim(0,0);
+ range.setBeginDim(1,0);
+ range.setEndDim(0,out.getNumRows());
+ range.setEndDim(1,out.getNumColumns());
+ }
+ inputFedMapCopy.setType(FederationMap.FType.PART);
+ }
+ //if partition type is col and aggregation type is col
+ // then set row dimension to output and col dimension to in
col split
+ // and set FType to COL
+ if ( inFtype.isColPartitioned() && isColAgg ){
+ for ( FederatedRange range :
inputFedMapCopy.getFederatedRanges() )
+ range.setEndDim(0,out.getNumRows());
+ inputFedMapCopy.setType(FederationMap.FType.COL);
+ }
+
+ //set out fedmap in the end
+ out.setFedMapping(inputFedMapCopy);
}
/**
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AppendFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AppendFEDInstruction.java
index 825b984..c2e7ab1 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AppendFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AppendFEDInstruction.java
@@ -130,8 +130,9 @@ public class AppendFEDInstruction extends
BinaryFEDInstruction {
}
else {
throw new DMLRuntimeException("Unsupported federated
append: "
- + (mo1.isFederated() ?
mo1.getFedMapping().getType().name():"LOCAL") + " "
- + (mo2.isFederated() ?
mo2.getFedMapping().getType().name():"LOCAL") + " " + _cbind);
+ + " input 1 FType is " + (mo1.isFederated() ?
mo1.getFedMapping().getType().name():"LOCAL")
+ + ", input 2 FType is " + (mo2.isFederated() ?
mo2.getFedMapping().getType().name():"LOCAL")
+ + ", and column bind is " + _cbind);
}
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
index 2a1cbb1..8795308 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
@@ -217,10 +217,10 @@ public class CtableFEDInstruction extends
ComputationFEDInstruction {
* @param mo2 input matrix object mo2
* @return boolean indicating if the output can be kept on the
federated sites
*/
- private boolean isFedOutput(FederationMap fedMap, MatrixObject mo2) {
+ private static boolean isFedOutput(FederationMap fedMap, MatrixObject
mo2) {
MatrixBlock mb = mo2.acquireReadAndRelease();
FederatedRange[] fedRanges = fedMap.getFederatedRanges(); //
federated ranges of mo1
- SortedMap<Double, Double> fedDims = new TreeMap<Double,
Double>(); // <beginDim, endDim>
+ SortedMap<Double, Double> fedDims = new TreeMap<>(); //
<beginDim, endDim>
// collect min and max of the corresponding slices of mo2
IntStream.range(0, fedRanges.length).forEach(i -> {
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
index f35030f..0e00faa 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
@@ -60,6 +60,9 @@ public abstract class FEDInstruction extends Instruction {
public boolean isForcedLocal() {
return this == LOUT;
}
+ public boolean isForced(){
+ return this == FOUT || this == LOUT;
+ }
}
protected final FEDType _fedType;
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
index 19847a2..cb3074f 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
@@ -24,6 +24,7 @@ import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
+import java.util.concurrent.Future;
import java.util.stream.Stream;
import org.apache.commons.lang3.tuple.Pair;
@@ -52,7 +53,8 @@ import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
public class ReorgFEDInstruction extends UnaryFEDInstruction {
-
+ private static boolean fedoutFlagInString = false;
+
public ReorgFEDInstruction(Operator op, CPOperand in1, CPOperand out,
String opcode, String istr, FederatedOutput fedOut) {
super(FEDType.Reorg, op, in1, out, opcode, istr, fedOut);
}
@@ -80,6 +82,7 @@ public class ReorgFEDInstruction extends UnaryFEDInstruction {
return new ReorgFEDInstruction(new
ReorgOperator(DiagIndex.getDiagIndexFnObject()), in, out, opcode, str);
}
else if ( opcode.equalsIgnoreCase("rev") ) {
+ fedoutFlagInString = parts.length > 3;
parseUnaryInstruction(str, in, out); //max 2 operands
return new ReorgFEDInstruction(new
ReorgOperator(RevIndex.getRevIndexFnObject()), in, out, opcode, str);
}
@@ -96,24 +99,34 @@ public class ReorgFEDInstruction extends
UnaryFEDInstruction {
if( !mo1.isFederated() )
throw new DMLRuntimeException("Federated Reorg: "
+ "Federated input expected, but invoked w/
"+mo1.isFederated());
+ if ( !( mo1.isFederated(FederationMap.FType.COL) ||
mo1.isFederated(FederationMap.FType.ROW)) )
+ throw new DMLRuntimeException("Federation type " +
mo1.getFedMapping().getType()
+ + " is not supported for Reorg processing");
if(instOpcode.equals("r'")) {
//execute transpose at federated site
FederatedRequest fr1 =
FederationUtils.callInstruction(instString,
output, new CPOperand[] {input1},
new long[] {mo1.getFedMapping().getID()}, true);
- mo1.getFedMapping().execute(getTID(), true, fr1);
+ if (_fedOut != null && !_fedOut.isForcedLocal()){
+ mo1.getFedMapping().execute(getTID(), true,
fr1);
- //drive output federated mapping
- MatrixObject out = ec.getMatrixObject(output);
- out.getDataCharacteristics().set(mo1.getNumColumns(),
mo1.getNumRows(), (int) mo1.getBlocksize(), mo1.getNnz());
-
out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr1.getID()).transpose());
+ //drive output federated mapping
+ MatrixObject out = ec.getMatrixObject(output);
+
out.getDataCharacteristics().set(mo1.getNumColumns(), mo1.getNumRows(), (int)
mo1.getBlocksize(), mo1.getNnz());
+
out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr1.getID()).transpose());
+ } else {
+ FederatedRequest getRequest = new
FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr1.getID());
+ Future<FederatedResponse>[] execResponse =
mo1.getFedMapping().execute(getTID(), true, fr1, getRequest);
+ ec.setMatrixOutput(output.getName(),
+ FederationUtils.bind(execResponse,
mo1.isFederated(FederationMap.FType.COL)));
+ }
}
else if(instOpcode.equalsIgnoreCase("rev")) {
//execute transpose at federated site
FederatedRequest fr1 =
FederationUtils.callInstruction(instString,
output, new CPOperand[] {input1},
- new long[] {mo1.getFedMapping().getID()}, true);
+ new long[] {mo1.getFedMapping().getID()},
fedoutFlagInString);
mo1.getFedMapping().execute(getTID(), true, fr1);
if(mo1.isFederated(FederationMap.FType.ROW))
@@ -123,6 +136,11 @@ public class ReorgFEDInstruction extends
UnaryFEDInstruction {
MatrixObject out = ec.getMatrixObject(output);
out.getDataCharacteristics().set(mo1.getNumRows(),
mo1.getNumColumns(), (int) mo1.getBlocksize(), mo1.getNnz());
out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr1.getID()));
+
+ if ( _fedOut != null && _fedOut.isForcedLocal() ){
+ out.acquireReadAndRelease();
+ out.getFedMapping().cleanup(getTID(),
fr1.getID());
+ }
}
else if (instOpcode.equals("rdiag")) {
RdiagResult result;
@@ -158,6 +176,10 @@ public class ReorgFEDInstruction extends
UnaryFEDInstruction {
.set(diagFedMap.getMaxIndexInRange(0),
diagFedMap.getMaxIndexInRange(1),
(int) mo1.getBlocksize());
rdiag.setFedMapping(diagFedMap);
+ if ( _fedOut != null && _fedOut.isForcedLocal() ){
+ rdiag.acquireReadAndRelease();
+ rdiag.getFedMapping().cleanup(getTID(),
rdiag.getFedMapping().getID());
+ }
}
}
diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
index 3d96cd9..3ac462c 100644
--- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
+++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
@@ -36,6 +36,7 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Properties;
+import java.util.Set;
import java.util.stream.Collectors;
import org.apache.commons.io.FileUtils;
@@ -2097,6 +2098,31 @@ public abstract class AutomatedTestBase {
return false;
}
+ /**
+ * Checks if given strings are all in the set of heavy hitters.
+ * @param str opcodes for which it is checked if all are in the heavy
hitters
+ * @return true if all given strings are in the set of heavy hitters
+ */
+ protected boolean heavyHittersContainsAllString(String... str){
+ Set<String> heavyHitters = Statistics.getCPHeavyHitterOpCodes();
+ return Arrays.stream(str).allMatch(heavyHitters::contains);
+ }
+
+ /**
+ * Returns an array of the given opcodes which are not present in the
set of heavy hitter opcodes.
+ * @param opcodes for which it is checked if they are among the heavy
hitters
+ * @return array of opcodes not found in heavy hitters
+ */
+ protected String[] missingHeavyHitters(String... opcodes){
+ Set<String> heavyHitters = Statistics.getCPHeavyHitterOpCodes();
+ List<String> missingHeavyHitters = new ArrayList<>();
+ for (String opcode : opcodes){
+ if ( !heavyHitters.contains(opcode) )
+ missingHeavyHitters.add(opcode);
+ }
+ return missingHeavyHitters.toArray(new String[0]);
+ }
+
protected boolean heavyHittersContainsString(String str, int minCount) {
int count = 0;
for(String opcode : Statistics.getCPHeavyHitterOpCodes())
diff --git
a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedCostEstimatorTest.java
b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedCostEstimatorTest.java
new file mode 100644
index 0000000..0092a3a
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedCostEstimatorTest.java
@@ -0,0 +1,279 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.privacy.fedplanning;
+
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.conf.ConfigurationManager;
+import org.apache.sysds.conf.DMLConfig;
+import org.apache.sysds.hops.AggBinaryOp;
+import org.apache.sysds.hops.BinaryOp;
+import org.apache.sysds.hops.DataOp;
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.hops.LiteralOp;
+import org.apache.sysds.hops.NaryOp;
+import org.apache.sysds.hops.ReorgOp;
+import org.apache.sysds.hops.cost.FederatedCost;
+import org.apache.sysds.hops.cost.FederatedCostEstimator;
+import org.apache.sysds.parser.DMLProgram;
+import org.apache.sysds.parser.DMLTranslator;
+import org.apache.sysds.parser.LanguageException;
+import org.apache.sysds.parser.ParserFactory;
+import org.apache.sysds.parser.ParserWrapper;
+import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.io.FileNotFoundException;
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Set;
+
+import static org.apache.sysds.common.Types.OpOp2.MULT;
+
+public class FederatedCostEstimatorTest extends AutomatedTestBase {
+
+ private static final String TEST_DIR = "functions/privacy/fedplanning/";
+ private static final String HOME = SCRIPT_DIR + TEST_DIR;
+ private static final String TEST_CLASS_DIR = TEST_DIR +
FederatedCostEstimatorTest.class.getSimpleName() + "/";
+ FederatedCostEstimator fedCostEstimator = new FederatedCostEstimator();
+
+ @Override
+ public void setUp() {}
+
+ @Test
+ public void simpleBinary() {
+ fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS = 2;
+ fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
+
+ /*
+ * HOP Occurences ComputeCost
ReadCost ComputeCostFinal ReadCostFinal
+ *
------------------------------------------------------------------------------------------
+ * LiteralOp 16 1
0 0.0625 0
+ * DataGenOp 2 100
64 6.25 6.4
+ * BinaryOp 1 100
1600 6.25 160
+ * TOSTRING 1 1
800 0.0625
80
+ * UnaryOp 1 1
8 0.0625
0.8
+ */
+ double computeCost = (16+2*100+100+1+1) /
(fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS*fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+ double readCost = (2*64+1600+800+8) /
(fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
+
+ double expectedCost = computeCost + readCost;
+ runTest("BinaryCostEstimatorTest.dml", false, expectedCost);
+ }
+
+ @Test
+ public void ifElseTest(){
+ fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS = 2;
+ fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
+ double computeCost = (16+2*100+100+1+1) /
(fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS*fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+ double readCost = (2*64+1600+800+8) /
(fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
+ double expectedCost = ((computeCost + readCost + 0.8 + 0.0625 +
0.0625) / 2) + 0.0625 + 0.8 + 0.0625;
+ runTest("IfElseCostEstimatorTest.dml", false, expectedCost);
+ }
+
+ @Test
+ public void whileTest(){
+ fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS = 2;
+ fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
+ double computeCost = (16+2*100+100+1+1) /
(fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS*fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+ double readCost = (2*64+1600+800+8) /
(fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
+ double expectedCost = (computeCost + readCost + 0.0625) *
fedCostEstimator.DEFAULT_ITERATION_NUMBER + 0.0625 + 0.8;
+ runTest("WhileCostEstimatorTest.dml", false, expectedCost);
+ }
+
+ @Test
+ public void forLoopTest(){
+ fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS = 2;
+ fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
+ double computeCost = (16+2*100+100+1+1) /
(fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS*fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+ double readCost = (2*64+1600+800+8) /
(fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
+ double predicateCost = 0.0625 + 0.8 + 0.0625 + 0.0625 + 0.8 +
0.0625 + 0.0625 + 0.8 + 0.0625;
+ double expectedCost = (computeCost + readCost + predicateCost)
* 5;
+ runTest("ForLoopCostEstimatorTest.dml", false, expectedCost);
+ }
+
+ @Test
+ public void parForLoopTest(){
+ fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS = 2;
+ fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
+ double computeCost = (16+2*100+100+1+1) /
(fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS*fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+ double readCost = (2*64+1600+800+8) /
(fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
+ double predicateCost = 0.0625 + 0.8 + 0.0625 + 0.0625 + 0.8 +
0.0625 + 0.0625 + 0.8 + 0.0625;
+ double expectedCost = (computeCost + readCost + predicateCost)
* 5;
+ runTest("ParForLoopCostEstimatorTest.dml", false, expectedCost);
+ }
+
+ @Test
+ public void functionTest(){
+ fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS = 2;
+ fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
+ double computeCost = (16+2*100+100+1+1) /
(fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS*fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+ double readCost = (2*64+1600+800+8) /
(fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
+ double expectedCost = (computeCost + readCost);
+ runTest("FunctionCostEstimatorTest.dml", false, expectedCost);
+ }
+
+ @Test
+ public void federatedMultiply() {
+ fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS = 2;
+ fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
+ fedCostEstimator.WORKER_NETWORK_BANDWIDTH_BYTES_PS = 5;
+
+ double literalOpCost = 10*0.0625;
+ double naryOpCostSpecial = (0.125+2.2);
+ double naryOpCostSpecial2 = (0.25+6.4);
+ double naryOpCost = 4*(0.125+1.6);
+ double reorgOpCost = 6250+80015.2+160030.4;
+ double binaryOpMultCost = 3125+160000;
+ double aggBinaryOpCost = 125000+160015.2+160030.4+190.4;
+ double dataOpCost = 2*(6250+5.6);
+ double dataOpWriteCost = 6.25+100.3;
+
+ double expectedCost = literalOpCost + naryOpCost +
naryOpCostSpecial + naryOpCostSpecial2 + reorgOpCost
+ + binaryOpMultCost + aggBinaryOpCost + dataOpCost +
dataOpWriteCost;
+ runTest("FederatedMultiplyCostEstimatorTest.dml", false,
expectedCost);
+
+ double aggBinaryActualCost = hops.stream()
+ .filter(hop -> hop instanceof AggBinaryOp)
+ .mapToDouble(aggHop ->
aggHop.getFederatedCost().getTotal()-aggHop.getFederatedCost().getInputTotalCost())
+ .sum();
+ Assert.assertEquals(aggBinaryOpCost, aggBinaryActualCost,
0.0001);
+
+ double writeActualCost = hops.stream()
+ .filter(hop -> hop instanceof DataOp)
+ .mapToDouble(writeHop ->
writeHop.getFederatedCost().getTotal()-writeHop.getFederatedCost().getInputTotalCost())
+ .sum();
+ Assert.assertEquals(dataOpWriteCost+dataOpCost,
writeActualCost, 0.0001);
+ }
+
+ Set<Hop> hops = new HashSet<>();
+
+ /**
+ * Recursively adds the hop and its inputs to the set of hops.
+ * @param hop root to be added to set of hops
+ */
+ private void addHop(Hop hop){
+ hops.add(hop);
+ for(Hop inHop : hop.getInput()){
+ addHop(inHop);
+ }
+ }
+
+ /**
+ * Sets dimensions of federated X and Y and sets binary multiplication
to FOUT.
+ * @param prog dml program where the HOPS are modified
+ */
+ private void modifyFedouts(DMLProgram prog){
+ prog.getStatementBlocks().forEach(stmBlock ->
stmBlock.getHops().forEach(this::addHop));
+ hops.forEach(hop -> {
+ if ( hop instanceof DataOp || (hop instanceof BinaryOp
&& ((BinaryOp) hop).getOp() == MULT ) ){
+
hop.setFederatedOutput(FEDInstruction.FederatedOutput.FOUT);
+ hop.setExecType(Types.ExecType.FED);
+ } else {
+
hop.setFederatedOutput(FEDInstruction.FederatedOutput.LOUT);
+ }
+ if ( hop.getOpString().equals("Fed Y") ||
hop.getOpString().equals("Fed X") ){
+ hop.setDim1(10000);
+ hop.setDim2(10);
+ }
+ });
+ }
+
+ @SuppressWarnings("unused")
+ private void printHopsInfo(){
+ //LiteralOp
+ long literalCount = hops.stream().filter(hop -> hop instanceof
LiteralOp).count();
+ System.out.println("LiteralOp Count: " + literalCount);
+ //NaryOp
+ long naryCount = hops.stream().filter(hop -> hop instanceof
NaryOp).count();
+ System.out.println("NaryOp Count " + naryCount);
+ //ReorgOp
+ long reorgCount = hops.stream().filter(hop -> hop instanceof
ReorgOp).count();
+ System.out.println("ReorgOp Count: " + reorgCount);
+ //BinaryOp
+ long binaryCount = hops.stream().filter(hop -> hop instanceof
BinaryOp).count();
+ System.out.println("Binary count: " + binaryCount);
+ //AggBinaryOp
+ long aggBinaryCount = hops.stream().filter(hop -> hop
instanceof AggBinaryOp).count();
+ System.out.println("AggBinaryOp Count: " + aggBinaryCount);
+ //DataOp
+ long dataOpCount = hops.stream().filter(hop -> hop instanceof
DataOp).count();
+ System.out.println("DataOp Count: " + dataOpCount);
+
+
hops.stream().map(Hop::getClass).distinct().forEach(System.out::println);
+ }
+
+ private void runTest( String scriptFilename, boolean expectedException,
double expectedCost ) {
+ boolean raisedException = false;
+ try
+ {
+ setTestConfig(scriptFilename);
+ String dmlScriptString = readScript(scriptFilename);
+
+ //parsing, dependency analysis and constructing hops
(step 3 and 4 in DMLScript.java)
+ ParserWrapper parser = ParserFactory.createParser();
+ DMLProgram prog =
parser.parse(DMLScript.DML_FILE_PATH_ANTLR_PARSER, dmlScriptString, new
HashMap<>());
+ DMLTranslator dmlt = new DMLTranslator(prog);
+ dmlt.liveVariableAnalysis(prog);
+ dmlt.validateParseTree(prog);
+ dmlt.constructHops(prog);
+ if (
scriptFilename.equals("FederatedMultiplyCostEstimatorTest.dml")){
+ modifyFedouts(prog);
+ dmlt.rewriteHopsDAG(prog);
+ hops = new HashSet<>();
+ prog.getStatementBlocks().forEach(stmBlock ->
stmBlock.getHops().forEach(this::addHop));
+ }
+
+ FederatedCost actualCost =
fedCostEstimator.costEstimate(prog);
+ Assert.assertEquals(expectedCost,
actualCost.getTotal(), 0.0001);
+ }
+ catch(LanguageException ex) {
+ raisedException = true;
+ if(raisedException!=expectedException)
+ ex.printStackTrace();
+ }
+ catch(Exception ex2) {
+ ex2.printStackTrace();
+ throw new RuntimeException(ex2);
+ }
+
+ Assert.assertEquals("Expected exception does not match raised
exception",
+ expectedException, raisedException);
+ }
+
+ private void setTestConfig(String scriptFilename) throws
FileNotFoundException {
+ int index = scriptFilename.lastIndexOf(".dml");
+ String testName = scriptFilename.substring(0, index > 0 ? index
: scriptFilename.length());
+ TestConfiguration testConfig = new
TestConfiguration(TEST_CLASS_DIR, testName, new String[] {});
+ addTestConfiguration(testName, testConfig);
+ loadTestConfiguration(testConfig);
+
+ DMLConfig conf = new DMLConfig(getCurConfigFile().getPath());
+ ConfigurationManager.setLocalConfig(conf);
+ }
+
+ private static String readScript(String scriptFilename) throws
IOException {
+ return DMLScript.readDMLScript(true, HOME + scriptFilename);
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
index 342a26b..e0ef884 100644
---
a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
@@ -35,7 +35,7 @@ import org.apache.sysds.test.TestUtils;
import java.util.Arrays;
import java.util.Collection;
-import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
@RunWith(value = Parameterized.class)
@net.jcip.annotations.NotThreadSafe
@@ -76,34 +76,40 @@ public class FederatedMultiplyPlanningTest extends
AutomatedTestBase {
@Test
public void federatedMultiplyCP() {
- federatedTwoMatricesSingleNodeTest(TEST_NAME);
+ String[] expectedHeavyHitters = new String[]{"fed_*",
"fed_fedinit", "fed_r'", "fed_ba+*"};
+ federatedTwoMatricesSingleNodeTest(TEST_NAME,
expectedHeavyHitters);
}
@Test
public void federatedRowSum(){
- federatedTwoMatricesSingleNodeTest(TEST_NAME_2);
+ String[] expectedHeavyHitters = new String[]{"fed_*", "fed_r'",
"fed_fedinit", "fed_ba+*", "fed_uark+"};
+ federatedTwoMatricesSingleNodeTest(TEST_NAME_2,
expectedHeavyHitters);
}
@Test
public void federatedTernarySequence(){
- federatedTwoMatricesSingleNodeTest(TEST_NAME_3);
+ String[] expectedHeavyHitters = new String[]{"fed_+*",
"fed_1-*", "fed_fedinit", "fed_uak+"};
+ federatedTwoMatricesSingleNodeTest(TEST_NAME_3,
expectedHeavyHitters);
}
@Test
public void federatedAggregateBinarySequence(){
cols = rows;
- federatedTwoMatricesSingleNodeTest(TEST_NAME_4);
+ String[] expectedHeavyHitters = new String[]{"fed_ba+*",
"fed_*", "fed_fedinit"};
+ federatedTwoMatricesSingleNodeTest(TEST_NAME_4,
expectedHeavyHitters);
}
@Test
public void federatedAggregateBinaryColFedSequence(){
cols = rows;
- federatedTwoMatricesSingleNodeTest(TEST_NAME_5);
+ String[] expectedHeavyHitters = new
String[]{"fed_ba+*","fed_*","fed_fedinit"};
+ federatedTwoMatricesSingleNodeTest(TEST_NAME_5,
expectedHeavyHitters);
}
@Test
public void federatedAggregateBinarySequence2(){
- federatedTwoMatricesSingleNodeTest(TEST_NAME_6);
+ String[] expectedHeavyHitters = new
String[]{"fed_ba+*","fed_fedinit"};
+ federatedTwoMatricesSingleNodeTest(TEST_NAME_6,
expectedHeavyHitters);
}
private void writeStandardMatrix(String matrixName, long seed){
@@ -166,11 +172,11 @@ public class FederatedMultiplyPlanningTest extends
AutomatedTestBase {
}
}
- private void federatedTwoMatricesSingleNodeTest(String testName){
- federatedTwoMatricesTest(Types.ExecMode.SINGLE_NODE, testName);
+ private void federatedTwoMatricesSingleNodeTest(String testName,
String[] expectedHeavyHitters){
+ federatedTwoMatricesTest(Types.ExecMode.SINGLE_NODE, testName,
expectedHeavyHitters);
}
- private void federatedTwoMatricesTest(Types.ExecMode execMode, String
testName) {
+ private void federatedTwoMatricesTest(Types.ExecMode execMode, String
testName, String[] expectedHeavyHitters) {
OptimizerUtils.FEDERATED_COMPILATION = true;
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
Types.ExecMode platformOld = rtplatform;
@@ -213,10 +219,9 @@ public class FederatedMultiplyPlanningTest extends
AutomatedTestBase {
// compare via files
compareResults(1e-9);
- if ( testName.equals(TEST_NAME_3) )
- assertTrue(heavyHittersContainsString("fed_+*",
"fed_1-*"));
- else
- assertTrue(heavyHittersContainsString("fed_*",
"fed_ba+*"));
+ if (!heavyHittersContainsAllString(expectedHeavyHitters))
+ fail("The following expected heavy hitters are missing:
"
+ +
Arrays.toString(missingHeavyHitters(expectedHeavyHitters)));
TestUtils.shutdownThreads(t1, t2);
diff --git
a/src/test/scripts/functions/privacy/fedplanning/BinaryCostEstimatorTest.dml
b/src/test/scripts/functions/privacy/fedplanning/BinaryCostEstimatorTest.dml
new file mode 100644
index 0000000..1899614
--- /dev/null
+++ b/src/test/scripts/functions/privacy/fedplanning/BinaryCostEstimatorTest.dml
@@ -0,0 +1,26 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+a = matrix (1, rows=10, cols=10);
+b = matrix (2, rows=10, cols=10);
+c = a * b;
+print(toString(c));
diff --git
a/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyCostEstimatorTest.dml
b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyCostEstimatorTest.dml
new file mode 100644
index 0000000..dfeacec
--- /dev/null
+++
b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyCostEstimatorTest.dml
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+ph = "placeholder"
+rows = 10000
+cols = 10
+X = federated(addresses=list(ph, ph),
+ ranges=list(list(0, 0), list(rows / 2, cols), list(rows / 2, 0),
list(rows, cols)))
+Y = federated(addresses=list(ph, ph),
+ ranges=list(list(0, 0), list(rows/2, cols), list(rows / 2, 0),
list(rows, cols)))
+Z0 = X * Y
+Z = t(Z0) %*% X
+write(Z, ph)
diff --git
a/src/test/scripts/functions/privacy/fedplanning/ForLoopCostEstimatorTest.dml
b/src/test/scripts/functions/privacy/fedplanning/ForLoopCostEstimatorTest.dml
new file mode 100644
index 0000000..b80e745
--- /dev/null
+++
b/src/test/scripts/functions/privacy/fedplanning/ForLoopCostEstimatorTest.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+for ( i in 1:5 ){
+ a = matrix (1, rows=10, cols=10);
+ b = matrix (2, rows=10, cols=10);
+ c = a * b;
+ print(toString(c));
+}
\ No newline at end of file
diff --git
a/src/test/scripts/functions/privacy/fedplanning/FunctionCostEstimatorTest.dml
b/src/test/scripts/functions/privacy/fedplanning/FunctionCostEstimatorTest.dml
new file mode 100644
index 0000000..1f0d876
--- /dev/null
+++
b/src/test/scripts/functions/privacy/fedplanning/FunctionCostEstimatorTest.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+multiplication = function (){
+ a = matrix (1, rows=10, cols=10);
+ b = matrix (2, rows=10, cols=10);
+ c = a * b;
+ print(toString(c));
+}
+multiplication();
\ No newline at end of file
diff --git
a/src/test/scripts/functions/privacy/fedplanning/IfElseCostEstimatorTest.dml
b/src/test/scripts/functions/privacy/fedplanning/IfElseCostEstimatorTest.dml
new file mode 100644
index 0000000..b0194e3
--- /dev/null
+++ b/src/test/scripts/functions/privacy/fedplanning/IfElseCostEstimatorTest.dml
@@ -0,0 +1,30 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+if ( 1 ){
+ a = matrix (1, rows=10, cols=10);
+ b = matrix (2, rows=10, cols=10);
+ c = a * b;
+ print(toString(c));
+}
+else {
+ print("No result");
+}
\ No newline at end of file
diff --git
a/src/test/scripts/functions/privacy/fedplanning/ParForLoopCostEstimatorTest.dml
b/src/test/scripts/functions/privacy/fedplanning/ParForLoopCostEstimatorTest.dml
new file mode 100644
index 0000000..32a49ec
--- /dev/null
+++
b/src/test/scripts/functions/privacy/fedplanning/ParForLoopCostEstimatorTest.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+parfor ( i in 1:5 ){
+ a = matrix (1, rows=10, cols=10);
+ b = matrix (2, rows=10, cols=10);
+ c = a * b;
+ print(toString(c));
+}
\ No newline at end of file
diff --git
a/src/test/scripts/functions/privacy/fedplanning/WhileCostEstimatorTest.dml
b/src/test/scripts/functions/privacy/fedplanning/WhileCostEstimatorTest.dml
new file mode 100644
index 0000000..faea01d
--- /dev/null
+++ b/src/test/scripts/functions/privacy/fedplanning/WhileCostEstimatorTest.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+while ( 1 ){
+ a = matrix (1, rows=10, cols=10);
+ b = matrix (2, rows=10, cols=10);
+ c = a * b;
+ print(toString(c));
+}
\ No newline at end of file