Baunsgaard commented on a change in pull request #1395:
URL: https://github.com/apache/systemds/pull/1395#discussion_r722214525
##########
File path: src/main/java/org/apache/sysds/hops/cost/FederatedCostEstimator.java
##########
@@ -196,6 +194,90 @@ private FederatedCost costEstimate(Hop root){
}
}
+ /**
+ * Return cost estimate in bytes of Hop DAG starting from given root
HopRel.
+ * @param root HopRel of Hop DAG for which cost is estimated
+ * @param hopRelMemo memo table of HopRels for calculating input costs
+ * @return cost estimation of Hop DAG starting from given root HopRel
+ */
+ public FederatedCost costEstimate(HopRel root, Map<Long, List<HopRel>>
hopRelMemo){
+ // Check if root is in memo table.
+ if ( hopRelMemo.containsKey(root.hopRef.getHopID())
+ &&
hopRelMemo.get(root.hopRef.getHopID()).stream().anyMatch(h -> h.fedOut ==
root.fedOut) ){
+ return root.getCostObject();
+ }
+ else {
+ // If no input has FOUT, the root will be processed by
the coordinator
+ boolean hasFederatedInput =
root.inputDependency.stream().anyMatch(in -> in.hopRef.hasFederatedOutput());
+ //the input cost is included the first time the input
hop is used
+ //for additional usage, the additional cost is zero
(disregarding potential read cost)
Review comment:
consistently add a space after the // start.
if you use our code formatter it does it for you.
##########
File path: src/main/java/org/apache/sysds/hops/cost/FederatedCostEstimator.java
##########
@@ -196,6 +194,90 @@ private FederatedCost costEstimate(Hop root){
}
}
+ /**
+ * Return cost estimate in bytes of Hop DAG starting from given root
HopRel.
+ * @param root HopRel of Hop DAG for which cost is estimated
+ * @param hopRelMemo memo table of HopRels for calculating input costs
+ * @return cost estimation of Hop DAG starting from given root HopRel
+ */
+ public FederatedCost costEstimate(HopRel root, Map<Long, List<HopRel>>
hopRelMemo){
+ // Check if root is in memo table.
+ if ( hopRelMemo.containsKey(root.hopRef.getHopID())
+ &&
hopRelMemo.get(root.hopRef.getHopID()).stream().anyMatch(h -> h.fedOut ==
root.fedOut) ){
+ return root.getCostObject();
+ }
+ else {
+ // If no input has FOUT, the root will be processed by
the coordinator
+ boolean hasFederatedInput =
root.inputDependency.stream().anyMatch(in -> in.hopRef.hasFederatedOutput());
+ //the input cost is included the first time the input
hop is used
+ //for additional usage, the additional cost is zero
(disregarding potential read cost)
+ double inputCosts = root.inputDependency.stream()
+ .mapToDouble( in -> {
+ double inCost =
in.existingCostPointer(root.hopRef.getHopID()) ?
+ 0 : costEstimate(in,
hopRelMemo).getTotal();
+
in.addCostPointer(root.hopRef.getHopID());
+ return inCost;
+ } )
+ .sum();
+ double inputTransferCost =
inputTransferCostEstimate(hasFederatedInput, root);
+ double computingCost =
ComputeCost.getHOPComputeCost(root.hopRef);
+ if ( hasFederatedInput ){
+ //Find the number of inputs that has FOUT set.
+ int numWorkers =
(int)root.inputDependency.stream().filter(HopRel::hasFederatedOutput).count();
+ //divide memory usage by the number of workers
the computation would be split to multiplied by
+ //the number of parallel processes at each
worker multiplied by the FLOPS of each process
+ //This assumes uniform workload among the
workers with FOUT data involved in the operation
+ //and assumes that the degree of parallelism
and compute bandwidth are equal for all workers
+ computingCost = computingCost /
(numWorkers*WORKER_DEGREE_OF_PARALLELISM*WORKER_COMPUTE_BANDWITH_FLOPS);
+ } else computingCost = computingCost /
(WORKER_DEGREE_OF_PARALLELISM*WORKER_COMPUTE_BANDWITH_FLOPS);
+ //Calculate output transfer cost if the operation is
computed at federated workers and the output is forced to the coordinator
+ //If the root is a federated DataOp, the data is forced
to the coordinator even if no input is LOUT
+ double outputTransferCost = ( root.hasLocalOutput() &&
(hasFederatedInput || root.hopRef.isFederatedDataOp()) ) ?
+
root.hopRef.getOutputMemEstimate(DEFAULT_MEMORY_ESTIMATE) /
WORKER_NETWORK_BANDWIDTH_BYTES_PS : 0;
+ double readCost =
root.hopRef.getInputMemEstimate(DEFAULT_MEMORY_ESTIMATE) /
WORKER_READ_BANDWIDTH_BYTES_PS;
+
+ return new FederatedCost(readCost, inputTransferCost,
outputTransferCost, computingCost, inputCosts);
+ }
+ }
+
+ /**
+ * Returns input transfer cost estimate.
+ * The input transfer cost estimate is based on the memory estimate of
LOUT when some input is FOUT
+ * except if root is a federated DataOp, since all input for this has
to be at the coordinator.
+ * When no input is FOUT, the input transfer cost is always 0.
+ * @param hasFederatedInput true if root has any FOUT input
+ * @param root hopRel for which cost is estimated
+ * @return input transfer cost estimate
+ */
+ private double inputTransferCostEstimate(boolean hasFederatedInput,
HopRel root){
+ if ( hasFederatedInput )
+ return root.inputDependency.stream()
+ .filter(input ->
(root.hopRef.isFederatedDataOp()) ? input.hasFederatedOutput() :
input.hasLocalOutput() )
+ .mapToDouble(in ->
in.hopRef.getOutputMemEstimate(DEFAULT_MEMORY_ESTIMATE))
+ .map(inMem -> inMem/
WORKER_NETWORK_BANDWIDTH_BYTES_PS)
+ .sum();
+ else return 0;
+ }
+
+ /**
+ * Returns input transfer cost estimate.
+ * The input transfer cost estimate is based on the memory estimate of
LOUT when some input is FOUT
+ * except if root is a federated DataOp, since all input for this has
to be at the coordinator.
+ * When no input is FOUT, the input transfer cost is always 0.
+ * @param hasFederatedInput true if root has any FOUT input
+ * @param root hop for which cost is estimated
+ * @return input transfer cost estimate
+ */
+ private double inputTransferCostEstimate(boolean hasFederatedInput, Hop
root){
+ if ( hasFederatedInput )
+ return root.getInput().stream()
+ .filter(input -> (root.isFederatedDataOp()) ?
input.hasFederatedOutput() : input.hasLocalOutput() )
+ .mapToDouble(in ->
in.getOutputMemEstimate(DEFAULT_MEMORY_ESTIMATE))
+ .map(inMem -> inMem/
WORKER_NETWORK_BANDWIDTH_BYTES_PS)
Review comment:
if you divide each element with a value it is the same as dividing the
sum with the value. therefore you don't need to map here.
##########
File path: src/main/java/org/apache/sysds/hops/cost/FederatedCostEstimator.java
##########
@@ -196,6 +194,90 @@ private FederatedCost costEstimate(Hop root){
}
}
+ /**
+ * Return cost estimate in bytes of Hop DAG starting from given root
HopRel.
+ * @param root HopRel of Hop DAG for which cost is estimated
+ * @param hopRelMemo memo table of HopRels for calculating input costs
+ * @return cost estimation of Hop DAG starting from given root HopRel
+ */
+ public FederatedCost costEstimate(HopRel root, Map<Long, List<HopRel>>
hopRelMemo){
+ // Check if root is in memo table.
+ if ( hopRelMemo.containsKey(root.hopRef.getHopID())
+ &&
hopRelMemo.get(root.hopRef.getHopID()).stream().anyMatch(h -> h.fedOut ==
root.fedOut) ){
+ return root.getCostObject();
+ }
+ else {
+ // If no input has FOUT, the root will be processed by
the coordinator
+ boolean hasFederatedInput =
root.inputDependency.stream().anyMatch(in -> in.hopRef.hasFederatedOutput());
+ //the input cost is included the first time the input
hop is used
+ //for additional usage, the additional cost is zero
(disregarding potential read cost)
+ double inputCosts = root.inputDependency.stream()
+ .mapToDouble( in -> {
+ double inCost =
in.existingCostPointer(root.hopRef.getHopID()) ?
+ 0 : costEstimate(in,
hopRelMemo).getTotal();
+
in.addCostPointer(root.hopRef.getHopID());
+ return inCost;
+ } )
+ .sum();
+ double inputTransferCost =
inputTransferCostEstimate(hasFederatedInput, root);
+ double computingCost =
ComputeCost.getHOPComputeCost(root.hopRef);
+ if ( hasFederatedInput ){
+ //Find the number of inputs that has FOUT set.
+ int numWorkers =
(int)root.inputDependency.stream().filter(HopRel::hasFederatedOutput).count();
+ //divide memory usage by the number of workers
the computation would be split to multiplied by
+ //the number of parallel processes at each
worker multiplied by the FLOPS of each process
+ //This assumes uniform workload among the
workers with FOUT data involved in the operation
+ //and assumes that the degree of parallelism
and compute bandwidth are equal for all workers
+ computingCost = computingCost /
(numWorkers*WORKER_DEGREE_OF_PARALLELISM*WORKER_COMPUTE_BANDWITH_FLOPS);
+ } else computingCost = computingCost /
(WORKER_DEGREE_OF_PARALLELISM*WORKER_COMPUTE_BANDWITH_FLOPS);
+ //Calculate output transfer cost if the operation is
computed at federated workers and the output is forced to the coordinator
+ //If the root is a federated DataOp, the data is forced
to the coordinator even if no input is LOUT
+ double outputTransferCost = ( root.hasLocalOutput() &&
(hasFederatedInput || root.hopRef.isFederatedDataOp()) ) ?
+
root.hopRef.getOutputMemEstimate(DEFAULT_MEMORY_ESTIMATE) /
WORKER_NETWORK_BANDWIDTH_BYTES_PS : 0;
+ double readCost =
root.hopRef.getInputMemEstimate(DEFAULT_MEMORY_ESTIMATE) /
WORKER_READ_BANDWIDTH_BYTES_PS;
+
+ return new FederatedCost(readCost, inputTransferCost,
outputTransferCost, computingCost, inputCosts);
+ }
+ }
+
+ /**
+ * Returns input transfer cost estimate.
+ * The input transfer cost estimate is based on the memory estimate of
LOUT when some input is FOUT
+ * except if root is a federated DataOp, since all input for this has
to be at the coordinator.
+ * When no input is FOUT, the input transfer cost is always 0.
+ * @param hasFederatedInput true if root has any FOUT input
+ * @param root hopRel for which cost is estimated
+ * @return input transfer cost estimate
+ */
+ private double inputTransferCostEstimate(boolean hasFederatedInput,
HopRel root){
+ if ( hasFederatedInput )
+ return root.inputDependency.stream()
Review comment:
very fond of the steam API here :)
##########
File path: src/main/java/org/apache/sysds/hops/cost/FederatedCostEstimator.java
##########
@@ -196,6 +194,90 @@ private FederatedCost costEstimate(Hop root){
}
}
+ /**
+ * Return cost estimate in bytes of Hop DAG starting from given root
HopRel.
+ * @param root HopRel of Hop DAG for which cost is estimated
+ * @param hopRelMemo memo table of HopRels for calculating input costs
+ * @return cost estimation of Hop DAG starting from given root HopRel
+ */
+ public FederatedCost costEstimate(HopRel root, Map<Long, List<HopRel>>
hopRelMemo){
+ // Check if root is in memo table.
+ if ( hopRelMemo.containsKey(root.hopRef.getHopID())
+ &&
hopRelMemo.get(root.hopRef.getHopID()).stream().anyMatch(h -> h.fedOut ==
root.fedOut) ){
+ return root.getCostObject();
+ }
+ else {
+ // If no input has FOUT, the root will be processed by
the coordinator
+ boolean hasFederatedInput =
root.inputDependency.stream().anyMatch(in -> in.hopRef.hasFederatedOutput());
+ //the input cost is included the first time the input
hop is used
+ //for additional usage, the additional cost is zero
(disregarding potential read cost)
+ double inputCosts = root.inputDependency.stream()
+ .mapToDouble( in -> {
+ double inCost =
in.existingCostPointer(root.hopRef.getHopID()) ?
+ 0 : costEstimate(in,
hopRelMemo).getTotal();
+
in.addCostPointer(root.hopRef.getHopID());
+ return inCost;
+ } )
+ .sum();
Review comment:
seems overkill with a stream sum here, i would suggest making method,
where you use a for loop instead.
##########
File path: src/main/java/org/apache/sysds/hops/cost/HopRel.java
##########
@@ -0,0 +1,218 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.cost;
+
+import org.apache.sysds.api.DMLException;
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
+import
org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+/**
+ * HopRel provides a representation of the relation between a hop, the cost of
setting a given FederatedOutput value,
+ * and the input dependency with the given FederatedOutput value.
+ * The HopRel class is used when building and selecting an optimal federated
execution plan in IPAPassRewriteFederatedPlan.
+ * The input dependency is needed to hold the valid and optimal
FederatedOutput values for the inputs.
+ */
+public class HopRel {
+ protected Hop hopRef;
+ protected FEDInstruction.FederatedOutput fedOut;
+ protected FederatedCost cost;
+ protected Set<Long> costPointerSet = new HashSet<>();
+ protected List<HopRel> inputDependency = new ArrayList<>();
+
+ /**
+ * Constructs a HopRel with input dependency and cost estimate based on
entries in hopRelMemo.
Review comment:
If you use our formatter it will add a line after the description and
before the (add)param.
##########
File path: src/main/java/org/apache/sysds/hops/cost/HopRel.java
##########
@@ -0,0 +1,218 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.cost;
+
+import org.apache.sysds.api.DMLException;
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
+import
org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+/**
+ * HopRel provides a representation of the relation between a hop, the cost of
setting a given FederatedOutput value,
+ * and the input dependency with the given FederatedOutput value.
+ * The HopRel class is used when building and selecting an optimal federated
execution plan in IPAPassRewriteFederatedPlan.
+ * The input dependency is needed to hold the valid and optimal
FederatedOutput values for the inputs.
+ */
+public class HopRel {
+ protected Hop hopRef;
+ protected FEDInstruction.FederatedOutput fedOut;
+ protected FederatedCost cost;
+ protected Set<Long> costPointerSet = new HashSet<>();
+ protected List<HopRel> inputDependency = new ArrayList<>();
+
+ /**
+ * Constructs a HopRel with input dependency and cost estimate based on
entries in hopRelMemo.
+ * @param associatedHop hop associated with this HopRel
+ * @param fedOut FederatedOutput value assigned to this HopRel
+ * @param hopRelMemo memo table storing other HopRels including the
inputs of associatedHop
+ */
+ public HopRel(Hop associatedHop, FEDInstruction.FederatedOutput fedOut,
Map<Long, List<HopRel>> hopRelMemo){
+ hopRef = associatedHop;
+ this.fedOut = fedOut;
+ setInputDependency(hopRelMemo);
+ cost = new FederatedCostEstimator().costEstimate(this,
hopRelMemo);
+ }
+
+ /**
+ * Adds hopID to set of hops pointing to this HopRel.
+ * By storing the hopID it can later be determined if the cost
+ * stored in this HopRel is already used as input cost in another
HopRel.
+ * @param hopID added to set of stored cost pointers
+ */
+ public void addCostPointer(long hopID){
+ costPointerSet.add(hopID);
+ }
+
+ /**
+ * Checks if another Hop is refering to this HopRel in memo table.
+ * A reference from a HopRel with same Hop ID is allowed, so this
+ * ID is ignored when checking references.
+ * @param currentHopID to ignore when checking references
+ * @return true if another Hop refers to this HopRel in memo table
+ */
+ public boolean existingCostPointer(long currentHopID){
+ if ( costPointerSet.contains(currentHopID) )
+ return costPointerSet.size() > 1;
+ else return costPointerSet.size() > 0;
+ }
+
+ public boolean hasLocalOutput(){
+ return fedOut == FederatedOutput.LOUT;
+ }
+
+ public boolean hasFederatedOutput(){
+ return fedOut == FederatedOutput.FOUT;
+ }
+
+ public FederatedOutput getFederatedOutput(){
+ return fedOut;
+ }
+
+ public List<HopRel> getInputDependency(){
+ return inputDependency;
+ }
+
+ public Hop getHopRef(){
+ return hopRef;
+ }
+
+ /**
+ * Returns FOUT HopRel for given hop found in hopRelMemo or returns
null if HopRel not found.
+ * @param hop to look for in hopRelMemo
+ * @param hopRelMemo memo table storing HopRels
+ * @return FOUT HopRel found in hopRelMemo
+ */
+ private HopRel getFOUTHopRel(Hop hop, Map<Long, List<HopRel>>
hopRelMemo){
+ return
hopRelMemo.get(hop.getHopID()).stream().filter(in->in.fedOut==FederatedOutput.FOUT).findFirst().orElse(null);
+ }
+
+ /**
+ * Get the HopRel with minimum cost for given hop
+ * @param hopRelMemo memo table storing HopRels
+ * @param input hop for which minimum cost HopRel is found
+ * @return HopRel with minimum cost for given hop
+ */
+ private HopRel getMinOfInput(Map<Long, List<HopRel>> hopRelMemo, Hop
input){
+ return hopRelMemo.get(input.getHopID()).stream()
+ .min(Comparator.comparingDouble(a -> a.cost.getTotal()))
+ .orElseThrow(() -> new DMLException("No element in Memo
Table found for input"));
+ }
+
+ /**
+ * Set valid and optimal input dependency for this HopRel as a field.
+ * @param hopRelMemo memo table storing input HopRels
+ */
+ private void setInputDependency(Map<Long, List<HopRel>> hopRelMemo){
+ if (hopRef.getInput() != null && hopRef.getInput().size() > 0) {
+ if ( fedOut == FederatedOutput.FOUT &&
!hopRef.isFederatedDataOp() ) {
+ int lowestFOUTIndex = 0;
+ HopRel lowestFOUTHopRel =
getFOUTHopRel(hopRef.getInput().get(0), hopRelMemo);
+ for(int i = 1; i < hopRef.getInput().size();
i++) {
+ Hop input = hopRef.getInput(i);
+ HopRel foutHopRel =
getFOUTHopRel(input, hopRelMemo);
+ if(lowestFOUTHopRel == null) {
+ lowestFOUTHopRel = foutHopRel;
+ lowestFOUTIndex = i;
+ }
+ else if(foutHopRel != null) {
+ if(foutHopRel.getCost() <
lowestFOUTHopRel.getCost()) {
+ lowestFOUTHopRel =
foutHopRel;
+ lowestFOUTIndex = i;
+ }
+ }
+ }
+
+ HopRel[] inputHopRels = new
HopRel[hopRef.getInput().size()];
+ for(int i = 0; i < hopRef.getInput().size();
i++) {
+ if(i != lowestFOUTIndex) {
+ Hop input = hopRef.getInput(i);
+ inputHopRels[i] =
getMinOfInput(hopRelMemo, input);
+ }
+ else {
+ inputHopRels[i] =
lowestFOUTHopRel;
+ }
+ }
+
inputDependency.addAll(Arrays.asList(inputHopRels));
+ } else {
+ inputDependency.addAll(
+ hopRef.getInput().stream()
+ .map(input ->
getMinOfInput(hopRelMemo, input))
+ .collect(Collectors.toList()));
+ }
+ }
+ validateInputDependency();
+ }
+
+ /**
+ * Throws exception if any input dependency is null.
+ * If any of the input dependencies are null, it is not possible to
build a federated execution plan.
+ * If this null-state is not found here, an exception will be thrown at
another difficult-to-debug place.
+ */
+ private void validateInputDependency(){
+ for ( int i = 0; i < inputDependency.size(); i++){
+ if ( inputDependency.get(i) == null)
+ throw new DMLException("HopRel input number " +
i + " (" + hopRef.getInput(i) + ")"
+ + " is null for root: \n" + this);
+ }
Review comment:
you can remove the last outer "}"
##########
File path: src/main/java/org/apache/sysds/hops/cost/HopRel.java
##########
@@ -0,0 +1,218 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.cost;
+
+import org.apache.sysds.api.DMLException;
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
+import
org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+/**
+ * HopRel provides a representation of the relation between a hop, the cost of
setting a given FederatedOutput value,
+ * and the input dependency with the given FederatedOutput value.
+ * The HopRel class is used when building and selecting an optimal federated
execution plan in IPAPassRewriteFederatedPlan.
+ * The input dependency is needed to hold the valid and optimal
FederatedOutput values for the inputs.
+ */
+public class HopRel {
+ protected Hop hopRef;
+ protected FEDInstruction.FederatedOutput fedOut;
+ protected FederatedCost cost;
+ protected Set<Long> costPointerSet = new HashSet<>();
+ protected List<HopRel> inputDependency = new ArrayList<>();
Review comment:
are some of these fields final?
##########
File path:
src/main/java/org/apache/sysds/hops/rewrite/IPAPassRewriteFederatedPlan.java
##########
@@ -0,0 +1,316 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.rewrite;
+
+import org.apache.sysds.api.DMLException;
+import org.apache.sysds.hops.AggBinaryOp;
+import org.apache.sysds.hops.AggUnaryOp;
+import org.apache.sysds.hops.BinaryOp;
+import org.apache.sysds.hops.DataOp;
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.hops.ReorgOp;
+import org.apache.sysds.hops.TernaryOp;
+import org.apache.sysds.hops.cost.HopRel;
+import org.apache.sysds.hops.ipa.FunctionCallGraph;
+import org.apache.sysds.hops.ipa.FunctionCallSizeInfo;
+import org.apache.sysds.hops.ipa.IPAPass;
+import org.apache.sysds.parser.DMLProgram;
+import org.apache.sysds.parser.ForStatement;
+import org.apache.sysds.parser.ForStatementBlock;
+import org.apache.sysds.parser.FunctionStatement;
+import org.apache.sysds.parser.FunctionStatementBlock;
+import org.apache.sysds.parser.IfStatement;
+import org.apache.sysds.parser.IfStatementBlock;
+import org.apache.sysds.parser.Statement;
+import org.apache.sysds.parser.StatementBlock;
+import org.apache.sysds.parser.WhileStatement;
+import org.apache.sysds.parser.WhileStatementBlock;
+import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * This rewrite generates a federated execution plan by estimating and setting
costs and the FederatedOutput values of
+ * all relevant hops in the DML program.
+ * The rewrite is only applied if federated compilation is activated in
OptimizerUtils.
+ */
+public class IPAPassRewriteFederatedPlan extends IPAPass {
+
+ private final static Map<Long, List<HopRel>> hopRelMemo = new
HashMap<>();
+
+ /**
+ * Indicates if an IPA pass is applicable for the current configuration.
+ * The configuration depends on OptimizerUtils.FEDERATED_COMPILATION.
+ *
+ * @param fgraph function call graph
+ * @return true if federated compilation is activated.
+ */
+ @Override
+ public boolean isApplicable(FunctionCallGraph fgraph) {
+ return OptimizerUtils.FEDERATED_COMPILATION;
+ }
+
+ /**
+ * Estimates cost and selects a federated execution plan
+ * by setting the federated output value of each hop in the program.
+ *
+ * @param prog dml program
+ * @param fgraph function call graph
+ * @param fcallSizes function call size infos
+ * @return false since the function call graph never has to be rebuilt
+ */
+ @Override
+ public boolean rewriteProgram(DMLProgram prog, FunctionCallGraph
fgraph, FunctionCallSizeInfo fcallSizes) {
+ rewriteStatementBlocks(prog.getStatementBlocks());
+ return false;
+ }
+
+ /**
+ * Estimates cost and selects a federated execution plan
+ * by setting the federated output value of each hop in the statement
blocks.
+ * The method calls the contained statement blocks recursively.
+ *
+ * @param sbs list of statement blocks
+ * @return list of statement blocks with the federated output value
updated for each hop
+ */
+ public ArrayList<StatementBlock>
rewriteStatementBlocks(List<StatementBlock> sbs) {
+ ArrayList<StatementBlock> rewrittenStmBlocks = new
ArrayList<>();
+ for ( StatementBlock stmBlock : sbs )
+
rewrittenStmBlocks.addAll(rewriteStatementBlock(stmBlock));
+ return rewrittenStmBlocks;
+ }
+
+ /**
+ * Estimates cost and selects a federated execution plan
+ * by setting the federated output value of each hop in the statement
blocks.
+ * The method calls the contained statement blocks recursively.
+ *
+ * @param sb statement block
+ * @return list of statement blocks with the federated output value
updated for each hop
+ */
+ public ArrayList<StatementBlock> rewriteStatementBlock(StatementBlock
sb) {
+ if ( sb instanceof WhileStatementBlock)
+ return rewriteWhileStatementBlock((WhileStatementBlock)
sb);
+ else if ( sb instanceof IfStatementBlock)
+ return rewriteIfStatementBlock((IfStatementBlock) sb);
+ else if ( sb instanceof ForStatementBlock){
+ // This also includes ParForStatementBlocks
+ return rewriteForStatementBlock((ForStatementBlock) sb);
+ }
+ else if ( sb instanceof FunctionStatementBlock)
+ return
rewriteFunctionStatementBlock((FunctionStatementBlock) sb);
+ else {
+ // StatementBlock type (no subclass)
+ selectFederatedExecutionPlan(sb.getHops());
+ }
+ return new ArrayList<>(Collections.singletonList(sb));
+ }
+
+ private ArrayList<StatementBlock>
rewriteWhileStatementBlock(WhileStatementBlock whileSB){
+ Hop whilePredicateHop = whileSB.getPredicateHops();
+ selectFederatedExecutionPlan(whilePredicateHop);
+ for ( Statement stm : whileSB.getStatements() ){
+ WhileStatement whileStm = (WhileStatement) stm;
+
whileStm.setBody(rewriteStatementBlocks(whileStm.getBody()));
+ }
+ return new ArrayList<>(Collections.singletonList(whileSB));
+ }
+
+ private ArrayList<StatementBlock>
rewriteIfStatementBlock(IfStatementBlock ifSB){
+ selectFederatedExecutionPlan(ifSB.getPredicateHops());
+ for ( Statement statement : ifSB.getStatements() ){
+ IfStatement ifStatement = (IfStatement) statement;
+
ifStatement.setIfBody(rewriteStatementBlocks(ifStatement.getIfBody()));
+
ifStatement.setElseBody(rewriteStatementBlocks(ifStatement.getElseBody()));
+ }
+ return new ArrayList<>(Collections.singletonList(ifSB));
+ }
+
+ private ArrayList<StatementBlock>
rewriteForStatementBlock(ForStatementBlock forSB){
+ selectFederatedExecutionPlan(forSB.getFromHops());
+ selectFederatedExecutionPlan(forSB.getToHops());
+ selectFederatedExecutionPlan(forSB.getIncrementHops());
+ for ( Statement statement : forSB.getStatements() ){
+ ForStatement forStatement = ((ForStatement)statement);
+
forStatement.setBody(rewriteStatementBlocks(forStatement.getBody()));
+ }
+ return new ArrayList<>(Collections.singletonList(forSB));
+ }
+
+ private ArrayList<StatementBlock>
rewriteFunctionStatementBlock(FunctionStatementBlock funcSB){
+ for ( Statement statement : funcSB.getStatements() ){
+ FunctionStatement funcStm = (FunctionStatement)
statement;
+
funcStm.setBody(rewriteStatementBlocks(funcStm.getBody()));
+ }
+ return new ArrayList<>(Collections.singletonList(funcSB));
+ }
+
+ /**
+ * Sets FederatedOutput field of all hops in DAG starting from given
root.
+ * The FederatedOutput chosen for root is the minimum cost HopRel found
in memo table for the given root.
+ * The FederatedOutput values chosen for the inputs to the root are
chosen based on the input dependencies.
+ * @param root hop for which FederatedOutput needs to be set
+ */
+ private void setFinalFedout(Hop root){
+ HopRel optimalRootHopRel =
hopRelMemo.get(root.getHopID()).stream().min(Comparator.comparingDouble(HopRel::getCost))
+ .orElseThrow(() -> new DMLException("Hop root " + root
+ " has no feasible federated output alternatives"));
+ setFinalFedout(root, optimalRootHopRel);
+ }
+
+ /**
+ * Update the FederatedOutput value and cost based on information
stored in given rootHopRel.
+ * @param root hop for which FederatedOutput is set
+ * @param rootHopRel from which FederatedOutput value and cost is
retrieved
+ */
+ private void setFinalFedout(Hop root, HopRel rootHopRel){
+ updateFederatedOutput(root, rootHopRel);
+ visitInputDependency(rootHopRel);
+ }
+
+ /**
+ * Sets FederatedOutput value for each of the inputs of rootHopRel
+ * @param rootHopRel which has its input values updated
+ */
+ private void visitInputDependency(HopRel rootHopRel){
+ List<HopRel> hopRelInputs = rootHopRel.getInputDependency();
+ for ( HopRel input : hopRelInputs )
+ setFinalFedout(input.getHopRef(), input);
+ }
+
+ /**
+ * Updates FederatedOutput value and cost estimate based on
updateHopRel values.
+ * @param root which has its values updated
+ * @param updateHopRel from which the values are retrieved
+ */
+ private void updateFederatedOutput(Hop root, HopRel updateHopRel){
+ root.setFederatedOutput(updateHopRel.getFederatedOutput());
+ root.setFederatedCost(updateHopRel.getCostObject());
+ }
+
+ /**
+ * Select federated execution plan for every Hop in the DAG starting
from given roots.
+ * The cost estimates of the hops are also updated when FederatedOutput
is updated in the hops.
+ * @param roots starting point for going through the Hop DAG to update
the FederatedOutput fields.
+ */
+ private void selectFederatedExecutionPlan(ArrayList<Hop> roots){
+ for ( Hop root : roots ){
+ selectFederatedExecutionPlan(root);
+ }
+ }
+
+ /**
+ * Select federated execution plan for every Hop in the DAG starting
from given root.
+ * @param root starting point for going through the Hop DAG to update
the federatedOutput fields
+ */
+ private void selectFederatedExecutionPlan(Hop root){
+ visitFedPlanHop(root);
+ setFinalFedout(root);
+ }
+
+ /**
+ * Go through the Hop DAG and set the FederatedOutput field and cost
estimate for each Hop from leaf to given currentHop.
+ * @param currentHop the Hop from which the DAG is visited
+ */
+ private void visitFedPlanHop(Hop currentHop){
+ // If the currentHop is in the hopRelMemo table, it means that
it has been visited
+ if ( hopRelMemo.containsKey(currentHop.getHopID()) )
+ return;
+ // If the currentHop has input, then the input should be
visited depth-first
+ if ( currentHop.getInput() != null &&
currentHop.getInput().size() > 0 ){
+ for ( Hop input : currentHop.getInput() )
+ visitFedPlanHop(input);
+ }
+ // Put FOUT, LOUT, and None HopRels into the memo table
+ ArrayList<HopRel> hopRels = new ArrayList<>();
+ if ( isFedInstSupportedHop(currentHop) ){
+ for ( FEDInstruction.FederatedOutput fedoutValue :
FEDInstruction.FederatedOutput.values() )
+ if ( isFedOutSupported(currentHop, fedoutValue)
)
+ hopRels.add(new
HopRel(currentHop,fedoutValue, hopRelMemo));
+ }
+ if ( hopRels.isEmpty() )
+ hopRels.add(new HopRel(currentHop,
FEDInstruction.FederatedOutput.NONE, hopRelMemo));
+ hopRelMemo.put(currentHop.getHopID(), hopRels);
+ currentHop.setVisited();
+ }
+
+ /**
+ * Checks if the instructions related to the given hop supports
FOUT/LOUT processing.
+ * @param hop to check for federated support
+ * @return true if federated instructions related to hop supports
FOUT/LOUT processing
+ */
+ private boolean isFedInstSupportedHop(Hop hop){
+ // The following operations are supported given that the above
conditions have not returned already
+ return ( hop instanceof AggBinaryOp || hop instanceof BinaryOp
|| hop instanceof ReorgOp
+ || hop instanceof AggUnaryOp || hop instanceof
TernaryOp || hop instanceof DataOp);
+ }
+
+ /**
+ * Checks if the associatedHop supports the given federated output
value.
+ * @param associatedHop to check support of
+ * @param fedOut federated output value
+ * @return true if associatedHop supports fedOut
+ */
+ private boolean isFedOutSupported(Hop associatedHop,
FEDInstruction.FederatedOutput fedOut){
+ switch(fedOut){
+ case FOUT:
+ return isFOUTSupported(associatedHop);
+ case LOUT:
+ return isLOUTSupported(associatedHop);
+ case NONE:
+ return false;
+ default:
+ return true;
+ }
+ }
+
+ /**
+ * Checks to see if the associatedHop supports FOUT.
+ * @param associatedHop for which FOUT support is checked
+ * @return true if FOUT is supported by the associatedHop
+ */
+ private boolean isFOUTSupported(Hop associatedHop){
+ // If the output of AggUnaryOp is a scalar, the operation
cannot be FOUT
+ if ( associatedHop instanceof AggUnaryOp &&
associatedHop.isScalar() )
+ return false;
+ // It can only be FOUT if at least one of the inputs are FOUT,
except if it is a federated DataOp
+ if ( associatedHop.getInput().stream().noneMatch(
+ input ->
hopRelMemo.get(input.getHopID()).stream().anyMatch(HopRel::hasFederatedOutput) )
+ && !associatedHop.isFederatedDataOp() )
+ return false;
+ return true;
+ }
+
+ /**
+ * Checks to see if the associatedHop supports LOUT.
+ * It supports LOUT if the output has no privacy constraints.
+ * @param associatedHop for which LOUT support is checked.
+ * @return true if LOUT is supported by the associatedHop
+ */
+ private boolean isLOUTSupported(Hop associatedHop){
+ return associatedHop.getPrivacy() == null
Review comment:
i would remove the first part of the parenthesis, because it is
implicitly applied from the first part of the or.
##########
File path:
src/main/java/org/apache/sysds/hops/rewrite/IPAPassRewriteFederatedPlan.java
##########
@@ -0,0 +1,316 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.rewrite;
+
+import org.apache.sysds.api.DMLException;
+import org.apache.sysds.hops.AggBinaryOp;
+import org.apache.sysds.hops.AggUnaryOp;
+import org.apache.sysds.hops.BinaryOp;
+import org.apache.sysds.hops.DataOp;
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.hops.ReorgOp;
+import org.apache.sysds.hops.TernaryOp;
+import org.apache.sysds.hops.cost.HopRel;
+import org.apache.sysds.hops.ipa.FunctionCallGraph;
+import org.apache.sysds.hops.ipa.FunctionCallSizeInfo;
+import org.apache.sysds.hops.ipa.IPAPass;
+import org.apache.sysds.parser.DMLProgram;
+import org.apache.sysds.parser.ForStatement;
+import org.apache.sysds.parser.ForStatementBlock;
+import org.apache.sysds.parser.FunctionStatement;
+import org.apache.sysds.parser.FunctionStatementBlock;
+import org.apache.sysds.parser.IfStatement;
+import org.apache.sysds.parser.IfStatementBlock;
+import org.apache.sysds.parser.Statement;
+import org.apache.sysds.parser.StatementBlock;
+import org.apache.sysds.parser.WhileStatement;
+import org.apache.sysds.parser.WhileStatementBlock;
+import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * This rewrite generates a federated execution plan by estimating and setting
costs and the FederatedOutput values of
+ * all relevant hops in the DML program.
+ * The rewrite is only applied if federated compilation is activated in
OptimizerUtils.
+ */
+public class IPAPassRewriteFederatedPlan extends IPAPass {
+
+ private final static Map<Long, List<HopRel>> hopRelMemo = new
HashMap<>();
+
+ /**
+ * Indicates if an IPA pass is applicable for the current configuration.
+ * The configuration depends on OptimizerUtils.FEDERATED_COMPILATION.
+ *
+ * @param fgraph function call graph
+ * @return true if federated compilation is activated.
+ */
+ @Override
+ public boolean isApplicable(FunctionCallGraph fgraph) {
+ return OptimizerUtils.FEDERATED_COMPILATION;
+ }
+
+ /**
+ * Estimates cost and selects a federated execution plan
+ * by setting the federated output value of each hop in the program.
+ *
+ * @param prog dml program
+ * @param fgraph function call graph
+ * @param fcallSizes function call size infos
+ * @return false since the function call graph never has to be rebuilt
+ */
+ @Override
+ public boolean rewriteProgram(DMLProgram prog, FunctionCallGraph
fgraph, FunctionCallSizeInfo fcallSizes) {
+ rewriteStatementBlocks(prog.getStatementBlocks());
+ return false;
+ }
+
+ /**
+ * Estimates cost and selects a federated execution plan
+ * by setting the federated output value of each hop in the statement
blocks.
+ * The method calls the contained statement blocks recursively.
+ *
+ * @param sbs list of statement blocks
+ * @return list of statement blocks with the federated output value
updated for each hop
+ */
+ public ArrayList<StatementBlock>
rewriteStatementBlocks(List<StatementBlock> sbs) {
+ ArrayList<StatementBlock> rewrittenStmBlocks = new
ArrayList<>();
+ for ( StatementBlock stmBlock : sbs )
+
rewrittenStmBlocks.addAll(rewriteStatementBlock(stmBlock));
+ return rewrittenStmBlocks;
+ }
+
+ /**
+ * Estimates cost and selects a federated execution plan
+ * by setting the federated output value of each hop in the statement
blocks.
+ * The method calls the contained statement blocks recursively.
+ *
+ * @param sb statement block
+ * @return list of statement blocks with the federated output value
updated for each hop
+ */
+ public ArrayList<StatementBlock> rewriteStatementBlock(StatementBlock
sb) {
+ if ( sb instanceof WhileStatementBlock)
+ return rewriteWhileStatementBlock((WhileStatementBlock)
sb);
+ else if ( sb instanceof IfStatementBlock)
+ return rewriteIfStatementBlock((IfStatementBlock) sb);
+ else if ( sb instanceof ForStatementBlock){
+ // This also includes ParForStatementBlocks
+ return rewriteForStatementBlock((ForStatementBlock) sb);
+ }
+ else if ( sb instanceof FunctionStatementBlock)
+ return
rewriteFunctionStatementBlock((FunctionStatementBlock) sb);
+ else {
+ // StatementBlock type (no subclass)
+ selectFederatedExecutionPlan(sb.getHops());
+ }
+ return new ArrayList<>(Collections.singletonList(sb));
+ }
+
+ private ArrayList<StatementBlock>
rewriteWhileStatementBlock(WhileStatementBlock whileSB){
+ Hop whilePredicateHop = whileSB.getPredicateHops();
+ selectFederatedExecutionPlan(whilePredicateHop);
+ for ( Statement stm : whileSB.getStatements() ){
+ WhileStatement whileStm = (WhileStatement) stm;
+
whileStm.setBody(rewriteStatementBlocks(whileStm.getBody()));
+ }
+ return new ArrayList<>(Collections.singletonList(whileSB));
+ }
+
+ private ArrayList<StatementBlock>
rewriteIfStatementBlock(IfStatementBlock ifSB){
+ selectFederatedExecutionPlan(ifSB.getPredicateHops());
+ for ( Statement statement : ifSB.getStatements() ){
+ IfStatement ifStatement = (IfStatement) statement;
+
ifStatement.setIfBody(rewriteStatementBlocks(ifStatement.getIfBody()));
+
ifStatement.setElseBody(rewriteStatementBlocks(ifStatement.getElseBody()));
+ }
+ return new ArrayList<>(Collections.singletonList(ifSB));
+ }
+
+ private ArrayList<StatementBlock>
rewriteForStatementBlock(ForStatementBlock forSB){
+ selectFederatedExecutionPlan(forSB.getFromHops());
+ selectFederatedExecutionPlan(forSB.getToHops());
+ selectFederatedExecutionPlan(forSB.getIncrementHops());
+ for ( Statement statement : forSB.getStatements() ){
+ ForStatement forStatement = ((ForStatement)statement);
+
forStatement.setBody(rewriteStatementBlocks(forStatement.getBody()));
+ }
+ return new ArrayList<>(Collections.singletonList(forSB));
+ }
+
+ private ArrayList<StatementBlock>
rewriteFunctionStatementBlock(FunctionStatementBlock funcSB){
+ for ( Statement statement : funcSB.getStatements() ){
+ FunctionStatement funcStm = (FunctionStatement)
statement;
+
funcStm.setBody(rewriteStatementBlocks(funcStm.getBody()));
+ }
+ return new ArrayList<>(Collections.singletonList(funcSB));
+ }
+
+ /**
+ * Sets FederatedOutput field of all hops in DAG starting from given
root.
+ * The FederatedOutput chosen for root is the minimum cost HopRel found
in memo table for the given root.
+ * The FederatedOutput values chosen for the inputs to the root are
chosen based on the input dependencies.
+ * @param root hop for which FederatedOutput needs to be set
+ */
+ private void setFinalFedout(Hop root){
+ HopRel optimalRootHopRel =
hopRelMemo.get(root.getHopID()).stream().min(Comparator.comparingDouble(HopRel::getCost))
+ .orElseThrow(() -> new DMLException("Hop root " + root
+ " has no feasible federated output alternatives"));
+ setFinalFedout(root, optimalRootHopRel);
+ }
+
+ /**
+ * Update the FederatedOutput value and cost based on information
stored in given rootHopRel.
+ * @param root hop for which FederatedOutput is set
+ * @param rootHopRel from which FederatedOutput value and cost is
retrieved
+ */
+ private void setFinalFedout(Hop root, HopRel rootHopRel){
+ updateFederatedOutput(root, rootHopRel);
+ visitInputDependency(rootHopRel);
+ }
+
+ /**
+ * Sets FederatedOutput value for each of the inputs of rootHopRel
+ * @param rootHopRel which has its input values updated
+ */
+ private void visitInputDependency(HopRel rootHopRel){
+ List<HopRel> hopRelInputs = rootHopRel.getInputDependency();
+ for ( HopRel input : hopRelInputs )
+ setFinalFedout(input.getHopRef(), input);
+ }
+
+ /**
+ * Updates FederatedOutput value and cost estimate based on
updateHopRel values.
+ * @param root which has its values updated
+ * @param updateHopRel from which the values are retrieved
+ */
+ private void updateFederatedOutput(Hop root, HopRel updateHopRel){
+ root.setFederatedOutput(updateHopRel.getFederatedOutput());
+ root.setFederatedCost(updateHopRel.getCostObject());
+ }
+
+ /**
+ * Select federated execution plan for every Hop in the DAG starting
from given roots.
+ * The cost estimates of the hops are also updated when FederatedOutput
is updated in the hops.
+ * @param roots starting point for going through the Hop DAG to update
the FederatedOutput fields.
+ */
+ private void selectFederatedExecutionPlan(ArrayList<Hop> roots){
+ for ( Hop root : roots ){
Review comment:
formatting and remove {
##########
File path: src/main/java/org/apache/sysds/hops/cost/HopRel.java
##########
@@ -0,0 +1,218 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.cost;
+
+import org.apache.sysds.api.DMLException;
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
+import
org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+/**
+ * HopRel provides a representation of the relation between a hop, the cost of
setting a given FederatedOutput value,
+ * and the input dependency with the given FederatedOutput value.
+ * The HopRel class is used when building and selecting an optimal federated
execution plan in IPAPassRewriteFederatedPlan.
+ * The input dependency is needed to hold the valid and optimal
FederatedOutput values for the inputs.
+ */
+public class HopRel {
+ protected Hop hopRef;
+ protected FEDInstruction.FederatedOutput fedOut;
+ protected FederatedCost cost;
+ protected Set<Long> costPointerSet = new HashSet<>();
+ protected List<HopRel> inputDependency = new ArrayList<>();
+
+ /**
+ * Constructs a HopRel with input dependency and cost estimate based on
entries in hopRelMemo.
+ * @param associatedHop hop associated with this HopRel
+ * @param fedOut FederatedOutput value assigned to this HopRel
+ * @param hopRelMemo memo table storing other HopRels including the
inputs of associatedHop
+ */
+ public HopRel(Hop associatedHop, FEDInstruction.FederatedOutput fedOut,
Map<Long, List<HopRel>> hopRelMemo){
+ hopRef = associatedHop;
+ this.fedOut = fedOut;
+ setInputDependency(hopRelMemo);
+ cost = new FederatedCostEstimator().costEstimate(this,
hopRelMemo);
+ }
+
+ /**
+ * Adds hopID to set of hops pointing to this HopRel.
+ * By storing the hopID it can later be determined if the cost
+ * stored in this HopRel is already used as input cost in another
HopRel.
+ * @param hopID added to set of stored cost pointers
+ */
+ public void addCostPointer(long hopID){
+ costPointerSet.add(hopID);
+ }
+
+ /**
+ * Checks if another Hop is refering to this HopRel in memo table.
+ * A reference from a HopRel with same Hop ID is allowed, so this
+ * ID is ignored when checking references.
+ * @param currentHopID to ignore when checking references
+ * @return true if another Hop refers to this HopRel in memo table
+ */
+ public boolean existingCostPointer(long currentHopID){
+ if ( costPointerSet.contains(currentHopID) )
+ return costPointerSet.size() > 1;
+ else return costPointerSet.size() > 0;
+ }
+
+ public boolean hasLocalOutput(){
+ return fedOut == FederatedOutput.LOUT;
+ }
+
+ public boolean hasFederatedOutput(){
+ return fedOut == FederatedOutput.FOUT;
+ }
+
+ public FederatedOutput getFederatedOutput(){
+ return fedOut;
+ }
+
+ public List<HopRel> getInputDependency(){
+ return inputDependency;
+ }
+
+ public Hop getHopRef(){
+ return hopRef;
+ }
+
+ /**
+ * Returns FOUT HopRel for given hop found in hopRelMemo or returns
null if HopRel not found.
+ * @param hop to look for in hopRelMemo
+ * @param hopRelMemo memo table storing HopRels
+ * @return FOUT HopRel found in hopRelMemo
+ */
+ private HopRel getFOUTHopRel(Hop hop, Map<Long, List<HopRel>>
hopRelMemo){
+ return
hopRelMemo.get(hop.getHopID()).stream().filter(in->in.fedOut==FederatedOutput.FOUT).findFirst().orElse(null);
Review comment:
fancy one liner.
##########
File path: src/main/java/org/apache/sysds/hops/cost/FederatedCostEstimator.java
##########
@@ -196,6 +194,90 @@ private FederatedCost costEstimate(Hop root){
}
}
+ /**
+ * Return cost estimate in bytes of Hop DAG starting from given root
HopRel.
+ * @param root HopRel of Hop DAG for which cost is estimated
+ * @param hopRelMemo memo table of HopRels for calculating input costs
+ * @return cost estimation of Hop DAG starting from given root HopRel
+ */
+ public FederatedCost costEstimate(HopRel root, Map<Long, List<HopRel>>
hopRelMemo){
+ // Check if root is in memo table.
+ if ( hopRelMemo.containsKey(root.hopRef.getHopID())
+ &&
hopRelMemo.get(root.hopRef.getHopID()).stream().anyMatch(h -> h.fedOut ==
root.fedOut) ){
+ return root.getCostObject();
+ }
+ else {
+ // If no input has FOUT, the root will be processed by
the coordinator
+ boolean hasFederatedInput =
root.inputDependency.stream().anyMatch(in -> in.hopRef.hasFederatedOutput());
+ //the input cost is included the first time the input
hop is used
+ //for additional usage, the additional cost is zero
(disregarding potential read cost)
+ double inputCosts = root.inputDependency.stream()
+ .mapToDouble( in -> {
+ double inCost =
in.existingCostPointer(root.hopRef.getHopID()) ?
+ 0 : costEstimate(in,
hopRelMemo).getTotal();
+
in.addCostPointer(root.hopRef.getHopID());
+ return inCost;
+ } )
+ .sum();
+ double inputTransferCost =
inputTransferCostEstimate(hasFederatedInput, root);
+ double computingCost =
ComputeCost.getHOPComputeCost(root.hopRef);
+ if ( hasFederatedInput ){
+ //Find the number of inputs that has FOUT set.
+ int numWorkers =
(int)root.inputDependency.stream().filter(HopRel::hasFederatedOutput).count();
+ //divide memory usage by the number of workers
the computation would be split to multiplied by
+ //the number of parallel processes at each
worker multiplied by the FLOPS of each process
+ //This assumes uniform workload among the
workers with FOUT data involved in the operation
+ //and assumes that the degree of parallelism
and compute bandwidth are equal for all workers
+ computingCost = computingCost /
(numWorkers*WORKER_DEGREE_OF_PARALLELISM*WORKER_COMPUTE_BANDWITH_FLOPS);
+ } else computingCost = computingCost /
(WORKER_DEGREE_OF_PARALLELISM*WORKER_COMPUTE_BANDWITH_FLOPS);
+ //Calculate output transfer cost if the operation is
computed at federated workers and the output is forced to the coordinator
+ //If the root is a federated DataOp, the data is forced
to the coordinator even if no input is LOUT
+ double outputTransferCost = ( root.hasLocalOutput() &&
(hasFederatedInput || root.hopRef.isFederatedDataOp()) ) ?
+
root.hopRef.getOutputMemEstimate(DEFAULT_MEMORY_ESTIMATE) /
WORKER_NETWORK_BANDWIDTH_BYTES_PS : 0;
+ double readCost =
root.hopRef.getInputMemEstimate(DEFAULT_MEMORY_ESTIMATE) /
WORKER_READ_BANDWIDTH_BYTES_PS;
+
+ return new FederatedCost(readCost, inputTransferCost,
outputTransferCost, computingCost, inputCosts);
+ }
+ }
+
+ /**
+ * Returns input transfer cost estimate.
+ * The input transfer cost estimate is based on the memory estimate of
LOUT when some input is FOUT
+ * except if root is a federated DataOp, since all input for this has
to be at the coordinator.
+ * When no input is FOUT, the input transfer cost is always 0.
+ * @param hasFederatedInput true if root has any FOUT input
+ * @param root hopRel for which cost is estimated
+ * @return input transfer cost estimate
+ */
+ private double inputTransferCostEstimate(boolean hasFederatedInput,
HopRel root){
+ if ( hasFederatedInput )
+ return root.inputDependency.stream()
+ .filter(input ->
(root.hopRef.isFederatedDataOp()) ? input.hasFederatedOutput() :
input.hasLocalOutput() )
+ .mapToDouble(in ->
in.hopRef.getOutputMemEstimate(DEFAULT_MEMORY_ESTIMATE))
+ .map(inMem -> inMem/
WORKER_NETWORK_BANDWIDTH_BYTES_PS)
Review comment:
same argument here, don't divide each element, but divide the sum
##########
File path: src/main/java/org/apache/sysds/hops/cost/FederatedCostEstimator.java
##########
@@ -196,6 +194,90 @@ private FederatedCost costEstimate(Hop root){
}
}
+ /**
+ * Return cost estimate in bytes of Hop DAG starting from given root
HopRel.
+ * @param root HopRel of Hop DAG for which cost is estimated
+ * @param hopRelMemo memo table of HopRels for calculating input costs
+ * @return cost estimation of Hop DAG starting from given root HopRel
+ */
+ public FederatedCost costEstimate(HopRel root, Map<Long, List<HopRel>>
hopRelMemo){
+ // Check if root is in memo table.
+ if ( hopRelMemo.containsKey(root.hopRef.getHopID())
+ &&
hopRelMemo.get(root.hopRef.getHopID()).stream().anyMatch(h -> h.fedOut ==
root.fedOut) ){
+ return root.getCostObject();
+ }
+ else {
+ // If no input has FOUT, the root will be processed by
the coordinator
+ boolean hasFederatedInput =
root.inputDependency.stream().anyMatch(in -> in.hopRef.hasFederatedOutput());
+ //the input cost is included the first time the input
hop is used
+ //for additional usage, the additional cost is zero
(disregarding potential read cost)
+ double inputCosts = root.inputDependency.stream()
+ .mapToDouble( in -> {
+ double inCost =
in.existingCostPointer(root.hopRef.getHopID()) ?
+ 0 : costEstimate(in,
hopRelMemo).getTotal();
+
in.addCostPointer(root.hopRef.getHopID());
+ return inCost;
+ } )
+ .sum();
Review comment:
if you don't do that i would suggest that you define a method that you
use int the map to double instead of the lambda function defined here.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]