johnyangk closed pull request #88: [NEMO-138] Rule-based policy URL: https://github.com/apache/incubator-nemo/pull/88
This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/common/src/main/java/edu/snu/nemo/common/pass/Pass.java b/common/src/main/java/edu/snu/nemo/common/pass/Pass.java new file mode 100644 index 000000000..6500b5b6b --- /dev/null +++ b/common/src/main/java/edu/snu/nemo/common/pass/Pass.java @@ -0,0 +1,64 @@ +/* + * Copyright (C) 2018 Seoul National University + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package edu.snu.nemo.common.pass; + +import edu.snu.nemo.common.dag.DAG; +import edu.snu.nemo.common.ir.edge.IREdge; +import edu.snu.nemo.common.ir.vertex.IRVertex; + +import java.io.Serializable; +import java.util.function.Predicate; + +/** + * Abstract class for optimization passes. All passes basically extends this class. + */ +public abstract class Pass implements Serializable { + private Predicate<DAG<IRVertex, IREdge>> condition; + + /** + * Default constructor. + */ + public Pass() { + this((dag) -> true); + } + + /** + * Constructor. + * @param condition condition under which to run the pass. + */ + private Pass(final Predicate<DAG<IRVertex, IREdge>> condition) { + this.condition = condition; + } + + /** + * Getter for the condition under which to apply the pass. + * @return the condition under which to apply the pass. + */ + public final Predicate<DAG<IRVertex, IREdge>> getCondition() { + return this.condition; + } + + /** + * Add the condition to the existing condition to run the pass. + * @param newCondition the new condition to add to the existing condition. + * @return the condition with the new condition added. + */ + public final Pass addCondition(final Predicate<DAG<IRVertex, IREdge>> newCondition) { + this.condition = this.condition.and(newCondition); + return this; + } +} diff --git a/compiler/backend/src/test/java/edu/snu/nemo/compiler/backend/nemo/NemoBackendTest.java b/compiler/backend/src/test/java/edu/snu/nemo/compiler/backend/nemo/NemoBackendTest.java index f8435e4c6..26c00a62f 100644 --- a/compiler/backend/src/test/java/edu/snu/nemo/compiler/backend/nemo/NemoBackendTest.java +++ b/compiler/backend/src/test/java/edu/snu/nemo/compiler/backend/nemo/NemoBackendTest.java @@ -21,7 +21,6 @@ import edu.snu.nemo.common.ir.vertex.IRVertex; import edu.snu.nemo.common.ir.vertex.OperatorVertex; import edu.snu.nemo.common.dag.DAGBuilder; -import edu.snu.nemo.compiler.optimizer.CompiletimeOptimizer; import edu.snu.nemo.common.test.EmptyComponents; import edu.snu.nemo.compiler.optimizer.policy.TransientResourcePolicy; import edu.snu.nemo.conf.JobConf; @@ -59,7 +58,7 @@ public void setUp() throws Exception { .connectVertices(new IREdge(CommunicationPatternProperty.Value.OneToOne, combine, map2)) .build(); - this.dag = CompiletimeOptimizer.optimize(dag, new TransientResourcePolicy(), EMPTY_DAG_DIRECTORY); + this.dag = new TransientResourcePolicy().runCompileTimeOptimization(dag, EMPTY_DAG_DIRECTORY); final Injector injector = Tang.Factory.getTang().newInjector(); injector.bindVolatileParameter(JobConf.DAGDirectory.class, ""); diff --git a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/examples/MapReduceDisaggregationOptimization.java b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/examples/MapReduceDisaggregationOptimization.java index 81de704b8..0bb5295ce 100644 --- a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/examples/MapReduceDisaggregationOptimization.java +++ b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/examples/MapReduceDisaggregationOptimization.java @@ -23,7 +23,6 @@ import edu.snu.nemo.common.ir.vertex.OperatorVertex; import edu.snu.nemo.common.test.EmptyComponents; import edu.snu.nemo.compiler.optimizer.policy.DisaggregationPolicy; -import edu.snu.nemo.compiler.optimizer.CompiletimeOptimizer; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -69,7 +68,7 @@ public static void main(final String[] args) throws Exception { LOG.info(dag.toString()); // Optimize - final DAG optimizedDAG = CompiletimeOptimizer.optimize(dag, new DisaggregationPolicy(), EMPTY_DAG_DIRECTORY); + final DAG optimizedDAG = new DisaggregationPolicy().runCompileTimeOptimization(dag, EMPTY_DAG_DIRECTORY); // After LOG.info("After Optimization"); diff --git a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/CompileTimePass.java b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/CompileTimePass.java index 72cbd0105..dfe1fd5bd 100644 --- a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/CompileTimePass.java +++ b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/CompileTimePass.java @@ -19,19 +19,19 @@ import edu.snu.nemo.common.ir.vertex.IRVertex; import edu.snu.nemo.common.dag.DAG; import edu.snu.nemo.common.ir.executionproperty.ExecutionProperty; +import edu.snu.nemo.common.pass.Pass; -import java.io.Serializable; import java.util.Set; import java.util.function.Function; /** - * Interface for compile-time optimization passes that processes the DAG. + * Abstract class for compile-time optimization passes that processes the DAG. * It is a function that takes an original DAG to produce a processed DAG, after an optimization. */ -public interface CompileTimePass extends Function<DAG<IRVertex, IREdge>, DAG<IRVertex, IREdge>>, Serializable { +public abstract class CompileTimePass extends Pass implements Function<DAG<IRVertex, IREdge>, DAG<IRVertex, IREdge>> { /** * Getter for prerequisite execution properties. * @return set of prerequisite execution properties. */ - Set<Class<? extends ExecutionProperty>> getPrerequisiteExecutionProperties(); + public abstract Set<Class<? extends ExecutionProperty>> getPrerequisiteExecutionProperties(); } diff --git a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/annotating/AnnotatingPass.java b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/annotating/AnnotatingPass.java index b7d369076..3fcb7bdb5 100644 --- a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/annotating/AnnotatingPass.java +++ b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/annotating/AnnotatingPass.java @@ -25,7 +25,7 @@ * A compile-time pass that annotates the IR DAG with execution properties. * It is ensured by the compiler that the shape of the IR DAG itself is not modified by an AnnotatingPass. */ -public abstract class AnnotatingPass implements CompileTimePass { +public abstract class AnnotatingPass extends CompileTimePass { private final Class<? extends ExecutionProperty> keyOfExecutionPropertyToModify; private final Set<Class<? extends ExecutionProperty>> prerequisiteExecutionProperties; diff --git a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/composite/CompositePass.java b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/composite/CompositePass.java index e5628ccd5..378d125c1 100644 --- a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/composite/CompositePass.java +++ b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/composite/CompositePass.java @@ -27,7 +27,7 @@ /** * A compile-time pass composed of multiple compile-time passes, which each modifies an IR DAG. */ -public abstract class CompositePass implements CompileTimePass { +public abstract class CompositePass extends CompileTimePass { private final List<CompileTimePass> passList; private final Set<Class<? extends ExecutionProperty>> prerequisiteExecutionProperties; @@ -74,7 +74,6 @@ public CompositePass(final List<CompileTimePass> passList) { } } - @Override public final Set<Class<? extends ExecutionProperty>> getPrerequisiteExecutionProperties() { return prerequisiteExecutionProperties; } diff --git a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/reshaping/ReshapingPass.java b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/reshaping/ReshapingPass.java index 3cb5a9ab1..d8159960c 100644 --- a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/reshaping/ReshapingPass.java +++ b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/reshaping/ReshapingPass.java @@ -25,14 +25,14 @@ * A compile-time pass that reshapes the structure of the IR DAG. * It is ensured by the compiler that no execution properties are modified by a ReshapingPass. */ -public abstract class ReshapingPass implements CompileTimePass { +public abstract class ReshapingPass extends CompileTimePass { private final Set<Class<? extends ExecutionProperty>> prerequisiteExecutionProperties; /** * Default constructor. */ public ReshapingPass() { - this.prerequisiteExecutionProperties = new HashSet<>(); + this(new HashSet<>()); } /** @@ -43,7 +43,6 @@ public ReshapingPass(final Set<Class<? extends ExecutionProperty>> prerequisiteE this.prerequisiteExecutionProperties = prerequisiteExecutionProperties; } - @Override public final Set<Class<? extends ExecutionProperty>> getPrerequisiteExecutionProperties() { return prerequisiteExecutionProperties; } diff --git a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/policy/BasicPullPolicy.java b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/policy/BasicPullPolicy.java index 7fc4fba56..c2514690b 100644 --- a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/policy/BasicPullPolicy.java +++ b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/policy/BasicPullPolicy.java @@ -15,26 +15,37 @@ */ package edu.snu.nemo.compiler.optimizer.policy; -import edu.snu.nemo.compiler.optimizer.pass.compiletime.CompileTimePass; +import edu.snu.nemo.common.dag.DAG; +import edu.snu.nemo.common.eventhandler.PubSubEventHandlerWrapper; +import edu.snu.nemo.common.ir.edge.IREdge; +import edu.snu.nemo.common.ir.vertex.IRVertex; import edu.snu.nemo.compiler.optimizer.pass.compiletime.annotating.DefaultScheduleGroupPass; -import edu.snu.nemo.runtime.common.optimizer.pass.runtime.RuntimePass; - -import java.util.ArrayList; -import java.util.List; +import org.apache.reef.tang.Injector; /** * Basic pull policy. */ public final class BasicPullPolicy implements Policy { + public static final PolicyBuilder BUILDER = + new PolicyBuilder(true) + .registerCompileTimePass(new DefaultScheduleGroupPass()); + private final Policy policy; + + /** + * Default constructor. + */ + public BasicPullPolicy() { + this.policy = BUILDER.build(); + } + @Override - public List<CompileTimePass> getCompileTimePasses() { - List<CompileTimePass> policy = new ArrayList<>(); - policy.add(new DefaultScheduleGroupPass()); - return policy; + public DAG<IRVertex, IREdge> runCompileTimeOptimization(final DAG<IRVertex, IREdge> dag, final String dagDirectory) + throws Exception { + return this.policy.runCompileTimeOptimization(dag, dagDirectory); } @Override - public List<RuntimePass<?>> getRuntimePasses() { - return new ArrayList<>(); + public void registerRunTimeOptimizations(final Injector injector, final PubSubEventHandlerWrapper pubSubWrapper) { + this.policy.registerRunTimeOptimizations(injector, pubSubWrapper); } } diff --git a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/policy/BasicPushPolicy.java b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/policy/BasicPushPolicy.java index 53ff914d8..31a5d7b4c 100644 --- a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/policy/BasicPushPolicy.java +++ b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/policy/BasicPushPolicy.java @@ -15,28 +15,39 @@ */ package edu.snu.nemo.compiler.optimizer.policy; -import edu.snu.nemo.compiler.optimizer.pass.compiletime.CompileTimePass; +import edu.snu.nemo.common.dag.DAG; +import edu.snu.nemo.common.eventhandler.PubSubEventHandlerWrapper; +import edu.snu.nemo.common.ir.edge.IREdge; +import edu.snu.nemo.common.ir.vertex.IRVertex; import edu.snu.nemo.compiler.optimizer.pass.compiletime.annotating.DefaultScheduleGroupPass; import edu.snu.nemo.compiler.optimizer.pass.compiletime.annotating.ShuffleEdgePushPass; -import edu.snu.nemo.runtime.common.optimizer.pass.runtime.RuntimePass; - -import java.util.ArrayList; -import java.util.List; +import org.apache.reef.tang.Injector; /** * Basic push policy. */ public final class BasicPushPolicy implements Policy { + public static final PolicyBuilder BUILDER = + new PolicyBuilder(true) + .registerCompileTimePass(new ShuffleEdgePushPass()) + .registerCompileTimePass(new DefaultScheduleGroupPass()); + private final Policy policy; + + /** + * Default constructor. + */ + public BasicPushPolicy() { + this.policy = BUILDER.build(); + } + @Override - public List<CompileTimePass> getCompileTimePasses() { - List<CompileTimePass> policy = new ArrayList<>(); - policy.add(new ShuffleEdgePushPass()); - policy.add(new DefaultScheduleGroupPass()); - return policy; + public DAG<IRVertex, IREdge> runCompileTimeOptimization(final DAG<IRVertex, IREdge> dag, final String dagDirectory) + throws Exception { + return this.policy.runCompileTimeOptimization(dag, dagDirectory); } @Override - public List<RuntimePass<?>> getRuntimePasses() { - return new ArrayList<>(); + public void registerRunTimeOptimizations(final Injector injector, final PubSubEventHandlerWrapper pubSubWrapper) { + this.policy.registerRunTimeOptimizations(injector, pubSubWrapper); } } diff --git a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/policy/ConditionalLargeShufflePolicy.java b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/policy/ConditionalLargeShufflePolicy.java new file mode 100644 index 000000000..efdf14d84 --- /dev/null +++ b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/policy/ConditionalLargeShufflePolicy.java @@ -0,0 +1,68 @@ +/* + * Copyright (C) 2018 Seoul National University + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package edu.snu.nemo.compiler.optimizer.policy; + +import edu.snu.nemo.common.dag.DAG; +import edu.snu.nemo.common.eventhandler.PubSubEventHandlerWrapper; +import edu.snu.nemo.common.ir.edge.IREdge; +import edu.snu.nemo.common.ir.vertex.IRVertex; +import edu.snu.nemo.common.ir.vertex.executionproperty.ParallelismProperty; +import edu.snu.nemo.compiler.optimizer.pass.compiletime.composite.DefaultCompositePass; +import edu.snu.nemo.compiler.optimizer.pass.compiletime.composite.LargeShuffleCompositePass; +import edu.snu.nemo.compiler.optimizer.pass.compiletime.composite.LoopOptimizationCompositePass; +import org.apache.reef.tang.Injector; + +/** + * A policy to demonstrate the large shuffle optimization, witch batches disk seek during data shuffle, conditionally. + */ +public final class ConditionalLargeShufflePolicy implements Policy { + public static final PolicyBuilder BUILDER = + new PolicyBuilder(false) + .registerCompileTimePass(new LargeShuffleCompositePass(), dag -> getMaxParallelism(dag) > 300) + .registerCompileTimePass(new LoopOptimizationCompositePass()) + .registerCompileTimePass(new DefaultCompositePass()); + private final Policy policy; + + /** + * Default constructor. + */ + public ConditionalLargeShufflePolicy() { + this.policy = BUILDER.build(); + } + + /** + * Returns the maximum parallelism of the vertices of a IR DAG. + * @param dag dag to observe. + * @return the maximum parallelism, or 1 by default. + */ + private static int getMaxParallelism(final DAG<IRVertex, IREdge> dag) { + return dag.getVertices().stream() + .mapToInt(vertex -> vertex.getPropertyValue(ParallelismProperty.class).orElse(1)) + .max().orElse(1); + } + + @Override + public DAG<IRVertex, IREdge> runCompileTimeOptimization(final DAG<IRVertex, IREdge> dag, final String dagDirectory) + throws Exception { + return this.policy.runCompileTimeOptimization(dag, dagDirectory); + } + + @Override + public void registerRunTimeOptimizations(final Injector injector, final PubSubEventHandlerWrapper pubSubWrapper) { + this.policy.registerRunTimeOptimizations(injector, pubSubWrapper); + } +} diff --git a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/policy/DataSkewPolicy.java b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/policy/DataSkewPolicy.java index 03be205c3..be23e6c96 100644 --- a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/policy/DataSkewPolicy.java +++ b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/policy/DataSkewPolicy.java @@ -15,30 +15,33 @@ */ package edu.snu.nemo.compiler.optimizer.policy; -import edu.snu.nemo.compiler.optimizer.pass.compiletime.CompileTimePass; +import edu.snu.nemo.common.dag.DAG; +import edu.snu.nemo.common.eventhandler.PubSubEventHandlerWrapper; +import edu.snu.nemo.common.ir.edge.IREdge; +import edu.snu.nemo.common.ir.vertex.IRVertex; import edu.snu.nemo.compiler.optimizer.pass.compiletime.composite.SkewCompositePass; import edu.snu.nemo.compiler.optimizer.pass.compiletime.composite.LoopOptimizationCompositePass; import edu.snu.nemo.compiler.optimizer.pass.compiletime.composite.DefaultCompositePass; import edu.snu.nemo.runtime.common.optimizer.pass.runtime.DataSkewRuntimePass; -import edu.snu.nemo.runtime.common.optimizer.pass.runtime.RuntimePass; - -import java.util.List; +import org.apache.reef.tang.Injector; /** * A policy to perform data skew dynamic optimization. */ public final class DataSkewPolicy implements Policy { + public static final PolicyBuilder BUILDER = + new PolicyBuilder(true) + .registerRuntimePass(new DataSkewRuntimePass().setNumSkewedKeys(DataSkewRuntimePass.DEFAULT_NUM_SKEWED_KEYS), + new SkewCompositePass()) + .registerCompileTimePass(new LoopOptimizationCompositePass()) + .registerCompileTimePass(new DefaultCompositePass()); private final Policy policy; /** * Default constructor. */ public DataSkewPolicy() { - this.policy = new PolicyBuilder(true) - .registerRuntimePass(new DataSkewRuntimePass(), new SkewCompositePass()) - .registerCompileTimePass(new LoopOptimizationCompositePass()) - .registerCompileTimePass(new DefaultCompositePass()) - .build(); + this.policy = BUILDER.build(); } public DataSkewPolicy(final int skewness) { @@ -50,12 +53,13 @@ public DataSkewPolicy(final int skewness) { } @Override - public List<CompileTimePass> getCompileTimePasses() { - return this.policy.getCompileTimePasses(); + public DAG<IRVertex, IREdge> runCompileTimeOptimization(final DAG<IRVertex, IREdge> dag, final String dagDirectory) + throws Exception { + return this.policy.runCompileTimeOptimization(dag, dagDirectory); } @Override - public List<RuntimePass<?>> getRuntimePasses() { - return this.policy.getRuntimePasses(); + public void registerRunTimeOptimizations(final Injector injector, final PubSubEventHandlerWrapper pubSubWrapper) { + this.policy.registerRunTimeOptimizations(injector, pubSubWrapper); } } diff --git a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/policy/DefaultPolicy.java b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/policy/DefaultPolicy.java index 851e5dd2d..986aec5e0 100644 --- a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/policy/DefaultPolicy.java +++ b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/policy/DefaultPolicy.java @@ -15,34 +15,37 @@ */ package edu.snu.nemo.compiler.optimizer.policy; -import edu.snu.nemo.compiler.optimizer.pass.compiletime.CompileTimePass; +import edu.snu.nemo.common.dag.DAG; +import edu.snu.nemo.common.eventhandler.PubSubEventHandlerWrapper; +import edu.snu.nemo.common.ir.edge.IREdge; +import edu.snu.nemo.common.ir.vertex.IRVertex; import edu.snu.nemo.compiler.optimizer.pass.compiletime.composite.DefaultCompositePass; -import edu.snu.nemo.runtime.common.optimizer.pass.runtime.RuntimePass; - -import java.util.List; +import org.apache.reef.tang.Injector; /** * A basic default policy, that performs the minimum amount of optimization to be done to a specific DAG. */ public final class DefaultPolicy implements Policy { + public static final PolicyBuilder BUILDER = + new PolicyBuilder(true) + .registerCompileTimePass(new DefaultCompositePass()); private final Policy policy; /** * Default constructor. */ public DefaultPolicy() { - this.policy = new PolicyBuilder(true) - .registerCompileTimePass(new DefaultCompositePass()) - .build(); + this.policy = BUILDER.build(); } @Override - public List<CompileTimePass> getCompileTimePasses() { - return this.policy.getCompileTimePasses(); + public DAG<IRVertex, IREdge> runCompileTimeOptimization(final DAG<IRVertex, IREdge> dag, final String dagDirectory) + throws Exception { + return this.policy.runCompileTimeOptimization(dag, dagDirectory); } @Override - public List<RuntimePass<?>> getRuntimePasses() { - return this.policy.getRuntimePasses(); + public void registerRunTimeOptimizations(final Injector injector, final PubSubEventHandlerWrapper pubSubWrapper) { + this.policy.registerRunTimeOptimizations(injector, pubSubWrapper); } } diff --git a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/policy/DefaultPolicyWithSeparatePass.java b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/policy/DefaultPolicyWithSeparatePass.java index b139ee2db..c832d1fcb 100644 --- a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/policy/DefaultPolicyWithSeparatePass.java +++ b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/policy/DefaultPolicyWithSeparatePass.java @@ -15,15 +15,17 @@ */ package edu.snu.nemo.compiler.optimizer.policy; +import edu.snu.nemo.common.dag.DAG; +import edu.snu.nemo.common.eventhandler.PubSubEventHandlerWrapper; +import edu.snu.nemo.common.ir.edge.IREdge; +import edu.snu.nemo.common.ir.vertex.IRVertex; import edu.snu.nemo.compiler.optimizer.pass.compiletime.annotating.DefaultParallelismPass; import edu.snu.nemo.compiler.optimizer.pass.compiletime.annotating.DefaultDataStorePass; import edu.snu.nemo.compiler.optimizer.pass.compiletime.annotating.DefaultScheduleGroupPass; -import edu.snu.nemo.compiler.optimizer.pass.compiletime.CompileTimePass; import edu.snu.nemo.compiler.optimizer.pass.compiletime.composite.CompositePass; -import edu.snu.nemo.runtime.common.optimizer.pass.runtime.RuntimePass; +import org.apache.reef.tang.Injector; import java.util.Arrays; -import java.util.List; /** * A simple example policy to demonstrate a policy with a separate, refactored pass. @@ -31,32 +33,34 @@ * This example simply shows that users can define their own pass in their policy. */ public final class DefaultPolicyWithSeparatePass implements Policy { + public static final PolicyBuilder BUILDER = + new PolicyBuilder(true) + .registerCompileTimePass(new DefaultParallelismPass()) + .registerCompileTimePass(new RefactoredPass()); private final Policy policy; /** * Default constructor. */ public DefaultPolicyWithSeparatePass() { - this.policy = new PolicyBuilder(true) - .registerCompileTimePass(new DefaultParallelismPass()) - .registerCompileTimePass(new RefactoredPass()) - .build(); + this.policy = BUILDER.build(); } @Override - public List<CompileTimePass> getCompileTimePasses() { - return this.policy.getCompileTimePasses(); + public DAG<IRVertex, IREdge> runCompileTimeOptimization(final DAG<IRVertex, IREdge> dag, final String dagDirectory) + throws Exception { + return this.policy.runCompileTimeOptimization(dag, dagDirectory); } @Override - public List<RuntimePass<?>> getRuntimePasses() { - return this.policy.getRuntimePasses(); + public void registerRunTimeOptimizations(final Injector injector, final PubSubEventHandlerWrapper pubSubWrapper) { + this.policy.registerRunTimeOptimizations(injector, pubSubWrapper); } /** * A simple custom pass consisted of the two passes at the end of the default pass. */ - public final class RefactoredPass extends CompositePass { + public static final class RefactoredPass extends CompositePass { /** * Default constructor. */ diff --git a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/policy/DisaggregationPolicy.java b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/policy/DisaggregationPolicy.java index 8d69b2495..667b6705d 100644 --- a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/policy/DisaggregationPolicy.java +++ b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/policy/DisaggregationPolicy.java @@ -15,38 +15,41 @@ */ package edu.snu.nemo.compiler.optimizer.policy; -import edu.snu.nemo.compiler.optimizer.pass.compiletime.*; +import edu.snu.nemo.common.dag.DAG; +import edu.snu.nemo.common.eventhandler.PubSubEventHandlerWrapper; +import edu.snu.nemo.common.ir.edge.IREdge; +import edu.snu.nemo.common.ir.vertex.IRVertex; import edu.snu.nemo.compiler.optimizer.pass.compiletime.annotating.*; import edu.snu.nemo.compiler.optimizer.pass.compiletime.composite.DefaultCompositePass; import edu.snu.nemo.compiler.optimizer.pass.compiletime.composite.LoopOptimizationCompositePass; -import edu.snu.nemo.runtime.common.optimizer.pass.runtime.RuntimePass; - -import java.util.List; +import org.apache.reef.tang.Injector; /** * A policy to demonstrate the disaggregation optimization, that uses GlusterFS as file storage. */ public final class DisaggregationPolicy implements Policy { + public static final PolicyBuilder BUILDER = + new PolicyBuilder(false) + .registerCompileTimePass(new LoopOptimizationCompositePass()) + .registerCompileTimePass(new DefaultCompositePass()) + .registerCompileTimePass(new DisaggregationEdgeDataStorePass()); private final Policy policy; /** * Default constructor. */ public DisaggregationPolicy() { - this.policy = new PolicyBuilder(false) - .registerCompileTimePass(new LoopOptimizationCompositePass()) - .registerCompileTimePass(new DefaultCompositePass()) - .registerCompileTimePass(new DisaggregationEdgeDataStorePass()) - .build(); + this.policy = BUILDER.build(); } @Override - public List<CompileTimePass> getCompileTimePasses() { - return this.policy.getCompileTimePasses(); + public DAG<IRVertex, IREdge> runCompileTimeOptimization(final DAG<IRVertex, IREdge> dag, final String dagDirectory) + throws Exception { + return this.policy.runCompileTimeOptimization(dag, dagDirectory); } @Override - public List<RuntimePass<?>> getRuntimePasses() { - return this.policy.getRuntimePasses(); + public void registerRunTimeOptimizations(final Injector injector, final PubSubEventHandlerWrapper pubSubWrapper) { + this.policy.registerRunTimeOptimizations(injector, pubSubWrapper); } } diff --git a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/policy/LargeShufflePolicy.java b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/policy/LargeShufflePolicy.java index 7f0115131..018c403ab 100644 --- a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/policy/LargeShufflePolicy.java +++ b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/policy/LargeShufflePolicy.java @@ -15,38 +15,41 @@ */ package edu.snu.nemo.compiler.optimizer.policy; -import edu.snu.nemo.compiler.optimizer.pass.compiletime.CompileTimePass; +import edu.snu.nemo.common.dag.DAG; +import edu.snu.nemo.common.eventhandler.PubSubEventHandlerWrapper; +import edu.snu.nemo.common.ir.edge.IREdge; +import edu.snu.nemo.common.ir.vertex.IRVertex; import edu.snu.nemo.compiler.optimizer.pass.compiletime.composite.DefaultCompositePass; import edu.snu.nemo.compiler.optimizer.pass.compiletime.composite.LoopOptimizationCompositePass; import edu.snu.nemo.compiler.optimizer.pass.compiletime.composite.LargeShuffleCompositePass; -import edu.snu.nemo.runtime.common.optimizer.pass.runtime.RuntimePass; - -import java.util.List; +import org.apache.reef.tang.Injector; /** * A policy to demonstrate the large shuffle optimization, witch batches disk seek during data shuffle. */ public final class LargeShufflePolicy implements Policy { + public static final PolicyBuilder BUILDER = + new PolicyBuilder(false) + .registerCompileTimePass(new LargeShuffleCompositePass()) + .registerCompileTimePass(new LoopOptimizationCompositePass()) + .registerCompileTimePass(new DefaultCompositePass()); private final Policy policy; /** * Default constructor. */ public LargeShufflePolicy() { - this.policy = new PolicyBuilder(false) - .registerCompileTimePass(new LargeShuffleCompositePass()) - .registerCompileTimePass(new LoopOptimizationCompositePass()) - .registerCompileTimePass(new DefaultCompositePass()) - .build(); + this.policy = BUILDER.build(); } @Override - public List<CompileTimePass> getCompileTimePasses() { - return this.policy.getCompileTimePasses(); + public DAG<IRVertex, IREdge> runCompileTimeOptimization(final DAG<IRVertex, IREdge> dag, final String dagDirectory) + throws Exception { + return this.policy.runCompileTimeOptimization(dag, dagDirectory); } @Override - public List<RuntimePass<?>> getRuntimePasses() { - return this.policy.getRuntimePasses(); + public void registerRunTimeOptimizations(final Injector injector, final PubSubEventHandlerWrapper pubSubWrapper) { + this.policy.registerRunTimeOptimizations(injector, pubSubWrapper); } } diff --git a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/policy/Policy.java b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/policy/Policy.java index 6f789f998..65160121b 100644 --- a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/policy/Policy.java +++ b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/policy/Policy.java @@ -15,11 +15,13 @@ */ package edu.snu.nemo.compiler.optimizer.policy; -import edu.snu.nemo.compiler.optimizer.pass.compiletime.CompileTimePass; -import edu.snu.nemo.runtime.common.optimizer.pass.runtime.RuntimePass; +import edu.snu.nemo.common.dag.DAG; +import edu.snu.nemo.common.eventhandler.PubSubEventHandlerWrapper; +import edu.snu.nemo.common.ir.edge.IREdge; +import edu.snu.nemo.common.ir.vertex.IRVertex; +import org.apache.reef.tang.Injector; import java.io.Serializable; -import java.util.List; /** * An interface for policies, each of which is composed of a list of static optimization passes. @@ -27,12 +29,18 @@ */ public interface Policy extends Serializable { /** - * @return the content of the policy: the list of static optimization passes of the policy. + * Optimize the DAG with the compile time optimizations. + * @param dag input DAG. + * @param dagDirectory directory to save the DAG information. + * @return optimized DAG, reshaped or tagged with execution properties. + * @throws Exception throws an exception if there is an exception. */ - List<CompileTimePass> getCompileTimePasses(); + DAG<IRVertex, IREdge> runCompileTimeOptimization(DAG<IRVertex, IREdge> dag, String dagDirectory) throws Exception; /** - * @return the content of the policy: the list of runtime passses of the policy. + * Register runtime optimizations to the event handler. + * @param injector Tang Injector, used in the UserApplicationRunner. + * @param pubSubWrapper pub-sub event handler, used in the UserApplicationRunner. */ - List<RuntimePass<?>> getRuntimePasses(); + void registerRunTimeOptimizations(Injector injector, PubSubEventHandlerWrapper pubSubWrapper); } diff --git a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/policy/PolicyBuilder.java b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/policy/PolicyBuilder.java index 9928991c5..bf939b93e 100644 --- a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/policy/PolicyBuilder.java +++ b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/policy/PolicyBuilder.java @@ -15,12 +15,15 @@ */ package edu.snu.nemo.compiler.optimizer.policy; +import edu.snu.nemo.common.dag.DAG; import edu.snu.nemo.common.exception.CompileTimeOptimizationException; +import edu.snu.nemo.common.ir.edge.IREdge; import edu.snu.nemo.common.ir.edge.executionproperty.CommunicationPatternProperty; import edu.snu.nemo.common.ir.edge.executionproperty.DataFlowProperty; import edu.snu.nemo.common.ir.edge.executionproperty.DataStoreProperty; import edu.snu.nemo.common.ir.edge.executionproperty.PartitionerProperty; import edu.snu.nemo.common.ir.executionproperty.ExecutionProperty; +import edu.snu.nemo.common.ir.vertex.IRVertex; import edu.snu.nemo.common.ir.vertex.executionproperty.ResourcePriorityProperty; import edu.snu.nemo.common.ir.vertex.executionproperty.ParallelismProperty; import edu.snu.nemo.compiler.optimizer.pass.compiletime.CompileTimePass; @@ -28,10 +31,8 @@ import edu.snu.nemo.compiler.optimizer.pass.compiletime.composite.CompositePass; import edu.snu.nemo.runtime.common.optimizer.pass.runtime.RuntimePass; -import java.util.ArrayList; -import java.util.HashSet; -import java.util.List; -import java.util.Set; +import java.util.*; +import java.util.function.Predicate; /** * A builder for policies. @@ -52,7 +53,6 @@ public PolicyBuilder() { /** * Constructor. - * * @param strictPrerequisiteCheckMode whether to use strict prerequisite check mode or not. */ public PolicyBuilder(final Boolean strictPrerequisiteCheckMode) { @@ -72,9 +72,9 @@ public PolicyBuilder(final Boolean strictPrerequisiteCheckMode) { } /** - * Register compile time pass. + * Register a compile time pass. * @param compileTimePass the compile time pass to register. - * @return the PolicyBuilder which registers compileTimePass. + * @return the PolicyBuilder which registers the compileTimePass. */ public PolicyBuilder registerCompileTimePass(final CompileTimePass compileTimePass) { // We decompose CompositePasses. @@ -108,33 +108,65 @@ public PolicyBuilder registerCompileTimePass(final CompileTimePass compileTimePa } /** - * Register run time passes. + * Register compile time pass with its condition under which to run the pass. + * @param compileTimePass the compile time pass to register. + * @param condition condition under which to run the pass. + * @return the PolicyBuilder which registers the compileTimePass. + */ + public PolicyBuilder registerCompileTimePass(final CompileTimePass compileTimePass, + final Predicate<DAG<IRVertex, IREdge>> condition) { + compileTimePass.addCondition(condition); + return this.registerCompileTimePass(compileTimePass); + } + + /** + * Register a run time pass. * @param runtimePass the runtime pass to register. - * @param runtimePassRegistrator the compile time pass that triggers the runtime pass. - * @return the PolicyBuilder which registers runtimePass and runtimePassRegistrator. + * @param runtimePassRegisterer the compile time pass that triggers the runtime pass. + * @return the PolicyBuilder which registers the runtimePass and the runtimePassRegisterer. */ public PolicyBuilder registerRuntimePass(final RuntimePass<?> runtimePass, - final CompileTimePass runtimePassRegistrator) { - registerCompileTimePass(runtimePassRegistrator); + final CompileTimePass runtimePassRegisterer) { + registerCompileTimePass(runtimePassRegisterer); this.runtimePasses.add(runtimePass); return this; } + /** + * Register a run time pass. + * @param runtimePass the runtime pass to register. + * @param runtimePassRegisterer the compile time pass that triggers the runtime pass. + * @param condition condition under which to run the pass. + * @return the PolicyBuilder which registers the runtimePass and the runtimePassRegisterer. + */ + public PolicyBuilder registerRuntimePass(final RuntimePass<?> runtimePass, + final CompileTimePass runtimePassRegisterer, + final Predicate<DAG<IRVertex, IREdge>> condition) { + runtimePass.addCondition(condition); + return this.registerRuntimePass(runtimePass, runtimePassRegisterer); + } + + /** + * Getter for compile time passes. + * @return the list of compile time passes. + */ + public List<CompileTimePass> getCompileTimePasses() { + return compileTimePasses; + } + + /** + * Getter for run time passes. + * @return the list of run time passes. + */ + public List<RuntimePass<?>> getRuntimePasses() { + return runtimePasses; + } + /** * Build a policy using compileTimePasses and runtimePasses in this object. * @return the built Policy. */ public Policy build() { - return new Policy() { - @Override - public List<CompileTimePass> getCompileTimePasses() { - return compileTimePasses; - } - - @Override - public List<RuntimePass<?>> getRuntimePasses() { - return runtimePasses; - } - }; + return new PolicyImpl(compileTimePasses, runtimePasses); } } diff --git a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/CompiletimeOptimizer.java b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/policy/PolicyImpl.java similarity index 64% rename from compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/CompiletimeOptimizer.java rename to compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/policy/PolicyImpl.java index 1c11f1ec2..99dc4c4dc 100644 --- a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/CompiletimeOptimizer.java +++ b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/policy/PolicyImpl.java @@ -13,43 +13,49 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package edu.snu.nemo.compiler.optimizer; +package edu.snu.nemo.compiler.optimizer.policy; + +import edu.snu.nemo.common.dag.DAG; +import edu.snu.nemo.common.eventhandler.PubSubEventHandlerWrapper; +import edu.snu.nemo.common.eventhandler.RuntimeEventHandler; import edu.snu.nemo.common.exception.CompileTimeOptimizationException; import edu.snu.nemo.common.ir.edge.IREdge; import edu.snu.nemo.common.ir.vertex.IRVertex; -import edu.snu.nemo.common.dag.DAG; import edu.snu.nemo.compiler.optimizer.pass.compiletime.CompileTimePass; import edu.snu.nemo.compiler.optimizer.pass.compiletime.annotating.AnnotatingPass; import edu.snu.nemo.compiler.optimizer.pass.compiletime.reshaping.ReshapingPass; -import edu.snu.nemo.compiler.optimizer.policy.Policy; +import edu.snu.nemo.runtime.common.optimizer.pass.runtime.RuntimePass; +import org.apache.reef.tang.Injector; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; -import java.util.*; +import java.util.Iterator; +import java.util.List; /** - * Compile time optimizer class. + * Implementation of the {@link Policy} interface. */ -public final class CompiletimeOptimizer { +public final class PolicyImpl implements Policy { + private final List<CompileTimePass> compileTimePasses; + private final List<RuntimePass<?>> runtimePasses; + private static final Logger LOG = LoggerFactory.getLogger(PolicyImpl.class.getName()); + /** - * Private constructor. + * Constructor. + * @param compileTimePasses compile time passes of the policy. + * @param runtimePasses run time passes of the policy. */ - private CompiletimeOptimizer() { + public PolicyImpl(final List<CompileTimePass> compileTimePasses, final List<RuntimePass<?>> runtimePasses) { + this.compileTimePasses = compileTimePasses; + this.runtimePasses = runtimePasses; } - /** - * Optimize function. - * @param dag input DAG. - * @param optimizationPolicy the optimization policy that we want to use to optimize the DAG. - * @param dagDirectory directory to save the DAG information. - * @return optimized DAG, tagged with execution properties. - * @throws Exception throws an exception if there is an exception. - */ - public static DAG<IRVertex, IREdge> optimize(final DAG<IRVertex, IREdge> dag, final Policy optimizationPolicy, - final String dagDirectory) throws Exception { - if (optimizationPolicy == null || optimizationPolicy.getCompileTimePasses().isEmpty()) { - throw new CompileTimeOptimizationException("A policy name should be specified."); - } - return process(dag, optimizationPolicy.getCompileTimePasses().iterator(), dagDirectory); + @Override + public DAG<IRVertex, IREdge> runCompileTimeOptimization(final DAG<IRVertex, IREdge> dag, final String dagDirectory) + throws Exception { + LOG.info("Launch Compile-time optimizations"); + return process(dag, compileTimePasses.iterator(), dagDirectory); } /** @@ -65,18 +71,26 @@ private CompiletimeOptimizer() { final String dagDirectory) throws Exception { if (passes.hasNext()) { final CompileTimePass passToApply = passes.next(); - // Apply the pass to the DAG. - final DAG<IRVertex, IREdge> processedDAG = passToApply.apply(dag); - // Ensure AnnotatingPass and ReshapingPass functions as intended. - if ((passToApply instanceof AnnotatingPass && !checkAnnotatingPass(dag, processedDAG)) - || (passToApply instanceof ReshapingPass && !checkReshapingPass(dag, processedDAG))) { - throw new CompileTimeOptimizationException(passToApply.getClass().getSimpleName() - + " is implemented in a way that doesn't follow its original intention of annotating or reshaping. " - + "Modify it or use a general CompileTimePass"); + final DAG<IRVertex, IREdge> processedDAG; + + if (passToApply.getCondition().test(dag)) { + LOG.info("Apply {} to the DAG", passToApply.getClass().getSimpleName()); + // Apply the pass to the DAG. + processedDAG = passToApply.apply(dag); + // Ensure AnnotatingPass and ReshapingPass functions as intended. + if ((passToApply instanceof AnnotatingPass && !checkAnnotatingPass(dag, processedDAG)) + || (passToApply instanceof ReshapingPass && !checkReshapingPass(dag, processedDAG))) { + throw new CompileTimeOptimizationException(passToApply.getClass().getSimpleName() + + " is implemented in a way that doesn't follow its original intention of annotating or reshaping. " + + "Modify it or use a general CompileTimePass"); + } + // Save the processed JSON DAG. + processedDAG.storeJSON(dagDirectory, "ir-after-" + passToApply.getClass().getSimpleName(), + "DAG after optimization"); + } else { + LOG.info("Condition unmet for applying {} to the DAG", passToApply.getClass().getSimpleName()); + processedDAG = dag; } - // Save the processed JSON DAG. - processedDAG.storeJSON(dagDirectory, "ir-after-" + passToApply.getClass().getSimpleName(), - "DAG after optimization"); // recursively apply the following passes. return process(processedDAG, passes, dagDirectory); } else { @@ -156,4 +170,19 @@ private static Boolean checkReshapingPass(final DAG<IRVertex, IREdge> before, fi } return true; } + + @Override + public void registerRunTimeOptimizations(final Injector injector, final PubSubEventHandlerWrapper pubSubWrapper) { + LOG.info("Register run-time optimizations to the PubSubHandler"); + runtimePasses.forEach(runtimePass -> + runtimePass.getEventHandlerClasses().forEach(runtimeEventHandlerClass -> { + try { + final RuntimeEventHandler runtimeEventHandler = injector.getInstance(runtimeEventHandlerClass); + pubSubWrapper.getPubSubEventHandler() + .subscribe(runtimeEventHandler.getEventClass(), runtimeEventHandler); + } catch (final Exception e) { + throw new RuntimeException(e); + } + })); + } } diff --git a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/policy/TransientResourcePolicy.java b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/policy/TransientResourcePolicy.java index 19baadeb9..caa29d99e 100644 --- a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/policy/TransientResourcePolicy.java +++ b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/policy/TransientResourcePolicy.java @@ -15,38 +15,41 @@ */ package edu.snu.nemo.compiler.optimizer.policy; -import edu.snu.nemo.compiler.optimizer.pass.compiletime.*; +import edu.snu.nemo.common.dag.DAG; +import edu.snu.nemo.common.eventhandler.PubSubEventHandlerWrapper; +import edu.snu.nemo.common.ir.edge.IREdge; +import edu.snu.nemo.common.ir.vertex.IRVertex; import edu.snu.nemo.compiler.optimizer.pass.compiletime.composite.DefaultCompositePass; import edu.snu.nemo.compiler.optimizer.pass.compiletime.composite.LoopOptimizationCompositePass; import edu.snu.nemo.compiler.optimizer.pass.compiletime.composite.TransientResourceCompositePass; -import edu.snu.nemo.runtime.common.optimizer.pass.runtime.RuntimePass; - -import java.util.List; +import org.apache.reef.tang.Injector; /** * A policy to perform optimization that uses transient resources in data centers. */ public final class TransientResourcePolicy implements Policy { + public static final PolicyBuilder BUILDER = + new PolicyBuilder(true) + .registerCompileTimePass(new TransientResourceCompositePass()) + .registerCompileTimePass(new LoopOptimizationCompositePass()) + .registerCompileTimePass(new DefaultCompositePass()); private final Policy policy; /** * Default constructor. */ public TransientResourcePolicy() { - this.policy = new PolicyBuilder(true) - .registerCompileTimePass(new TransientResourceCompositePass()) - .registerCompileTimePass(new LoopOptimizationCompositePass()) - .registerCompileTimePass(new DefaultCompositePass()) - .build(); + this.policy = BUILDER.build(); } @Override - public List<CompileTimePass> getCompileTimePasses() { - return this.policy.getCompileTimePasses(); + public DAG<IRVertex, IREdge> runCompileTimeOptimization(final DAG<IRVertex, IREdge> dag, final String dagDirectory) + throws Exception { + return this.policy.runCompileTimeOptimization(dag, dagDirectory); } @Override - public List<RuntimePass<?>> getRuntimePasses() { - return this.policy.getRuntimePasses(); + public void registerRunTimeOptimizations(final Injector injector, final PubSubEventHandlerWrapper pubSubWrapper) { + this.policy.registerRunTimeOptimizations(injector, pubSubWrapper); } } diff --git a/compiler/optimizer/src/test/java/edu/snu/nemo/compiler/optimizer/policy/PolicyBuilderTest.java b/compiler/optimizer/src/test/java/edu/snu/nemo/compiler/optimizer/policy/PolicyBuilderTest.java index 4503fbebc..f75cb8c62 100644 --- a/compiler/optimizer/src/test/java/edu/snu/nemo/compiler/optimizer/policy/PolicyBuilderTest.java +++ b/compiler/optimizer/src/test/java/edu/snu/nemo/compiler/optimizer/policy/PolicyBuilderTest.java @@ -26,23 +26,20 @@ public final class PolicyBuilderTest { @Test public void testDisaggregationPolicy() { - final Policy disaggregationPolicy = new DisaggregationPolicy(); - assertEquals(17, disaggregationPolicy.getCompileTimePasses().size()); - assertEquals(0, disaggregationPolicy.getRuntimePasses().size()); + assertEquals(17, DisaggregationPolicy.BUILDER.getCompileTimePasses().size()); + assertEquals(0, DisaggregationPolicy.BUILDER.getRuntimePasses().size()); } @Test public void testTransientResourcePolicy() { - final Policy transientResourcePolicy = new TransientResourcePolicy(); - assertEquals(19, transientResourcePolicy.getCompileTimePasses().size()); - assertEquals(0, transientResourcePolicy.getRuntimePasses().size()); + assertEquals(19, TransientResourcePolicy.BUILDER.getCompileTimePasses().size()); + assertEquals(0, TransientResourcePolicy.BUILDER.getRuntimePasses().size()); } @Test public void testDataSkewPolicy() { - final Policy dataSkewPolicy = new DataSkewPolicy(); - assertEquals(21, dataSkewPolicy.getCompileTimePasses().size()); - assertEquals(1, dataSkewPolicy.getRuntimePasses().size()); + assertEquals(21, DataSkewPolicy.BUILDER.getCompileTimePasses().size()); + assertEquals(1, DataSkewPolicy.BUILDER.getRuntimePasses().size()); } @Test diff --git a/compiler/test/src/main/java/edu/snu/nemo/compiler/optimizer/policy/TestPolicy.java b/compiler/test/src/main/java/edu/snu/nemo/compiler/optimizer/policy/TestPolicy.java index aca479148..061d431aa 100644 --- a/compiler/test/src/main/java/edu/snu/nemo/compiler/optimizer/policy/TestPolicy.java +++ b/compiler/test/src/main/java/edu/snu/nemo/compiler/optimizer/policy/TestPolicy.java @@ -15,9 +15,13 @@ */ package edu.snu.nemo.compiler.optimizer.policy; +import edu.snu.nemo.common.dag.DAG; +import edu.snu.nemo.common.eventhandler.PubSubEventHandlerWrapper; +import edu.snu.nemo.common.ir.edge.IREdge; +import edu.snu.nemo.common.ir.vertex.IRVertex; import edu.snu.nemo.compiler.optimizer.pass.compiletime.CompileTimePass; import edu.snu.nemo.compiler.optimizer.pass.compiletime.annotating.*; -import edu.snu.nemo.runtime.common.optimizer.pass.runtime.RuntimePass; +import org.apache.reef.tang.Injector; import java.util.*; @@ -25,30 +29,32 @@ * A policy for tests. */ public final class TestPolicy implements Policy { - private final boolean testPushPolicy; + private final Policy policy; public TestPolicy() { this(false); } public TestPolicy(final boolean testPushPolicy) { - this.testPushPolicy = testPushPolicy; - } - - @Override - public List<CompileTimePass> getCompileTimePasses() { - List<CompileTimePass> policy = new ArrayList<>(); + List<CompileTimePass> compileTimePasses = new ArrayList<>(); if (testPushPolicy) { - policy.add(new ShuffleEdgePushPass()); + compileTimePasses.add(new ShuffleEdgePushPass()); } - policy.add(new DefaultScheduleGroupPass()); - return policy; + compileTimePasses.add(new DefaultScheduleGroupPass()); + + this.policy = new PolicyImpl(compileTimePasses, new ArrayList<>()); + } + + @Override + public DAG<IRVertex, IREdge> runCompileTimeOptimization(final DAG<IRVertex, IREdge> dag, final String dagDirectory) + throws Exception { + return this.policy.runCompileTimeOptimization(dag, dagDirectory); } @Override - public List<RuntimePass<?>> getRuntimePasses() { - return new ArrayList<>(); + public void registerRunTimeOptimizations(final Injector injector, final PubSubEventHandlerWrapper pubSubWrapper) { + this.policy.registerRunTimeOptimizations(injector, pubSubWrapper); } } diff --git a/compiler/test/src/test/java/edu/snu/nemo/compiler/backend/nemo/DAGConverterTest.java b/compiler/test/src/test/java/edu/snu/nemo/compiler/backend/nemo/DAGConverterTest.java index 79721fe6d..b4237bfdd 100644 --- a/compiler/test/src/test/java/edu/snu/nemo/compiler/backend/nemo/DAGConverterTest.java +++ b/compiler/test/src/test/java/edu/snu/nemo/compiler/backend/nemo/DAGConverterTest.java @@ -28,7 +28,6 @@ import edu.snu.nemo.common.ir.vertex.executionproperty.ParallelismProperty; import edu.snu.nemo.common.ir.vertex.transform.Transform; import edu.snu.nemo.compiler.frontend.beam.transform.DoTransform; -import edu.snu.nemo.compiler.optimizer.CompiletimeOptimizer; import edu.snu.nemo.common.test.EmptyComponents; import edu.snu.nemo.conf.JobConf; import edu.snu.nemo.runtime.common.plan.PhysicalPlanGenerator; @@ -74,8 +73,8 @@ public void testSimplePlan() throws Exception { final IREdge e = new IREdge(CommunicationPatternProperty.Value.Shuffle, v1, v2); irDAGBuilder.connectVertices(e); - final DAG<IRVertex, IREdge> irDAG = CompiletimeOptimizer.optimize(irDAGBuilder.buildWithoutSourceSinkCheck(), - new TestPolicy(), ""); + final DAG<IRVertex, IREdge> irDAG = new TestPolicy().runCompileTimeOptimization( + irDAGBuilder.buildWithoutSourceSinkCheck(), DAG.EMPTY_DAG_DIRECTORY); final DAG<Stage, StageEdge> DAGOfStages = physicalPlanGenerator.stagePartitionIrDAG(irDAG); final DAG<Stage, StageEdge> physicalDAG = irDAG.convert(physicalPlanGenerator); @@ -208,8 +207,8 @@ public void testComplexPlan() throws Exception { // Stage 5 = {v6} irDAGBuilder.connectVertices(e5); - final DAG<IRVertex, IREdge> irDAG = CompiletimeOptimizer.optimize(irDAGBuilder.build(), - new TestPolicy(), ""); + final DAG<IRVertex, IREdge> irDAG = new TestPolicy().runCompileTimeOptimization(irDAGBuilder.build(), + DAG.EMPTY_DAG_DIRECTORY); final DAG<Stage, StageEdge> logicalDAG = physicalPlanGenerator.stagePartitionIrDAG(irDAG); // Test Logical DAG diff --git a/compiler/test/src/test/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/annotating/DefaultScheduleGroupPassTest.java b/compiler/test/src/test/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/annotating/DefaultScheduleGroupPassTest.java index 64d0bce1d..bf348e04b 100644 --- a/compiler/test/src/test/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/annotating/DefaultScheduleGroupPassTest.java +++ b/compiler/test/src/test/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/annotating/DefaultScheduleGroupPassTest.java @@ -26,7 +26,6 @@ import edu.snu.nemo.common.ir.vertex.OperatorVertex; import edu.snu.nemo.common.ir.vertex.executionproperty.ScheduleGroupProperty; import edu.snu.nemo.compiler.CompilerTestUtil; -import edu.snu.nemo.compiler.optimizer.CompiletimeOptimizer; import edu.snu.nemo.compiler.optimizer.policy.TestPolicy; import org.junit.Test; import org.junit.runner.RunWith; @@ -58,8 +57,8 @@ public void testAnnotatingPass() { @Test public void testTopologicalOrdering() throws Exception { final DAG<IRVertex, IREdge> compiledDAG = CompilerTestUtil.compileALSDAG(); - final DAG<IRVertex, IREdge> processedDAG = CompiletimeOptimizer.optimize(compiledDAG, - new TestPolicy(), ""); + final DAG<IRVertex, IREdge> processedDAG = new TestPolicy().runCompileTimeOptimization(compiledDAG, + DAG.EMPTY_DAG_DIRECTORY); for (final IRVertex irVertex : processedDAG.getTopologicalSort()) { final Integer currentScheduleGroup = irVertex.getPropertyValue(ScheduleGroupProperty.class).get(); diff --git a/examples/beam/src/test/java/edu/snu/nemo/examples/beam/WordCountITCase.java b/examples/beam/src/test/java/edu/snu/nemo/examples/beam/WordCountITCase.java index 36e30e06a..7d291a923 100644 --- a/examples/beam/src/test/java/edu/snu/nemo/examples/beam/WordCountITCase.java +++ b/examples/beam/src/test/java/edu/snu/nemo/examples/beam/WordCountITCase.java @@ -18,6 +18,7 @@ import edu.snu.nemo.client.JobLauncher; import edu.snu.nemo.common.test.ArgBuilder; import edu.snu.nemo.common.test.ExampleTestUtil; +import edu.snu.nemo.compiler.optimizer.policy.ConditionalLargeShufflePolicy; import edu.snu.nemo.examples.beam.policy.*; import org.junit.After; import org.junit.Before; @@ -73,7 +74,7 @@ public void test() throws Exception { public void testLargeShuffle() throws Exception { JobLauncher.main(builder .addResourceJson(executorResourceFileName) - .addJobId(WordCountITCase.class.getSimpleName() + "_largeshuffle") + .addJobId(WordCountITCase.class.getSimpleName() + "_largeShuffle") .addOptimizationPolicy(LargeShufflePolicyParallelismFive.class.getCanonicalName()) .build()); } @@ -87,6 +88,15 @@ public void testLargeShuffleInOneExecutor() throws Exception { .build()); } + @Test (timeout = TIMEOUT) + public void testConditionalLargeShuffle() throws Exception { + JobLauncher.main(builder + .addResourceJson(executorResourceFileName) + .addJobId(WordCountITCase.class.getSimpleName() + "_conditionalLargeShuffle") + .addOptimizationPolicy(ConditionalLargeShufflePolicy.class.getCanonicalName()) + .build()); + } + @Test (timeout = TIMEOUT) public void testTransientResource() throws Exception { JobLauncher.main(builder diff --git a/examples/beam/src/test/java/edu/snu/nemo/examples/beam/policy/DataSkewPolicyParallelismFive.java b/examples/beam/src/test/java/edu/snu/nemo/examples/beam/policy/DataSkewPolicyParallelismFive.java index eed73f42e..e7fa7bd1b 100644 --- a/examples/beam/src/test/java/edu/snu/nemo/examples/beam/policy/DataSkewPolicyParallelismFive.java +++ b/examples/beam/src/test/java/edu/snu/nemo/examples/beam/policy/DataSkewPolicyParallelismFive.java @@ -15,12 +15,14 @@ */ package edu.snu.nemo.examples.beam.policy; -import edu.snu.nemo.compiler.optimizer.pass.compiletime.CompileTimePass; +import edu.snu.nemo.common.dag.DAG; +import edu.snu.nemo.common.eventhandler.PubSubEventHandlerWrapper; +import edu.snu.nemo.common.ir.edge.IREdge; +import edu.snu.nemo.common.ir.vertex.IRVertex; import edu.snu.nemo.compiler.optimizer.policy.DataSkewPolicy; import edu.snu.nemo.compiler.optimizer.policy.Policy; -import edu.snu.nemo.runtime.common.optimizer.pass.runtime.RuntimePass; - -import java.util.List; +import edu.snu.nemo.compiler.optimizer.policy.PolicyImpl; +import org.apache.reef.tang.Injector; /** * A data-skew policy with fixed parallelism 5 for tests. @@ -29,16 +31,19 @@ private final Policy policy; public DataSkewPolicyParallelismFive() { - this.policy = PolicyTestUtil.overwriteParallelism(5, DataSkewPolicy.class.getCanonicalName()); + this.policy = new PolicyImpl( + PolicyTestUtil.overwriteParallelism(5, DataSkewPolicy.BUILDER.getCompileTimePasses()), + DataSkewPolicy.BUILDER.getRuntimePasses()); } @Override - public List<CompileTimePass> getCompileTimePasses() { - return this.policy.getCompileTimePasses(); + public DAG<IRVertex, IREdge> runCompileTimeOptimization(final DAG<IRVertex, IREdge> dag, final String dagDirectory) + throws Exception { + return this.policy.runCompileTimeOptimization(dag, dagDirectory); } @Override - public List<RuntimePass<?>> getRuntimePasses() { - return this.policy.getRuntimePasses(); + public void registerRunTimeOptimizations(final Injector injector, final PubSubEventHandlerWrapper pubSubWrapper) { + this.policy.registerRunTimeOptimizations(injector, pubSubWrapper); } } diff --git a/examples/beam/src/test/java/edu/snu/nemo/examples/beam/policy/DefaultPolicyParallelismFive.java b/examples/beam/src/test/java/edu/snu/nemo/examples/beam/policy/DefaultPolicyParallelismFive.java index 4db1611e9..0287e1ce9 100644 --- a/examples/beam/src/test/java/edu/snu/nemo/examples/beam/policy/DefaultPolicyParallelismFive.java +++ b/examples/beam/src/test/java/edu/snu/nemo/examples/beam/policy/DefaultPolicyParallelismFive.java @@ -15,12 +15,14 @@ */ package edu.snu.nemo.examples.beam.policy; -import edu.snu.nemo.compiler.optimizer.pass.compiletime.CompileTimePass; +import edu.snu.nemo.common.dag.DAG; +import edu.snu.nemo.common.eventhandler.PubSubEventHandlerWrapper; +import edu.snu.nemo.common.ir.edge.IREdge; +import edu.snu.nemo.common.ir.vertex.IRVertex; import edu.snu.nemo.compiler.optimizer.policy.DefaultPolicy; import edu.snu.nemo.compiler.optimizer.policy.Policy; -import edu.snu.nemo.runtime.common.optimizer.pass.runtime.RuntimePass; - -import java.util.List; +import edu.snu.nemo.compiler.optimizer.policy.PolicyImpl; +import org.apache.reef.tang.Injector; /** * A default policy with fixed parallelism 5 for tests. @@ -29,16 +31,19 @@ private final Policy policy; public DefaultPolicyParallelismFive() { - this.policy = PolicyTestUtil.overwriteParallelism(5, DefaultPolicy.class.getCanonicalName()); + this.policy = new PolicyImpl( + PolicyTestUtil.overwriteParallelism(5, DefaultPolicy.BUILDER.getCompileTimePasses()), + DefaultPolicy.BUILDER.getRuntimePasses()); } @Override - public List<CompileTimePass> getCompileTimePasses() { - return this.policy.getCompileTimePasses(); + public DAG<IRVertex, IREdge> runCompileTimeOptimization(final DAG<IRVertex, IREdge> dag, final String dagDirectory) + throws Exception { + return this.policy.runCompileTimeOptimization(dag, dagDirectory); } @Override - public List<RuntimePass<?>> getRuntimePasses() { - return this.policy.getRuntimePasses(); + public void registerRunTimeOptimizations(final Injector injector, final PubSubEventHandlerWrapper pubSubWrapper) { + this.policy.registerRunTimeOptimizations(injector, pubSubWrapper); } } diff --git a/examples/beam/src/test/java/edu/snu/nemo/examples/beam/policy/DisaggregationPolicyParallelismFive.java b/examples/beam/src/test/java/edu/snu/nemo/examples/beam/policy/DisaggregationPolicyParallelismFive.java index 1e3c9810c..718ea9a67 100644 --- a/examples/beam/src/test/java/edu/snu/nemo/examples/beam/policy/DisaggregationPolicyParallelismFive.java +++ b/examples/beam/src/test/java/edu/snu/nemo/examples/beam/policy/DisaggregationPolicyParallelismFive.java @@ -15,13 +15,14 @@ */ package edu.snu.nemo.examples.beam.policy; - -import edu.snu.nemo.compiler.optimizer.pass.compiletime.CompileTimePass; +import edu.snu.nemo.common.dag.DAG; +import edu.snu.nemo.common.eventhandler.PubSubEventHandlerWrapper; +import edu.snu.nemo.common.ir.edge.IREdge; +import edu.snu.nemo.common.ir.vertex.IRVertex; import edu.snu.nemo.compiler.optimizer.policy.DisaggregationPolicy; import edu.snu.nemo.compiler.optimizer.policy.Policy; -import edu.snu.nemo.runtime.common.optimizer.pass.runtime.RuntimePass; - -import java.util.List; +import edu.snu.nemo.compiler.optimizer.policy.PolicyImpl; +import org.apache.reef.tang.Injector; /** * A disaggregation policy with fixed parallelism 5 for tests. @@ -30,16 +31,20 @@ private final Policy policy; public DisaggregationPolicyParallelismFive() { - this.policy = PolicyTestUtil.overwriteParallelism(5, DisaggregationPolicy.class.getCanonicalName()); + this.policy = new PolicyImpl( + PolicyTestUtil.overwriteParallelism(5, + DisaggregationPolicy.BUILDER.getCompileTimePasses()), + DisaggregationPolicy.BUILDER.getRuntimePasses()); } @Override - public List<CompileTimePass> getCompileTimePasses() { - return this.policy.getCompileTimePasses(); + public DAG<IRVertex, IREdge> runCompileTimeOptimization(final DAG<IRVertex, IREdge> dag, final String dagDirectory) + throws Exception { + return this.policy.runCompileTimeOptimization(dag, dagDirectory); } @Override - public List<RuntimePass<?>> getRuntimePasses() { - return this.policy.getRuntimePasses(); + public void registerRunTimeOptimizations(final Injector injector, final PubSubEventHandlerWrapper pubSubWrapper) { + this.policy.registerRunTimeOptimizations(injector, pubSubWrapper); } } diff --git a/examples/beam/src/test/java/edu/snu/nemo/examples/beam/policy/LargeShufflePolicyParallelismFive.java b/examples/beam/src/test/java/edu/snu/nemo/examples/beam/policy/LargeShufflePolicyParallelismFive.java index b4c6339a4..1c949eee9 100644 --- a/examples/beam/src/test/java/edu/snu/nemo/examples/beam/policy/LargeShufflePolicyParallelismFive.java +++ b/examples/beam/src/test/java/edu/snu/nemo/examples/beam/policy/LargeShufflePolicyParallelismFive.java @@ -15,12 +15,14 @@ */ package edu.snu.nemo.examples.beam.policy; -import edu.snu.nemo.compiler.optimizer.pass.compiletime.CompileTimePass; -import edu.snu.nemo.compiler.optimizer.policy.Policy; +import edu.snu.nemo.common.dag.DAG; +import edu.snu.nemo.common.eventhandler.PubSubEventHandlerWrapper; +import edu.snu.nemo.common.ir.edge.IREdge; +import edu.snu.nemo.common.ir.vertex.IRVertex; import edu.snu.nemo.compiler.optimizer.policy.LargeShufflePolicy; -import edu.snu.nemo.runtime.common.optimizer.pass.runtime.RuntimePass; - -import java.util.List; +import edu.snu.nemo.compiler.optimizer.policy.Policy; +import edu.snu.nemo.compiler.optimizer.policy.PolicyImpl; +import org.apache.reef.tang.Injector; /** * A large shuffle policy with fixed parallelism 5 for tests. @@ -29,17 +31,19 @@ private final Policy policy; public LargeShufflePolicyParallelismFive() { - this.policy = - PolicyTestUtil.overwriteParallelism(5, LargeShufflePolicy.class.getCanonicalName()); + this.policy = new PolicyImpl( + PolicyTestUtil.overwriteParallelism(5, LargeShufflePolicy.BUILDER.getCompileTimePasses()), + LargeShufflePolicy.BUILDER.getRuntimePasses()); } @Override - public List<CompileTimePass> getCompileTimePasses() { - return this.policy.getCompileTimePasses(); + public DAG<IRVertex, IREdge> runCompileTimeOptimization(final DAG<IRVertex, IREdge> dag, final String dagDirectory) + throws Exception { + return this.policy.runCompileTimeOptimization(dag, dagDirectory); } @Override - public List<RuntimePass<?>> getRuntimePasses() { - return this.policy.getRuntimePasses(); + public void registerRunTimeOptimizations(final Injector injector, final PubSubEventHandlerWrapper pubSubWrapper) { + this.policy.registerRunTimeOptimizations(injector, pubSubWrapper); } } diff --git a/examples/beam/src/test/java/edu/snu/nemo/examples/beam/policy/PolicyTestUtil.java b/examples/beam/src/test/java/edu/snu/nemo/examples/beam/policy/PolicyTestUtil.java index 24ec98dbf..cb60d7eac 100644 --- a/examples/beam/src/test/java/edu/snu/nemo/examples/beam/policy/PolicyTestUtil.java +++ b/examples/beam/src/test/java/edu/snu/nemo/examples/beam/policy/PolicyTestUtil.java @@ -17,8 +17,6 @@ import edu.snu.nemo.compiler.optimizer.pass.compiletime.CompileTimePass; import edu.snu.nemo.compiler.optimizer.pass.compiletime.annotating.DefaultParallelismPass; -import edu.snu.nemo.compiler.optimizer.policy.Policy; -import edu.snu.nemo.runtime.common.optimizer.pass.runtime.RuntimePass; import java.util.List; @@ -30,32 +28,14 @@ * Overwrite the parallelism of existing policy. * * @param desiredSourceParallelism the desired source parallelism to set. - * @param policyToOverwriteCanonicalName the name of the policy to overwrite parallelism. + * @param compileTimePassesToOverwrite the list of compile time passes to overwrite. * @return the overwritten policy. */ - public static Policy overwriteParallelism(final int desiredSourceParallelism, - final String policyToOverwriteCanonicalName) { - final Policy policyToOverwrite; - try { - policyToOverwrite = (Policy) Class.forName(policyToOverwriteCanonicalName).newInstance(); - } catch (final ClassNotFoundException | InstantiationException | IllegalAccessException e) { - throw new RuntimeException(e); - } - final List<CompileTimePass> compileTimePasses = policyToOverwrite.getCompileTimePasses(); - final int parallelismPassIdx = compileTimePasses.indexOf(new DefaultParallelismPass()); - compileTimePasses.set(parallelismPassIdx, new DefaultParallelismPass(desiredSourceParallelism, 2)); - final List<RuntimePass<?>> runtimePasses = policyToOverwrite.getRuntimePasses(); - - return new Policy() { - @Override - public List<CompileTimePass> getCompileTimePasses() { - return compileTimePasses; - } - - @Override - public List<RuntimePass<?>> getRuntimePasses() { - return runtimePasses; - } - }; + public static List<CompileTimePass> overwriteParallelism(final int desiredSourceParallelism, + final List<CompileTimePass> compileTimePassesToOverwrite) { + final int parallelismPassIdx = compileTimePassesToOverwrite.indexOf(new DefaultParallelismPass()); + compileTimePassesToOverwrite.set(parallelismPassIdx, + new DefaultParallelismPass(desiredSourceParallelism, 2)); + return compileTimePassesToOverwrite; } } diff --git a/examples/beam/src/test/java/edu/snu/nemo/examples/beam/policy/TransientResourcePolicyParallelismFive.java b/examples/beam/src/test/java/edu/snu/nemo/examples/beam/policy/TransientResourcePolicyParallelismFive.java index 676091712..fbf95d08b 100644 --- a/examples/beam/src/test/java/edu/snu/nemo/examples/beam/policy/TransientResourcePolicyParallelismFive.java +++ b/examples/beam/src/test/java/edu/snu/nemo/examples/beam/policy/TransientResourcePolicyParallelismFive.java @@ -15,12 +15,14 @@ */ package edu.snu.nemo.examples.beam.policy; -import edu.snu.nemo.compiler.optimizer.pass.compiletime.CompileTimePass; +import edu.snu.nemo.common.dag.DAG; +import edu.snu.nemo.common.eventhandler.PubSubEventHandlerWrapper; +import edu.snu.nemo.common.ir.edge.IREdge; +import edu.snu.nemo.common.ir.vertex.IRVertex; +import edu.snu.nemo.compiler.optimizer.policy.PolicyImpl; import edu.snu.nemo.compiler.optimizer.policy.TransientResourcePolicy; import edu.snu.nemo.compiler.optimizer.policy.Policy; -import edu.snu.nemo.runtime.common.optimizer.pass.runtime.RuntimePass; - -import java.util.List; +import org.apache.reef.tang.Injector; /** * A transient resource policy with fixed parallelism 5 for tests. @@ -29,17 +31,20 @@ private final Policy policy; public TransientResourcePolicyParallelismFive() { - this.policy = - PolicyTestUtil.overwriteParallelism(5, TransientResourcePolicy.class.getCanonicalName()); + this.policy = new PolicyImpl( + PolicyTestUtil.overwriteParallelism(5, + TransientResourcePolicy.BUILDER.getCompileTimePasses()), + TransientResourcePolicy.BUILDER.getRuntimePasses()); } @Override - public List<CompileTimePass> getCompileTimePasses() { - return this.policy.getCompileTimePasses(); + public DAG<IRVertex, IREdge> runCompileTimeOptimization(final DAG<IRVertex, IREdge> dag, final String dagDirectory) + throws Exception { + return this.policy.runCompileTimeOptimization(dag, dagDirectory); } @Override - public List<RuntimePass<?>> getRuntimePasses() { - return this.policy.getRuntimePasses(); + public void registerRunTimeOptimizations(final Injector injector, final PubSubEventHandlerWrapper pubSubWrapper) { + this.policy.registerRunTimeOptimizations(injector, pubSubWrapper); } } diff --git a/examples/beam/src/test/java/edu/snu/nemo/examples/beam/policy/TransientResourcePolicyParallelismTen.java b/examples/beam/src/test/java/edu/snu/nemo/examples/beam/policy/TransientResourcePolicyParallelismTen.java index 4227b8045..fa617223d 100644 --- a/examples/beam/src/test/java/edu/snu/nemo/examples/beam/policy/TransientResourcePolicyParallelismTen.java +++ b/examples/beam/src/test/java/edu/snu/nemo/examples/beam/policy/TransientResourcePolicyParallelismTen.java @@ -15,12 +15,14 @@ */ package edu.snu.nemo.examples.beam.policy; -import edu.snu.nemo.compiler.optimizer.pass.compiletime.CompileTimePass; +import edu.snu.nemo.common.dag.DAG; +import edu.snu.nemo.common.eventhandler.PubSubEventHandlerWrapper; +import edu.snu.nemo.common.ir.edge.IREdge; +import edu.snu.nemo.common.ir.vertex.IRVertex; +import edu.snu.nemo.compiler.optimizer.policy.PolicyImpl; import edu.snu.nemo.compiler.optimizer.policy.TransientResourcePolicy; import edu.snu.nemo.compiler.optimizer.policy.Policy; -import edu.snu.nemo.runtime.common.optimizer.pass.runtime.RuntimePass; - -import java.util.List; +import org.apache.reef.tang.Injector; /** * A transient resource policy with fixed parallelism 10 for tests. @@ -29,17 +31,20 @@ private final Policy policy; public TransientResourcePolicyParallelismTen() { - this.policy = - PolicyTestUtil.overwriteParallelism(10, TransientResourcePolicy.class.getCanonicalName()); + this.policy = new PolicyImpl( + PolicyTestUtil.overwriteParallelism(10, + TransientResourcePolicy.BUILDER.getCompileTimePasses()), + TransientResourcePolicy.BUILDER.getRuntimePasses()); } @Override - public List<CompileTimePass> getCompileTimePasses() { - return this.policy.getCompileTimePasses(); + public DAG<IRVertex, IREdge> runCompileTimeOptimization(final DAG<IRVertex, IREdge> dag, final String dagDirectory) + throws Exception { + return this.policy.runCompileTimeOptimization(dag, dagDirectory); } @Override - public List<RuntimePass<?>> getRuntimePasses() { - return this.policy.getRuntimePasses(); + public void registerRunTimeOptimizations(final Injector injector, final PubSubEventHandlerWrapper pubSubWrapper) { + this.policy.registerRunTimeOptimizations(injector, pubSubWrapper); } } diff --git a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/eventhandler/DynamicOptimizationEventHandler.java b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/eventhandler/DynamicOptimizationEventHandler.java index 63cd56eda..f4f7d2236 100644 --- a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/eventhandler/DynamicOptimizationEventHandler.java +++ b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/eventhandler/DynamicOptimizationEventHandler.java @@ -22,7 +22,7 @@ import edu.snu.nemo.common.eventhandler.PubSubEventHandlerWrapper; import edu.snu.nemo.common.eventhandler.RuntimeEventHandler; import edu.snu.nemo.common.ir.vertex.MetricCollectionBarrierVertex; -import edu.snu.nemo.runtime.common.optimizer.RuntimeOptimizer; +import edu.snu.nemo.runtime.common.optimizer.RunTimeOptimizer; import edu.snu.nemo.runtime.common.plan.PhysicalPlan; import org.apache.reef.wake.impl.PubSubEventHandler; @@ -56,7 +56,7 @@ public void onNext(final DynamicOptimizationEvent dynamicOptimizationEvent) { final Pair<String, String> taskInfo = dynamicOptimizationEvent.getTaskInfo(); - final PhysicalPlan newPlan = RuntimeOptimizer.dynamicOptimization(physicalPlan, + final PhysicalPlan newPlan = RunTimeOptimizer.dynamicOptimization(physicalPlan, metricCollectionBarrierVertex); pubSubEventHandler.onNext(new UpdatePhysicalPlanEvent(newPlan, taskInfo)); diff --git a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/optimizer/RuntimeOptimizer.java b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/optimizer/RunTimeOptimizer.java similarity index 97% rename from runtime/common/src/main/java/edu/snu/nemo/runtime/common/optimizer/RuntimeOptimizer.java rename to runtime/common/src/main/java/edu/snu/nemo/runtime/common/optimizer/RunTimeOptimizer.java index 888e6f58c..e251c983c 100644 --- a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/optimizer/RuntimeOptimizer.java +++ b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/optimizer/RunTimeOptimizer.java @@ -26,11 +26,11 @@ /** * Runtime optimizer class. */ -public final class RuntimeOptimizer { +public final class RunTimeOptimizer { /** * Private constructor. */ - private RuntimeOptimizer() { + private RunTimeOptimizer() { } /** diff --git a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/optimizer/pass/runtime/DataSkewRuntimePass.java b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/optimizer/pass/runtime/DataSkewRuntimePass.java index cf8f5b6e8..3218317dc 100644 --- a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/optimizer/pass/runtime/DataSkewRuntimePass.java +++ b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/optimizer/pass/runtime/DataSkewRuntimePass.java @@ -41,11 +41,11 @@ * this RuntimePass identifies a number of keys with big partition sizes(skewed key) * and evenly redistributes data via overwriting incoming edges of destination tasks. */ -public final class DataSkewRuntimePass implements RuntimePass<Pair<List<String>, Map<Integer, Long>>> { +public final class DataSkewRuntimePass extends RuntimePass<Pair<List<String>, Map<Integer, Long>>> { private static final Logger LOG = LoggerFactory.getLogger(DataSkewRuntimePass.class.getName()); private final Set<Class<? extends RuntimeEventHandler>> eventHandlers; // Skewed keys denote for top n keys in terms of partition size. - private static final int DEFAULT_NUM_SKEWED_KEYS = 3; + public static final int DEFAULT_NUM_SKEWED_KEYS = 3; private int numSkewedKeys = DEFAULT_NUM_SKEWED_KEYS; /** diff --git a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/optimizer/pass/runtime/RuntimePass.java b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/optimizer/pass/runtime/RuntimePass.java index cfa38a6d0..249a239b3 100644 --- a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/optimizer/pass/runtime/RuntimePass.java +++ b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/optimizer/pass/runtime/RuntimePass.java @@ -16,21 +16,21 @@ package edu.snu.nemo.runtime.common.optimizer.pass.runtime; import edu.snu.nemo.common.eventhandler.RuntimeEventHandler; +import edu.snu.nemo.common.pass.Pass; import edu.snu.nemo.runtime.common.plan.PhysicalPlan; -import java.io.Serializable; import java.util.Set; import java.util.function.BiFunction; /** - * Interface for dynamic optimization passes, for dynamically optimizing a physical plan. + * Abstract class for dynamic optimization passes, for dynamically optimizing a physical plan. * It is a BiFunction that takes an original physical plan and metric data, to produce a new physical plan * after dynamic optimization. * @param <T> type of the metric data used for dynamic optimization. */ -public interface RuntimePass<T> extends BiFunction<PhysicalPlan, T, PhysicalPlan>, Serializable { +public abstract class RuntimePass<T> extends Pass implements BiFunction<PhysicalPlan, T, PhysicalPlan> { /** * @return the set of event handlers used with the runtime pass. */ - Set<Class<? extends RuntimeEventHandler>> getEventHandlerClasses(); + public abstract Set<Class<? extends RuntimeEventHandler>> getEventHandlerClasses(); } diff --git a/runtime/driver/src/main/java/edu/snu/nemo/driver/UserApplicationRunner.java b/runtime/driver/src/main/java/edu/snu/nemo/driver/UserApplicationRunner.java index aad818527..52ea468a4 100644 --- a/runtime/driver/src/main/java/edu/snu/nemo/driver/UserApplicationRunner.java +++ b/runtime/driver/src/main/java/edu/snu/nemo/driver/UserApplicationRunner.java @@ -18,12 +18,11 @@ import edu.snu.nemo.common.Pair; import edu.snu.nemo.common.dag.DAG; import edu.snu.nemo.common.eventhandler.PubSubEventHandlerWrapper; -import edu.snu.nemo.common.eventhandler.RuntimeEventHandler; +import edu.snu.nemo.common.exception.CompileTimeOptimizationException; import edu.snu.nemo.common.ir.edge.IREdge; import edu.snu.nemo.common.ir.vertex.IRVertex; import edu.snu.nemo.compiler.backend.Backend; import edu.snu.nemo.compiler.backend.nemo.NemoBackend; -import edu.snu.nemo.compiler.optimizer.CompiletimeOptimizer; import edu.snu.nemo.compiler.optimizer.policy.Policy; import edu.snu.nemo.conf.JobConf; import edu.snu.nemo.runtime.common.plan.PhysicalPlan; @@ -87,20 +86,14 @@ public void run(final String dagString) { dag.storeJSON(dagDirectory, "ir", "IR before optimization"); final Policy optimizationPolicy = (Policy) Class.forName(optimizationPolicyCanonicalName).newInstance(); - final DAG<IRVertex, IREdge> optimizedDAG = CompiletimeOptimizer.optimize(dag, optimizationPolicy, dagDirectory); + if (optimizationPolicy == null) { + throw new CompileTimeOptimizationException("A policy name should be specified."); + } + final DAG<IRVertex, IREdge> optimizedDAG = optimizationPolicy.runCompileTimeOptimization(dag, dagDirectory); optimizedDAG.storeJSON(dagDirectory, "ir-" + optimizationPolicy.getClass().getSimpleName(), "IR optimized for " + optimizationPolicy.getClass().getSimpleName()); - optimizationPolicy.getRuntimePasses().forEach(runtimePass -> - runtimePass.getEventHandlerClasses().forEach(runtimeEventHandlerClass -> { - try { - final RuntimeEventHandler runtimeEventHandler = injector.getInstance(runtimeEventHandlerClass); - pubSubWrapper.getPubSubEventHandler() - .subscribe(runtimeEventHandler.getEventClass(), runtimeEventHandler); - } catch (final Exception e) { - throw new RuntimeException(e); - } - })); + optimizationPolicy.registerRunTimeOptimizations(injector, pubSubWrapper); final PhysicalPlan physicalPlan = backend.compile(optimizedDAG); diff --git a/runtime/test/src/main/java/edu/snu/nemo/runtime/common/plan/TestPlanGenerator.java b/runtime/test/src/main/java/edu/snu/nemo/runtime/common/plan/TestPlanGenerator.java index 93c7ad515..5742238ac 100644 --- a/runtime/test/src/main/java/edu/snu/nemo/runtime/common/plan/TestPlanGenerator.java +++ b/runtime/test/src/main/java/edu/snu/nemo/runtime/common/plan/TestPlanGenerator.java @@ -25,7 +25,6 @@ import edu.snu.nemo.common.ir.vertex.executionproperty.ParallelismProperty; import edu.snu.nemo.common.ir.vertex.transform.Transform; import edu.snu.nemo.common.test.EmptyComponents; -import edu.snu.nemo.compiler.optimizer.CompiletimeOptimizer; import edu.snu.nemo.compiler.optimizer.policy.BasicPullPolicy; import edu.snu.nemo.compiler.optimizer.policy.BasicPushPolicy; import edu.snu.nemo.compiler.optimizer.policy.Policy; @@ -94,7 +93,7 @@ public static PhysicalPlan generatePhysicalPlan(final PlanType planType, final b */ private static PhysicalPlan convertIRToPhysical(final DAG<IRVertex, IREdge> irDAG, final Policy policy) throws Exception { - final DAG<IRVertex, IREdge> optimized = CompiletimeOptimizer.optimize(irDAG, policy, EMPTY_DAG_DIRECTORY); + final DAG<IRVertex, IREdge> optimized = policy.runCompileTimeOptimization(irDAG, EMPTY_DAG_DIRECTORY); final DAG<Stage, StageEdge> physicalDAG = optimized.convert(PLAN_GENERATOR); return new PhysicalPlan("TestPlan", physicalDAG); } ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected] With regards, Apache Git Services
