This is an automated email from the ASF dual-hosted git repository.
sanha pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-nemo.git
The following commit(s) were added to refs/heads/master by this push:
new 3dcff47 [NEMO-96] Modularize DataSkewPolicy to use MetricVertex and
BarrierVertex (#115)
3dcff47 is described below
commit 3dcff47351364b136c0dc33e07d1c557edeee0e7
Author: Jeongyoon Eo <[email protected]>
AuthorDate: Wed Sep 5 14:02:22 2018 +0900
[NEMO-96] Modularize DataSkewPolicy to use MetricVertex and BarrierVertex
(#115)
JIRA: [NEMO-96: Modularize DataSkewPolicy to use MetricVertex and
BarrierVertex](https://issues.apache.org/jira/projects/NEMO/issues/NEMO-96)
[NEMO-98: Implement MetricVertex that collect metric used for dynamic
optimization](https://issues.apache.org/jira/projects/NEMO/issues/NEMO-98)
[NEMO-99: Implement AggregationBarrierVertex for dynamic
optimization](https://issues.apache.org/jira/projects/NEMO/issues/NEMO-99)
**Major changes:**
- Handle dynamic optimization via `MetricCollectionVertex` and
`AggregationBarrierVertex` instead of `MetricCollectionBarrierVertex`
- For each shuffle edge with main output, `MetricCollectionVertex` is
inserted in compile-time at the end of its source tasks, which collects key
frequency data
- For each shuffle edge with main output, `AggregationBarrierVertex` is
inserted in compile-time. It aggregates task-level key frequency data, which is
collected via each `MetricCollectionVertex` and emitted as additional tagged
output
**Minor changes to note:**
- Added encoder/decoder factories needed for aggregating dynamic
optimization data - in here key frequency data
- Modified `PipelineTranslator` to extract key encoder/decoders
- Modified `DataSkewRuntimePass` and related code path to handle `Object`
type keys, instead of integer type hash index keys
**Tests for the changes:**
- N/A(unit tests for skew handling and `PerKeyMedianITCase` test the
changes)
**Other comments:**
- N/A
Closes #115
---
bin/json2dot.py | 4 +-
common/pom.xml | 6 +
.../java/edu/snu/nemo/common/KeyExtractor.java | 1 +
.../edu/snu/nemo/common/coder/DecoderFactory.java | 4 +-
.../edu/snu/nemo/common/coder/EncoderFactory.java | 2 +-
.../snu/nemo/common/coder/LongDecoderFactory.java | 71 +++++++++
.../snu/nemo/common/coder/LongEncoderFactory.java | 72 +++++++++
.../java/edu/snu/nemo/common/dag/DAGBuilder.java | 8 +-
.../edge/executionproperty/KeyDecoderProperty.java | 43 ++++++
.../edge/executionproperty/KeyEncoderProperty.java | 43 ++++++
.../ir/vertex/MetricCollectionBarrierVertex.java | 128 ----------------
.../vertex/transform/AggregateMetricTransform.java | 69 +++++++++
.../vertex/transform/MetricCollectTransform.java | 73 +++++++++
.../edu/snu/nemo/common/test/EmptyComponents.java | 66 ++++++++
.../compiler/frontend/beam/PipelineTranslator.java | 12 +-
.../frontend/beam/coder/BeamDecoderFactory.java | 6 +-
.../frontend/spark/core/rdd/PairRDDFunctions.scala | 6 +-
.../compiler/frontend/spark/core/rdd/RDD.scala | 1 -
.../compiletime/annotating/SkewDataStorePass.java | 58 -------
.../annotating/SkewMetricCollectionPass.java | 16 +-
.../annotating/SkewPartitionerPass.java | 10 +-
.../annotating/SkewResourceSkewedDataPass.java | 24 +--
.../compiletime/composite/SkewCompositePass.java | 9 +-
.../compiletime/reshaping/SkewReshapingPass.java | 167 +++++++++++++++++----
.../compiler/optimizer/policy/DataSkewPolicy.java | 9 +-
.../optimizer/policy/PolicyBuilderTest.java | 2 +-
.../compiler/optimizer/policy/PolicyImplTest.java | 9 +-
.../composite/SkewCompositePassTest.java | 47 +++---
.../runtime/common/optimizer/RunTimeOptimizer.java | 2 +-
.../pass/runtime/DataSkewRuntimePass.java | 61 ++++----
.../runtime/common/plan/PhysicalPlanGenerator.java | 3 +-
.../nemo/runtime/common/plan/StagePartitioner.java | 2 +-
runtime/common/src/main/proto/ControlMessage.proto | 3 +-
.../pass/runtime/DataSkewRuntimePassTest.java | 2 +-
.../runtime/executor/data/BlockManagerWorker.java | 26 ----
.../runtime/executor/data/block/FileBlock.java | 4 +-
.../data/partition/SerializedPartition.java | 4 +
.../executor/datatransfer/OutputCollectorImpl.java | 63 +++++---
.../executor/datatransfer/OutputWriter.java | 8 +-
.../nemo/runtime/executor/task/TaskExecutor.java | 134 ++++++++++++-----
.../nemo/runtime/executor/task/VertexHarness.java | 20 ++-
.../executor/datatransfer/DataTransferTest.java | 13 +-
.../runtime/master/DataSkewDynOptDataHandler.java | 10 +-
.../edu/snu/nemo/runtime/master/RuntimeMaster.java | 1 -
.../runtime/master/scheduler/BatchScheduler.java | 21 ++-
45 files changed, 905 insertions(+), 438 deletions(-)
diff --git a/bin/json2dot.py b/bin/json2dot.py
index f41146b..f3caf7d 100755
--- a/bin/json2dot.py
+++ b/bin/json2dot.py
@@ -157,9 +157,9 @@ class NormalVertex:
label += '<BR/>{}:{}'.format(transform_name, class_name)
except:
pass
- if ('class' in self.properties and self.properties['class'] ==
'MetricCollectionBarrierVertex'):
+ if ('class' in self.properties and self.properties['class'] ==
'AggregationBarrierVertex'):
shape = ', shape=box'
- label += '<BR/>MetricCollectionBarrier'
+ label += '<BR/>AggregationBarrier'
else:
shape = ''
try:
diff --git a/common/pom.xml b/common/pom.xml
index da5a48c..18ef10b 100644
--- a/common/pom.xml
+++ b/common/pom.xml
@@ -52,5 +52,11 @@ limitations under the License.
<version>${hadoop.version}</version>
<scope>provided</scope>
</dependency>
+ <dependency>
+ <groupId>org.apache.beam</groupId>
+ <artifactId>beam-sdks-java-core</artifactId>
+ <version>${beam.version}</version>
+ </dependency>
+
</dependencies>
</project>
diff --git a/common/src/main/java/edu/snu/nemo/common/KeyExtractor.java
b/common/src/main/java/edu/snu/nemo/common/KeyExtractor.java
index be7cf59..23bc6ab 100644
--- a/common/src/main/java/edu/snu/nemo/common/KeyExtractor.java
+++ b/common/src/main/java/edu/snu/nemo/common/KeyExtractor.java
@@ -24,6 +24,7 @@ import java.io.Serializable;
public interface KeyExtractor extends Serializable {
/**
* Extracts key.
+ *
* @param element Element to get the key from.
* @return The extracted key of the element.
*/
diff --git a/common/src/main/java/edu/snu/nemo/common/coder/DecoderFactory.java
b/common/src/main/java/edu/snu/nemo/common/coder/DecoderFactory.java
index dc67ff3..16fa877 100644
--- a/common/src/main/java/edu/snu/nemo/common/coder/DecoderFactory.java
+++ b/common/src/main/java/edu/snu/nemo/common/coder/DecoderFactory.java
@@ -20,8 +20,8 @@ import java.io.InputStream;
import java.io.Serializable;
/**
- * A decoder factory object which generates decoders that decode values of
type {@code T} into byte streams.
- * To avoid to generate instance-based coder such as Spark serializer for
every decoding,
+ * A decoder factory object which generates decoders that decode byte streams
into values of type {@code T}.
+ * To avoid generating instance-based coder such as Spark serializer for every
decoding,
* user need to instantiate a decoder instance and use it.
*
* @param <T> element type.
diff --git a/common/src/main/java/edu/snu/nemo/common/coder/EncoderFactory.java
b/common/src/main/java/edu/snu/nemo/common/coder/EncoderFactory.java
index d63fafb..82c3730 100644
--- a/common/src/main/java/edu/snu/nemo/common/coder/EncoderFactory.java
+++ b/common/src/main/java/edu/snu/nemo/common/coder/EncoderFactory.java
@@ -46,7 +46,7 @@ public interface EncoderFactory<T> extends Serializable {
/**
* Encodes the given value onto the specified output stream.
- * It have to be able to encode the given stream consequently by calling
this method repeatedly.
+ * It has to be able to encode the given stream consequently by calling
this method repeatedly.
* Because the user can want to keep a single output stream and
continuously concatenate elements,
* the output stream should not be closed.
*
diff --git
a/common/src/main/java/edu/snu/nemo/common/coder/LongDecoderFactory.java
b/common/src/main/java/edu/snu/nemo/common/coder/LongDecoderFactory.java
new file mode 100644
index 0000000..4335413
--- /dev/null
+++ b/common/src/main/java/edu/snu/nemo/common/coder/LongDecoderFactory.java
@@ -0,0 +1,71 @@
+/*
+ * Copyright (C) 2018 Seoul National University
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package edu.snu.nemo.common.coder;
+
+import java.io.DataInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+
+/**
+ * A {@link DecoderFactory} which is used for long.
+ */
+public final class LongDecoderFactory implements DecoderFactory<Long> {
+
+ private static final LongDecoderFactory LONG_DECODER_FACTORY = new
LongDecoderFactory();
+
+ /**
+ * A private constructor.
+ */
+ private LongDecoderFactory() {
+ // do nothing.
+ }
+
+ /**
+ * Static initializer of the coder.
+ */
+ public static LongDecoderFactory of() {
+ return LONG_DECODER_FACTORY;
+ }
+
+ @Override
+ public Decoder<Long> create(final InputStream inputStream) {
+ return new LongDecoder(inputStream);
+ }
+
+ /**
+ * LongDecoder.
+ */
+ private final class LongDecoder implements Decoder<Long> {
+ private final DataInputStream inputStream;
+
+ /**
+ * Constructor.
+ *
+ * @param inputStream the input stream to decode.
+ */
+ private LongDecoder(final InputStream inputStream) {
+ // If the inputStream is closed well in upper level, it is okay to not
close this stream
+ // because the DataInputStream itself will not contain any extra
information.
+ // (when we close this stream, the input will be closed together.)
+ this.inputStream = new DataInputStream(inputStream);
+ }
+
+ @Override
+ public Long decode() throws IOException {
+ return inputStream.readLong();
+ }
+ }
+}
diff --git
a/common/src/main/java/edu/snu/nemo/common/coder/LongEncoderFactory.java
b/common/src/main/java/edu/snu/nemo/common/coder/LongEncoderFactory.java
new file mode 100644
index 0000000..f6e5aaa
--- /dev/null
+++ b/common/src/main/java/edu/snu/nemo/common/coder/LongEncoderFactory.java
@@ -0,0 +1,72 @@
+/*
+ * Copyright (C) 2018 Seoul National University
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package edu.snu.nemo.common.coder;
+
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.io.OutputStream;
+
+/**
+ * A {@link EncoderFactory} which is used for long.
+ */
+public final class LongEncoderFactory implements EncoderFactory<Long> {
+
+ private static final LongEncoderFactory LONG_ENCODER_FACTORY = new
LongEncoderFactory();
+
+ /**
+ * A private constructor.
+ */
+ private LongEncoderFactory() {
+ // do nothing.
+ }
+
+ /**
+ * Static initializer of the coder.
+ */
+ public static LongEncoderFactory of() {
+ return LONG_ENCODER_FACTORY;
+ }
+
+ @Override
+ public Encoder<Long> create(final OutputStream outputStream) {
+ return new LongEncoder(outputStream);
+ }
+
+ /**
+ * LongEncoder.
+ */
+ private final class LongEncoder implements Encoder<Long> {
+
+ private final DataOutputStream outputStream;
+
+ /**
+ * Constructor.
+ *
+ * @param outputStream the output stream to store the encoded bytes.
+ */
+ private LongEncoder(final OutputStream outputStream) {
+ // If the outputStream is closed well in upper level, it is okay to not
close this stream
+ // because the DataOutputStream itself will not contain any extra
information.
+ // (when we close this stream, the output will be closed together.)
+ this.outputStream = new DataOutputStream(outputStream);
+ }
+
+ @Override
+ public void encode(final Long value) throws IOException {
+ outputStream.writeLong(value);
+ }
+ }
+}
diff --git a/common/src/main/java/edu/snu/nemo/common/dag/DAGBuilder.java
b/common/src/main/java/edu/snu/nemo/common/dag/DAGBuilder.java
index f9b266e..442b7dd 100644
--- a/common/src/main/java/edu/snu/nemo/common/dag/DAGBuilder.java
+++ b/common/src/main/java/edu/snu/nemo/common/dag/DAGBuilder.java
@@ -20,10 +20,7 @@ import edu.snu.nemo.common.ir.edge.IREdge;
import
edu.snu.nemo.common.ir.edge.executionproperty.BroadcastVariableIdProperty;
import edu.snu.nemo.common.ir.edge.executionproperty.DataFlowProperty;
import edu.snu.nemo.common.ir.edge.executionproperty.MetricCollectionProperty;
-import edu.snu.nemo.common.ir.vertex.IRVertex;
-import edu.snu.nemo.common.ir.vertex.OperatorVertex;
-import edu.snu.nemo.common.ir.vertex.SourceVertex;
-import edu.snu.nemo.common.ir.vertex.LoopVertex;
+import edu.snu.nemo.common.ir.vertex.*;
import edu.snu.nemo.common.exception.IllegalVertexOperationException;
import java.io.Serializable;
@@ -245,7 +242,8 @@ public final class DAGBuilder<V extends Vertex, E extends
Edge<V>> implements Se
.filter(v -> outgoingEdges.get(v).isEmpty())
.filter(v -> v instanceof IRVertex);
// They should either be OperatorVertex or LoopVertex
- if (verticesToObserve.get().anyMatch(v -> !(v instanceof OperatorVertex ||
v instanceof LoopVertex))) {
+ if (verticesToObserve.get().anyMatch(v ->
+ !(v instanceof OperatorVertex || v instanceof LoopVertex))) {
final String problematicVertices = verticesToObserve.get().filter(v ->
!(v instanceof OperatorVertex || v instanceof LoopVertex))
.map(V::getId).collect(Collectors.toList()).toString();
diff --git
a/common/src/main/java/edu/snu/nemo/common/ir/edge/executionproperty/KeyDecoderProperty.java
b/common/src/main/java/edu/snu/nemo/common/ir/edge/executionproperty/KeyDecoderProperty.java
new file mode 100644
index 0000000..0a567ec
--- /dev/null
+++
b/common/src/main/java/edu/snu/nemo/common/ir/edge/executionproperty/KeyDecoderProperty.java
@@ -0,0 +1,43 @@
+/*
+ * Copyright (C) 2018 Seoul National University
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package edu.snu.nemo.common.ir.edge.executionproperty;
+
+import edu.snu.nemo.common.coder.DecoderFactory;
+import edu.snu.nemo.common.ir.executionproperty.EdgeExecutionProperty;
+
+/**
+ * KeyDecoder ExecutionProperty.
+ */
+public final class KeyDecoderProperty extends
EdgeExecutionProperty<DecoderFactory> {
+ /**
+ * Constructor.
+ *
+ * @param value value of the execution property.
+ */
+ private KeyDecoderProperty(final DecoderFactory value) {
+ super(value);
+ }
+
+ /**
+ * Static method exposing the constructor.
+ *
+ * @param value value of the new execution property.
+ * @return the newly created execution property.
+ */
+ public static KeyDecoderProperty of(final DecoderFactory value) {
+ return new KeyDecoderProperty(value);
+ }
+}
diff --git
a/common/src/main/java/edu/snu/nemo/common/ir/edge/executionproperty/KeyEncoderProperty.java
b/common/src/main/java/edu/snu/nemo/common/ir/edge/executionproperty/KeyEncoderProperty.java
new file mode 100644
index 0000000..c767484
--- /dev/null
+++
b/common/src/main/java/edu/snu/nemo/common/ir/edge/executionproperty/KeyEncoderProperty.java
@@ -0,0 +1,43 @@
+/*
+ * Copyright (C) 2018 Seoul National University
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package edu.snu.nemo.common.ir.edge.executionproperty;
+
+import edu.snu.nemo.common.coder.EncoderFactory;
+import edu.snu.nemo.common.ir.executionproperty.EdgeExecutionProperty;
+
+/**
+ * KeyEncoder ExecutionProperty.
+ */
+public final class KeyEncoderProperty extends
EdgeExecutionProperty<EncoderFactory> {
+ /**
+ * Constructor.
+ *
+ * @param value value of the execution property.
+ */
+ private KeyEncoderProperty(final EncoderFactory value) {
+ super(value);
+ }
+
+ /**
+ * Static method exposing the constructor.
+ *
+ * @param value value of the new execution property.
+ * @return the newly created execution property.
+ */
+ public static KeyEncoderProperty of(final EncoderFactory value) {
+ return new KeyEncoderProperty(value);
+ }
+}
diff --git
a/common/src/main/java/edu/snu/nemo/common/ir/vertex/MetricCollectionBarrierVertex.java
b/common/src/main/java/edu/snu/nemo/common/ir/vertex/MetricCollectionBarrierVertex.java
deleted file mode 100644
index a4a03bf..0000000
---
a/common/src/main/java/edu/snu/nemo/common/ir/vertex/MetricCollectionBarrierVertex.java
+++ /dev/null
@@ -1,128 +0,0 @@
-/*
- * Copyright (C) 2018 Seoul National University
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package edu.snu.nemo.common.ir.vertex;
-
-import edu.snu.nemo.common.dag.DAG;
-import edu.snu.nemo.common.exception.DynamicOptimizationException;
-import edu.snu.nemo.common.ir.edge.IREdge;
-
-import java.util.*;
-
-/**
- * IRVertex that collects statistics to send them to the optimizer for dynamic
optimization.
- * This class is generated in the DAG through
- *
{edu.snu.nemo.compiler.optimizer.pass.compiletime.composite.DataSkewCompositePass}.
- * @param <K> type of the key of metric data.
- * @param <V> type of the value of metric data.
- */
-public final class MetricCollectionBarrierVertex<K, V> extends IRVertex {
- // Metric data used for dynamic optimization.
- private Map<K, V> metricData;
- private final List<String> blockIds;
-
- // This DAG snapshot is taken at the end of the DataSkewCompositePass, for
the vertex to know the state of the DAG at
- // its optimization, and to be able to figure out exactly where in the DAG
the vertex exists.
- private DAG<IRVertex, IREdge> dagSnapshot;
-
- /**
- * Constructor for dynamic optimization vertex.
- */
- public MetricCollectionBarrierVertex() {
- super();
- this.metricData = new HashMap<>();
- this.blockIds = new ArrayList<>();
- this.dagSnapshot = null;
- }
-
- /**
- * Constructor for dynamic optimization vertex.
- *
- * @param that the source object for copying
- */
- public MetricCollectionBarrierVertex(final MetricCollectionBarrierVertex<K,
V> that) {
- super(that);
- this.metricData = new HashMap<>();
- that.metricData.forEach(this.metricData::put);
- this.blockIds = new ArrayList<>();
- that.blockIds.forEach(this.blockIds::add);
- this.dagSnapshot = that.dagSnapshot;
- }
-
- @Override
- public MetricCollectionBarrierVertex getClone() {
- return new MetricCollectionBarrierVertex(this);
- }
-
- /**
- * This is to set the DAG snapshot at the end of the DataSkewCompositePass.
- * @param dag DAG to set on the vertex.
- */
- public void setDAGSnapshot(final DAG<IRVertex, IREdge> dag) {
- this.dagSnapshot = dag;
- }
-
- /**
- * Access the DAG snapshot when triggering dynamic optimization.
- * @return the DAG set to the vertex, or throws an exception otherwise.
- */
- public DAG<IRVertex, IREdge> getDAGSnapshot() {
- if (this.dagSnapshot == null) {
- throw new DynamicOptimizationException("MetricCollectionBarrierVertex
must have been set with a DAG.");
- }
- return this.dagSnapshot;
- }
-
- /**
- * Method for accumulating metrics in the vertex.
- * @param metric map of hash value of the key of the block to the block size.
- */
- public void setMetricData(final Map<K, V> metric) {
- metricData = metric;
- }
-
- /**
- * Method for retrieving metrics from the vertex.
- * @return the accumulated metric data.
- */
- public Map<K, V> getMetricData() {
- return metricData;
- }
-
- /**
- * Add block id that is needed for optimization in RuntimePass.
- * @param blockId the block id subjected to the optimization.
- */
- public void addBlockId(final String blockId) {
- blockIds.add(blockId);
- }
-
- /**
- * Retrieve block ids.
- * @return the block ids subjected to optimization.
- */
- public List<String> getBlockIds() {
- return blockIds;
- }
-
- @Override
- public String propertiesToJSON() {
- final StringBuilder sb = new StringBuilder();
- sb.append("{");
- sb.append(irVertexPropertiesToString());
- sb.append("}");
- return sb.toString();
- }
-}
diff --git
a/common/src/main/java/edu/snu/nemo/common/ir/vertex/transform/AggregateMetricTransform.java
b/common/src/main/java/edu/snu/nemo/common/ir/vertex/transform/AggregateMetricTransform.java
new file mode 100644
index 0000000..8b81728
--- /dev/null
+++
b/common/src/main/java/edu/snu/nemo/common/ir/vertex/transform/AggregateMetricTransform.java
@@ -0,0 +1,69 @@
+/*
+ * Copyright (C) 2018 Seoul National University
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package edu.snu.nemo.common.ir.vertex.transform;
+
+import edu.snu.nemo.common.ir.OutputCollector;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.function.BiFunction;
+
+/**
+ * A {@link Transform} that aggregates stage-level statistics sent to the
master side optimizer
+ * for dynamic optimization.
+ *
+ * @param <I> input type.
+ * @param <O> output type.
+ */
+public final class AggregateMetricTransform<I, O> implements Transform<I, O> {
+ private static final Logger LOG =
LoggerFactory.getLogger(AggregateMetricTransform.class.getName());
+ private OutputCollector<O> outputCollector;
+ private O aggregatedDynOptData;
+ private final BiFunction<Object, O, O> dynOptDataAggregator;
+
+ /**
+ * Default constructor.
+ */
+ public AggregateMetricTransform(final O aggregatedDynOptData,
+ final BiFunction<Object, O, O>
dynOptDataAggregator) {
+ this.aggregatedDynOptData = aggregatedDynOptData;
+ this.dynOptDataAggregator = dynOptDataAggregator;
+ }
+
+ @Override
+ public void prepare(final Context context, final OutputCollector<O> oc) {
+ this.outputCollector = oc;
+ }
+
+ @Override
+ public void onData(final I element) {
+ aggregatedDynOptData = dynOptDataAggregator.apply(element,
aggregatedDynOptData);
+ }
+
+ @Override
+ public void close() {
+ outputCollector.emit(aggregatedDynOptData);
+ }
+
+ @Override
+ public String toString() {
+ final StringBuilder sb = new StringBuilder();
+ sb.append(AggregateMetricTransform.class);
+ sb.append(":");
+ sb.append(super.toString());
+ return sb.toString();
+ }
+}
diff --git
a/common/src/main/java/edu/snu/nemo/common/ir/vertex/transform/MetricCollectTransform.java
b/common/src/main/java/edu/snu/nemo/common/ir/vertex/transform/MetricCollectTransform.java
new file mode 100644
index 0000000..c696d23
--- /dev/null
+++
b/common/src/main/java/edu/snu/nemo/common/ir/vertex/transform/MetricCollectTransform.java
@@ -0,0 +1,73 @@
+/*
+ * Copyright (C) 2018 Seoul National University
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package edu.snu.nemo.common.ir.vertex.transform;
+
+import edu.snu.nemo.common.ir.OutputCollector;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.function.BiFunction;
+
+/**
+ * A {@link Transform} that collects task-level statistics used for dynamic
optimization.
+ * The collected statistics is sent to vertex with {@link
AggregateMetricTransform} as a tagged output
+ * when this transform is closed.
+ *
+ * @param <I> input type.
+ * @param <O> output type.
+ */
+public final class MetricCollectTransform<I, O> implements Transform<I, O> {
+ private static final Logger LOG =
LoggerFactory.getLogger(MetricCollectTransform.class.getName());
+ private OutputCollector<O> outputCollector;
+ private O dynOptData;
+ private final BiFunction<Object, O, O> dynOptDataCollector;
+ private final BiFunction<O, OutputCollector, O> closer;
+
+ /**
+ * MetricCollectTransform constructor.
+ */
+ public MetricCollectTransform(final O dynOptData,
+ final BiFunction<Object, O, O>
dynOptDataCollector,
+ final BiFunction<O, OutputCollector, O>
closer) {
+ this.dynOptData = dynOptData;
+ this.dynOptDataCollector = dynOptDataCollector;
+ this.closer = closer;
+ }
+
+ @Override
+ public void prepare(final Context context, final OutputCollector<O> oc) {
+ this.outputCollector = oc;
+ }
+
+ @Override
+ public void onData(final I element) {
+ dynOptData = dynOptDataCollector.apply(element, dynOptData);
+ }
+
+ @Override
+ public void close() {
+ closer.apply(dynOptData, outputCollector);
+ }
+
+ @Override
+ public String toString() {
+ final StringBuilder sb = new StringBuilder();
+ sb.append(MetricCollectTransform.class);
+ sb.append(":");
+ sb.append(super.toString());
+ return sb.toString();
+ }
+}
diff --git a/common/src/main/java/edu/snu/nemo/common/test/EmptyComponents.java
b/common/src/main/java/edu/snu/nemo/common/test/EmptyComponents.java
index 9c02ffd..e1ae497 100644
--- a/common/src/main/java/edu/snu/nemo/common/test/EmptyComponents.java
+++ b/common/src/main/java/edu/snu/nemo/common/test/EmptyComponents.java
@@ -15,16 +15,23 @@
*/
package edu.snu.nemo.common.test;
+import edu.snu.nemo.common.KeyExtractor;
+import edu.snu.nemo.common.coder.DecoderFactory;
+import edu.snu.nemo.common.coder.EncoderFactory;
import edu.snu.nemo.common.dag.DAG;
import edu.snu.nemo.common.dag.DAGBuilder;
import edu.snu.nemo.common.ir.OutputCollector;
import edu.snu.nemo.common.ir.Readable;
import edu.snu.nemo.common.ir.edge.IREdge;
import
edu.snu.nemo.common.ir.edge.executionproperty.CommunicationPatternProperty;
+import edu.snu.nemo.common.ir.edge.executionproperty.DecoderProperty;
+import edu.snu.nemo.common.ir.edge.executionproperty.EncoderProperty;
+import edu.snu.nemo.common.ir.edge.executionproperty.KeyExtractorProperty;
import edu.snu.nemo.common.ir.vertex.IRVertex;
import edu.snu.nemo.common.ir.vertex.OperatorVertex;
import edu.snu.nemo.common.ir.vertex.SourceVertex;
import edu.snu.nemo.common.ir.vertex.transform.Transform;
+import org.apache.beam.sdk.values.KV;
import java.util.ArrayList;
import java.util.List;
@@ -38,6 +45,10 @@ public final class EmptyComponents {
private EmptyComponents() {
}
+ /**
+ * Builds dummy IR DAG for testing.
+ * @return the dummy IR DAG.
+ */
public static DAG<IRVertex, IREdge> buildEmptyDAG() {
DAGBuilder<IRVertex, IREdge> dagBuilder = new DAGBuilder<>();
final IRVertex s = new EmptyComponents.EmptySourceVertex<>("s");
@@ -61,6 +72,61 @@ public final class EmptyComponents {
}
/**
+ * Builds dummy IR DAG to test skew handling.
+ * For DataSkewPolicy, shuffle edges needs extra setting for
EncoderProperty, DecoderProperty
+ * and KeyExtractorProperty by default.
+ * @return the dummy IR DAG.
+ */
+ public static DAG<IRVertex, IREdge> buildEmptyDAGForSkew() {
+ DAGBuilder<IRVertex, IREdge> dagBuilder = new DAGBuilder<>();
+ final IRVertex s = new EmptyComponents.EmptySourceVertex<>("s");
+ final IRVertex t1 = new OperatorVertex(new
EmptyComponents.EmptyTransform("t1"));
+ final IRVertex t2 = new OperatorVertex(new
EmptyComponents.EmptyTransform("t2"));
+ final IRVertex t3 = new OperatorVertex(new
EmptyComponents.EmptyTransform("t3"));
+ final IRVertex t4 = new OperatorVertex(new
EmptyComponents.EmptyTransform("t4"));
+ final IRVertex t5 = new OperatorVertex(new
EmptyComponents.EmptyTransform("t5"));
+
+ final IREdge shuffleEdgeBetweenT1AndT2 = new
IREdge(CommunicationPatternProperty.Value.Shuffle, t1, t2);
+ shuffleEdgeBetweenT1AndT2.setProperty(KeyExtractorProperty.of(new
DummyBeamKeyExtractor()));
+ shuffleEdgeBetweenT1AndT2.setProperty(EncoderProperty.of(new
EncoderFactory.DummyEncoderFactory()));
+ shuffleEdgeBetweenT1AndT2.setProperty(DecoderProperty.of(new
DecoderFactory.DummyDecoderFactory()));
+
+ final IREdge shuffleEdgeBetweenT3AndT4 = new
IREdge(CommunicationPatternProperty.Value.Shuffle, t3, t4);
+ shuffleEdgeBetweenT3AndT4.setProperty(KeyExtractorProperty.of(new
DummyBeamKeyExtractor()));
+ shuffleEdgeBetweenT3AndT4.setProperty(EncoderProperty.of(new
EncoderFactory.DummyEncoderFactory()));
+ shuffleEdgeBetweenT3AndT4.setProperty(DecoderProperty.of(new
DecoderFactory.DummyDecoderFactory()));
+
+ dagBuilder.addVertex(s);
+ dagBuilder.addVertex(t1);
+ dagBuilder.addVertex(t2);
+ dagBuilder.addVertex(t3);
+ dagBuilder.addVertex(t4);
+ dagBuilder.addVertex(t5);
+ dagBuilder.connectVertices(new
IREdge(CommunicationPatternProperty.Value.OneToOne, s, t1));
+ dagBuilder.connectVertices(shuffleEdgeBetweenT1AndT2);
+ dagBuilder.connectVertices(new
IREdge(CommunicationPatternProperty.Value.OneToOne, t2, t3));
+ dagBuilder.connectVertices(shuffleEdgeBetweenT3AndT4);
+ dagBuilder.connectVertices(new
IREdge(CommunicationPatternProperty.Value.OneToOne, t2, t5));
+ return dagBuilder.build();
+ }
+
+ /**
+ * Dummy beam key extractor.
+ **/
+ static class DummyBeamKeyExtractor implements KeyExtractor {
+ @Override
+ public Object extractKey(final Object element) {
+ if (element instanceof KV) {
+ // Handle null keys, since Beam allows KV with null keys.
+ final Object key = ((KV) element).getKey();
+ return key == null ? 0 : key;
+ } else {
+ return element;
+ }
+ }
+ }
+
+ /**
* An empty transform.
*
* @param <I> input type.
diff --git
a/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/PipelineTranslator.java
b/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/PipelineTranslator.java
index a744ae8..d334d15 100644
---
a/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/PipelineTranslator.java
+++
b/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/PipelineTranslator.java
@@ -430,15 +430,25 @@ public final class PipelineTranslator
throw new RuntimeException(String.format("While adding an edge from
%s, to %s, coder for PValue %s cannot "
+ "be determined", src, dst, input));
}
+
+ edge.setProperty(KeyExtractorProperty.of(new BeamKeyExtractor()));
+
+ if (coder instanceof KvCoder) {
+ Coder keyCoder = ((KvCoder) coder).getKeyCoder();
+ edge.setProperty(KeyEncoderProperty.of(new
BeamEncoderFactory(keyCoder)));
+ edge.setProperty(KeyDecoderProperty.of(new
BeamDecoderFactory(keyCoder)));
+ }
edge.setProperty(EncoderProperty.of(new BeamEncoderFactory<>(coder)));
edge.setProperty(DecoderProperty.of(new BeamDecoderFactory<>(coder)));
+
if (pValueToTag.containsKey(input)) {
edge.setProperty(AdditionalOutputTagProperty.of(pValueToTag.get(input).getId()));
}
+
if (input instanceof PCollectionView) {
edge.setProperty(BroadcastVariableIdProperty.of((PCollectionView)
input));
}
- edge.setProperty(KeyExtractorProperty.of(new BeamKeyExtractor()));
+
builder.connectVertices(edge);
}
diff --git
a/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/coder/BeamDecoderFactory.java
b/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/coder/BeamDecoderFactory.java
index 7ebea38..f2aedaa 100644
---
a/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/coder/BeamDecoderFactory.java
+++
b/compiler/frontend/beam/src/main/java/edu/snu/nemo/compiler/frontend/beam/coder/BeamDecoderFactory.java
@@ -19,15 +19,19 @@ import edu.snu.nemo.common.coder.DecoderFactory;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.CoderException;
import org.apache.beam.sdk.coders.VoidCoder;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.io.InputStream;
/**
* {@link DecoderFactory} from {@link org.apache.beam.sdk.coders.Coder}.
- * @param <T> the type of element to encode.
+ * @param <T> the type of element to decode.
*/
public final class BeamDecoderFactory<T> implements DecoderFactory<T> {
+ private static final Logger LOG =
LoggerFactory.getLogger(BeamDecoderFactory.class);
+
private final Coder<T> beamCoder;
/**
diff --git
a/compiler/frontend/spark/src/main/scala/edu/snu/nemo/compiler/frontend/spark/core/rdd/PairRDDFunctions.scala
b/compiler/frontend/spark/src/main/scala/edu/snu/nemo/compiler/frontend/spark/core/rdd/PairRDDFunctions.scala
index 14f764a..e4f02a3 100644
---
a/compiler/frontend/spark/src/main/scala/edu/snu/nemo/compiler/frontend/spark/core/rdd/PairRDDFunctions.scala
+++
b/compiler/frontend/spark/src/main/scala/edu/snu/nemo/compiler/frontend/spark/core/rdd/PairRDDFunctions.scala
@@ -19,7 +19,7 @@ import java.util
import edu.snu.nemo.common.dag.DAGBuilder
import edu.snu.nemo.common.ir.edge.IREdge
-import edu.snu.nemo.common.ir.edge.executionproperty.{DecoderProperty,
EncoderProperty, KeyExtractorProperty}
+import edu.snu.nemo.common.ir.edge.executionproperty._
import edu.snu.nemo.common.ir.executionproperty.EdgeExecutionProperty
import edu.snu.nemo.common.ir.vertex.{IRVertex, LoopVertex, OperatorVertex}
import edu.snu.nemo.compiler.frontend.spark.SparkKeyExtractor
@@ -77,6 +77,10 @@ final class PairRDDFunctions[K: ClassTag, V: ClassTag]
protected[rdd] (
newEdge.setProperty(
DecoderProperty.of(new SparkDecoderFactory[Tuple2[K,
V]](self.serializer))
.asInstanceOf[EdgeExecutionProperty[_ <: Serializable]])
+ // For Tuple2 type data, set KeyEn(De)coderFactoryProperty
+ // in case it is subjected to dynamic optimization.
+ newEdge.setProperty(KeyEncoderProperty.of(new
SparkEncoderFactory[K](self.serializer)))
+ newEdge.setProperty(KeyDecoderProperty.of(new
SparkDecoderFactory[K](self.serializer)))
newEdge.setProperty(KeyExtractorProperty.of(new SparkKeyExtractor))
builder.connectVertices(newEdge)
diff --git
a/compiler/frontend/spark/src/main/scala/edu/snu/nemo/compiler/frontend/spark/core/rdd/RDD.scala
b/compiler/frontend/spark/src/main/scala/edu/snu/nemo/compiler/frontend/spark/core/rdd/RDD.scala
index c253f17..abcebc8 100644
---
a/compiler/frontend/spark/src/main/scala/edu/snu/nemo/compiler/frontend/spark/core/rdd/RDD.scala
+++
b/compiler/frontend/spark/src/main/scala/edu/snu/nemo/compiler/frontend/spark/core/rdd/RDD.scala
@@ -30,7 +30,6 @@ import
edu.snu.nemo.compiler.frontend.spark.{SparkBroadcastVariables, SparkKeyEx
import edu.snu.nemo.compiler.frontend.spark.coder.{SparkDecoderFactory,
SparkEncoderFactory}
import edu.snu.nemo.compiler.frontend.spark.core.SparkFrontendUtils
import edu.snu.nemo.compiler.frontend.spark.transform._
-import org.apache.commons.lang.SerializationUtils
import org.apache.hadoop.io.WritableFactory
import org.apache.hadoop.io.compress.CompressionCodec
import org.apache.spark.api.java.function.{FlatMapFunction, Function,
Function2}
diff --git
a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/annotating/SkewDataStorePass.java
b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/annotating/SkewDataStorePass.java
deleted file mode 100644
index 984e8df..0000000
---
a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/annotating/SkewDataStorePass.java
+++ /dev/null
@@ -1,58 +0,0 @@
-/*
- * Copyright (C) 2018 Seoul National University
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package edu.snu.nemo.compiler.optimizer.pass.compiletime.annotating;
-
-import edu.snu.nemo.common.dag.DAG;
-import edu.snu.nemo.common.ir.edge.IREdge;
-import edu.snu.nemo.common.ir.edge.executionproperty.DataStoreProperty;
-import edu.snu.nemo.common.ir.vertex.IRVertex;
-import edu.snu.nemo.common.ir.vertex.MetricCollectionBarrierVertex;
-
-/**
- * Pass to annotate the DAG for a job to perform data skew.
- * It specifies the incoming one-to-one edges to MetricCollectionVertices to
have either MemoryStore or LocalFileStore
- * as its DataStore ExecutionProperty.
- */
-@Annotates(DataStoreProperty.class)
-public final class SkewDataStorePass extends AnnotatingPass {
- /**
- * Default constructor.
- */
- public SkewDataStorePass() {
- super(SkewDataStorePass.class);
- }
-
- @Override
- public DAG<IRVertex, IREdge> apply(final DAG<IRVertex, IREdge> dag) {
- dag.topologicalDo(v -> {
- // we only care about metric collection barrier vertices.
- if (v instanceof MetricCollectionBarrierVertex) {
- // We use memory for just a single inEdge, to make use of locality of
stages: {@link PhysicalPlanGenerator}.
- final IREdge edgeToUseMemory =
dag.getIncomingEdgesOf(v).stream().findFirst().orElseThrow(() ->
- new RuntimeException("This MetricCollectionBarrierVertex doesn't
have any incoming edges: " + v.getId()));
- dag.getIncomingEdgesOf(v).forEach(edge -> {
- // we want it to be in the same stage
- if (edge.equals(edgeToUseMemory)) {
-
edge.setPropertyPermanently(DataStoreProperty.of(DataStoreProperty.Value.MemoryStore));
- } else {
-
edge.setPropertyPermanently(DataStoreProperty.of(DataStoreProperty.Value.LocalFileStore));
- }
- });
- }
- });
- return dag;
- }
-}
diff --git
a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/annotating/SkewMetricCollectionPass.java
b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/annotating/SkewMetricCollectionPass.java
index ad635c6..95f1cd0 100644
---
a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/annotating/SkewMetricCollectionPass.java
+++
b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/annotating/SkewMetricCollectionPass.java
@@ -19,14 +19,17 @@ import edu.snu.nemo.common.dag.DAG;
import edu.snu.nemo.common.ir.edge.IREdge;
import
edu.snu.nemo.common.ir.edge.executionproperty.CommunicationPatternProperty;
import edu.snu.nemo.common.ir.vertex.IRVertex;
-import edu.snu.nemo.common.ir.vertex.MetricCollectionBarrierVertex;
import edu.snu.nemo.common.ir.edge.executionproperty.MetricCollectionProperty;
+import edu.snu.nemo.common.ir.vertex.OperatorVertex;
+import edu.snu.nemo.common.ir.vertex.transform.MetricCollectTransform;
import edu.snu.nemo.compiler.optimizer.pass.compiletime.Requires;
/**
- * Pass to annotate the DAG for a job to perform data skew.
- * It specifies the outgoing Shuffle edges from MetricCollectionVertices with
a MetricCollection ExecutionProperty
- * which lets the edge to know what metric collection it should perform.
+ * Pass to annotate the IR DAG for skew handling.
+ *
+ * It specifies the target of dynamic optimization for skew handling
+ * by setting appropriate {@link MetricCollectionProperty} to
+ * outgoing shuffle edges from vertices with {@link MetricCollectTransform}.
*/
@Annotates(MetricCollectionProperty.class)
@Requires(CommunicationPatternProperty.class)
@@ -41,8 +44,9 @@ public final class SkewMetricCollectionPass extends
AnnotatingPass {
@Override
public DAG<IRVertex, IREdge> apply(final DAG<IRVertex, IREdge> dag) {
dag.topologicalDo(v -> {
- // we only care about metric collection barrier vertices.
- if (v instanceof MetricCollectionBarrierVertex) {
+ // we only care about metric collection vertices.
+ if (v instanceof OperatorVertex
+ && ((OperatorVertex) v).getTransform() instanceof
MetricCollectTransform) {
dag.getOutgoingEdgesOf(v).forEach(edge -> {
// double checking.
if (edge.getPropertyValue(CommunicationPatternProperty.class).get()
diff --git
a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/annotating/SkewPartitionerPass.java
b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/annotating/SkewPartitionerPass.java
index 7ffd221..ab34fff 100644
---
a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/annotating/SkewPartitionerPass.java
+++
b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/annotating/SkewPartitionerPass.java
@@ -19,8 +19,9 @@ import edu.snu.nemo.common.dag.DAG;
import edu.snu.nemo.common.ir.edge.IREdge;
import edu.snu.nemo.common.ir.edge.executionproperty.MetricCollectionProperty;
import edu.snu.nemo.common.ir.vertex.IRVertex;
-import edu.snu.nemo.common.ir.vertex.MetricCollectionBarrierVertex;
import edu.snu.nemo.common.ir.edge.executionproperty.PartitionerProperty;
+import edu.snu.nemo.common.ir.vertex.OperatorVertex;
+import edu.snu.nemo.common.ir.vertex.transform.AggregateMetricTransform;
import edu.snu.nemo.compiler.optimizer.pass.compiletime.Requires;
import java.util.List;
@@ -40,9 +41,10 @@ public final class SkewPartitionerPass extends
AnnotatingPass {
@Override
public DAG<IRVertex, IREdge> apply(final DAG<IRVertex, IREdge> dag) {
- dag.getVertices().forEach(vertex -> {
- if (vertex instanceof MetricCollectionBarrierVertex) {
- final List<IREdge> outEdges = dag.getOutgoingEdgesOf(vertex);
+ dag.getVertices().forEach(v -> {
+ if (v instanceof OperatorVertex
+ && ((OperatorVertex) v).getTransform() instanceof
AggregateMetricTransform) {
+ final List<IREdge> outEdges = dag.getOutgoingEdgesOf(v);
outEdges.forEach(edge -> {
// double checking.
if (MetricCollectionProperty.Value.DataSkewRuntimePass
diff --git
a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/annotating/SkewResourceSkewedDataPass.java
b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/annotating/SkewResourceSkewedDataPass.java
index c3739af..806cef1 100644
---
a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/annotating/SkewResourceSkewedDataPass.java
+++
b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/annotating/SkewResourceSkewedDataPass.java
@@ -18,15 +18,19 @@ package
edu.snu.nemo.compiler.optimizer.pass.compiletime.annotating;
import edu.snu.nemo.common.dag.DAG;
import edu.snu.nemo.common.ir.edge.IREdge;
import edu.snu.nemo.common.ir.vertex.IRVertex;
-import edu.snu.nemo.common.ir.vertex.MetricCollectionBarrierVertex;
+import edu.snu.nemo.common.ir.vertex.OperatorVertex;
import
edu.snu.nemo.common.ir.vertex.executionproperty.DynamicOptimizationProperty;
import
edu.snu.nemo.common.ir.vertex.executionproperty.ResourceSkewedDataProperty;
+import edu.snu.nemo.common.ir.vertex.transform.MetricCollectTransform;
import java.util.List;
/**
- * Pass to annotate the DAG for a job to perform data skew.
- * It specifies which optimization to perform on the
MetricCollectionBarrierVertex.
+ * Pass to annotate the IR DAG for skew handling.
+ *
+ * It marks children and descendents of vertex with {@link
MetricCollectTransform},
+ * which collects task-level statistics used for dynamic optimization,
+ * with {@link ResourceSkewedDataProperty} to perform skewness-aware
scheduling.
*/
@Annotates(DynamicOptimizationProperty.class)
public final class SkewResourceSkewedDataPass extends AnnotatingPass {
@@ -37,11 +41,12 @@ public final class SkewResourceSkewedDataPass extends
AnnotatingPass {
super(SkewResourceSkewedDataPass.class);
}
- private boolean hasMetricCollectionBarrierVertexAsParent(final DAG<IRVertex,
IREdge> dag,
- final IRVertex v) {
+ private boolean hasParentWithMetricCollectTransform(final DAG<IRVertex,
IREdge> dag,
+ final IRVertex v) {
List<IRVertex> parents = dag.getParents(v.getId());
for (IRVertex parent : parents) {
- if (parent instanceof MetricCollectionBarrierVertex) {
+ if (parent instanceof OperatorVertex
+ && ((OperatorVertex) v).getTransform() instanceof
MetricCollectTransform) {
return true;
}
}
@@ -51,12 +56,13 @@ public final class SkewResourceSkewedDataPass extends
AnnotatingPass {
@Override
public DAG<IRVertex, IREdge> apply(final DAG<IRVertex, IREdge> dag) {
dag.getVertices().stream()
- .filter(v -> v instanceof MetricCollectionBarrierVertex)
- .forEach(v -> v.setProperty(DynamicOptimizationProperty
+ .filter(v -> v instanceof OperatorVertex
+ && ((OperatorVertex) v).getTransform() instanceof
MetricCollectTransform)
+ .forEach(v -> v.setProperty(DynamicOptimizationProperty
.of(DynamicOptimizationProperty.Value.DataSkewRuntimePass)));
dag.getVertices().stream()
- .filter(v -> hasMetricCollectionBarrierVertexAsParent(dag, v)
+ .filter(v -> hasParentWithMetricCollectTransform(dag, v)
&&
!v.getExecutionProperties().containsKey(ResourceSkewedDataProperty.class))
.forEach(childV -> {
childV.getExecutionProperties().put(ResourceSkewedDataProperty.of(true));
diff --git
a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/composite/SkewCompositePass.java
b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/composite/SkewCompositePass.java
index 5e3d8fd..1be97fa 100644
---
a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/composite/SkewCompositePass.java
+++
b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/composite/SkewCompositePass.java
@@ -22,11 +22,6 @@ import java.util.Arrays;
/**
* Pass to modify the DAG for a job to perform data skew.
- * It adds a {@link
edu.snu.nemo.common.ir.vertex.MetricCollectionBarrierVertex} before Shuffle
edges,
- * to make a barrier before it, and to use the metrics to repartition the
skewed data.
- * NOTE: we currently put the SkewCompositePass at the end of the list for
each policies, as it needs to take a
- * snapshot at the end of the pass. This could be prevented by modifying other
passes to take the snapshot of the DAG
- * at the end of each passes for metricCollectionVertices.
*/
public final class SkewCompositePass extends CompositePass {
/**
@@ -36,9 +31,7 @@ public final class SkewCompositePass extends CompositePass {
super(Arrays.asList(
new SkewReshapingPass(),
new SkewResourceSkewedDataPass(),
- new SkewDataStorePass(),
- new SkewMetricCollectionPass(),
- new SkewPartitionerPass()
+ new SkewMetricCollectionPass()
));
}
}
diff --git
a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/reshaping/SkewReshapingPass.java
b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/reshaping/SkewReshapingPass.java
index 571476d..fe12b95 100644
---
a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/reshaping/SkewReshapingPass.java
+++
b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/reshaping/SkewReshapingPass.java
@@ -15,30 +15,44 @@
*/
package edu.snu.nemo.compiler.optimizer.pass.compiletime.reshaping;
+import edu.snu.nemo.common.KeyExtractor;
+import edu.snu.nemo.common.Pair;
+import edu.snu.nemo.common.coder.*;
import edu.snu.nemo.common.dag.DAG;
import edu.snu.nemo.common.dag.DAGBuilder;
+import edu.snu.nemo.common.ir.OutputCollector;
import edu.snu.nemo.common.ir.edge.IREdge;
+import edu.snu.nemo.common.ir.edge.executionproperty.*;
import
edu.snu.nemo.common.ir.edge.executionproperty.CommunicationPatternProperty;
import edu.snu.nemo.common.ir.edge.executionproperty.DecoderProperty;
import edu.snu.nemo.common.ir.edge.executionproperty.EncoderProperty;
import edu.snu.nemo.common.ir.vertex.IRVertex;
-import edu.snu.nemo.common.ir.vertex.MetricCollectionBarrierVertex;
import edu.snu.nemo.common.ir.vertex.OperatorVertex;
+import edu.snu.nemo.common.ir.vertex.transform.AggregateMetricTransform;
+import edu.snu.nemo.common.ir.vertex.transform.MetricCollectTransform;
import edu.snu.nemo.compiler.optimizer.pass.compiletime.Requires;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import java.io.Serializable;
import java.util.ArrayList;
+import java.util.HashMap;
import java.util.List;
+import java.util.Map;
+import java.util.function.BiFunction;
/**
- * Pass to modify the DAG for a job to perform data skew.
- * It adds a {@link MetricCollectionBarrierVertex} before Shuffle edges, to
make a barrier before it,
- * and to use the metrics to repartition the skewed data.
- * NOTE: we currently put the SkewCompositePass at the end of the list for
each policies, as it needs to take
- * a snapshot at the end of the pass. This could be prevented by modifying
other passes to take the snapshot of the
- * DAG at the end of each passes for metricCollectionVertices.
- */
+ * Pass to reshape the IR DAG for skew handling.
+ *
+ * This pass inserts vertices to perform two-step dynamic optimization for
skew handling.
+ * 1) Task-level statistic collection is done via vertex with {@link
MetricCollectTransform}
+ * 2) Stage-level statistic aggregation is done via vertex with {@link
AggregateMetricTransform}
+ * inserted before shuffle edges.
+ * */
@Requires(CommunicationPatternProperty.class)
public final class SkewReshapingPass extends ReshapingPass {
+ private static final Logger LOG =
LoggerFactory.getLogger(SkewReshapingPass.class.getName());
+
/**
* Default constructor.
*/
@@ -49,33 +63,38 @@ public final class SkewReshapingPass extends ReshapingPass {
@Override
public DAG<IRVertex, IREdge> apply(final DAG<IRVertex, IREdge> dag) {
final DAGBuilder<IRVertex, IREdge> builder = new DAGBuilder<>();
- final List<MetricCollectionBarrierVertex> metricCollectionVertices = new
ArrayList<>();
+ final List<OperatorVertex> metricCollectVertices = new ArrayList<>();
dag.topologicalDo(v -> {
- // We care about OperatorVertices that have any incoming edges that are
of type Shuffle.
+ // We care about OperatorVertices that have shuffle incoming edges with
main output.
+ // TODO #210: Data-aware dynamic optimization at run-time
if (v instanceof OperatorVertex &&
dag.getIncomingEdgesOf(v).stream().anyMatch(irEdge ->
CommunicationPatternProperty.Value.Shuffle
-
.equals(irEdge.getPropertyValue(CommunicationPatternProperty.class).get()))) {
- final MetricCollectionBarrierVertex<Integer, Long>
metricCollectionBarrierVertex
- = new MetricCollectionBarrierVertex<>();
- metricCollectionVertices.add(metricCollectionBarrierVertex);
- builder.addVertex(v);
- builder.addVertex(metricCollectionBarrierVertex);
+
.equals(irEdge.getPropertyValue(CommunicationPatternProperty.class).get()))
+ && dag.getIncomingEdgesOf(v).stream().noneMatch(irEdge ->
+ irEdge.getPropertyValue(AdditionalOutputTagProperty.class).isPresent()))
{
+
dag.getIncomingEdgesOf(v).forEach(edge -> {
- // we insert the metric collection vertex when we meet a shuffle edge
if (CommunicationPatternProperty.Value.Shuffle
.equals(edge.getPropertyValue(CommunicationPatternProperty.class).get())) {
- // We then insert the dynamicOptimizationVertex between the vertex
and incoming vertices.
- final IREdge newEdge = new
IREdge(CommunicationPatternProperty.Value.OneToOne,
- edge.getSrc(), metricCollectionBarrierVertex);
-
newEdge.setProperty(EncoderProperty.of(edge.getPropertyValue(EncoderProperty.class).get()));
-
newEdge.setProperty(DecoderProperty.of(edge.getPropertyValue(DecoderProperty.class).get()));
-
- final IREdge edgeToGbK = new IREdge(
- edge.getPropertyValue(CommunicationPatternProperty.class).get(),
metricCollectionBarrierVertex, v);
- edge.copyExecutionPropertiesTo(edgeToGbK);
- builder.connectVertices(newEdge);
- builder.connectVertices(edgeToGbK);
+ final OperatorVertex abv = generateMetricAggregationVertex();
+ final OperatorVertex mcv = generateMetricCollectVertex(edge, abv);
+ metricCollectVertices.add(mcv);
+ builder.addVertex(v);
+ builder.addVertex(mcv);
+ builder.addVertex(abv);
+
+ // We then insert the vertex with MetricCollectTransform and
vertex with AggregateMetricTransform
+ // between the vertex and incoming vertices.
+ final IREdge edgeToMCV = generateEdgeToMCV(edge, mcv);
+ final IREdge edgeToABV = generateEdgeToABV(edge, mcv, abv);
+ final IREdge edgeToOriginalDstV =
+ new
IREdge(edge.getPropertyValue(CommunicationPatternProperty.class).get(),
edge.getSrc(), v);
+ edge.copyExecutionPropertiesTo(edgeToOriginalDstV);
+
+ builder.connectVertices(edgeToMCV);
+ builder.connectVertices(edgeToABV);
+ builder.connectVertices(edgeToOriginalDstV);
} else {
builder.connectVertices(edge);
}
@@ -86,7 +105,97 @@ public final class SkewReshapingPass extends ReshapingPass {
}
});
final DAG<IRVertex, IREdge> newDAG = builder.build();
- metricCollectionVertices.forEach(v -> v.setDAGSnapshot(newDAG));
return newDAG;
}
+
+ private OperatorVertex generateMetricAggregationVertex() {
+ // Define a custom data aggregator for skew handling.
+ // Here, the aggregator gathers key frequency data used in shuffle data
repartitioning.
+ final BiFunction<Object, Map<Object, Long>, Map<Object, Long>>
dynOptDataAggregator =
+ (BiFunction<Object, Map<Object, Long>, Map<Object, Long>> & Serializable)
+ (element, aggregatedDynOptData) -> {
+ final Object key = ((Pair<Object, Long>) element).left();
+ final Long count = ((Pair<Object, Long>) element).right();
+
+ final Map<Object, Long> aggregatedDynOptDataMap = (Map<Object, Long>)
aggregatedDynOptData;
+ if (aggregatedDynOptDataMap.containsKey(key)) {
+ aggregatedDynOptDataMap.compute(key, (existingKey, accumulatedCount)
-> accumulatedCount + count);
+ } else {
+ aggregatedDynOptDataMap.put(key, count);
+ }
+ return aggregatedDynOptData;
+ };
+ final AggregateMetricTransform abt =
+ new AggregateMetricTransform<Pair<Object, Long>, Map<Object, Long>>(new
HashMap<>(), dynOptDataAggregator);
+ return new OperatorVertex(abt);
+ }
+
+ private OperatorVertex generateMetricCollectVertex(final IREdge edge, final
OperatorVertex abv) {
+ final KeyExtractor keyExtractor =
edge.getPropertyValue(KeyExtractorProperty.class).get();
+ // Define a custom data collector for skew handling.
+ // Here, the collector gathers key frequency data used in shuffle data
repartitioning.
+ final BiFunction<Object, Map<Object, Object>, Map<Object, Object>>
dynOptDataCollector =
+ (BiFunction<Object, Map<Object, Object>, Map<Object, Object>> &
Serializable)
+ (element, dynOptData) -> {
+ Object key = keyExtractor.extractKey(element);
+ if (dynOptData.containsKey(key)) {
+ dynOptData.compute(key, (existingKey, existingCount) -> (long)
existingCount + 1L);
+ } else {
+ dynOptData.put(key, 1L);
+ }
+ return dynOptData;
+ };
+
+ // Define a custom transform closer for skew handling.
+ // Here, we emit key to frequency data map type data when closing
transform.
+ final BiFunction<Map<Object, Object>, OutputCollector, Map<Object,
Object>> closer =
+ (BiFunction<Map<Object, Object>, OutputCollector, Map<Object, Object>> &
Serializable)
+ (dynOptData, outputCollector)-> {
+ dynOptData.forEach((k, v) -> {
+ final Pair<Object, Object> pairData = Pair.of(k, v);
+ outputCollector.emit(abv.getId(), pairData);
+ });
+ return dynOptData;
+ };
+
+ final MetricCollectTransform mct
+ = new MetricCollectTransform(new HashMap<>(), dynOptDataCollector,
closer);
+ return new OperatorVertex(mct);
+ }
+
+ private IREdge generateEdgeToMCV(final IREdge edge, final OperatorVertex
mcv) {
+ final IREdge newEdge =
+ new IREdge(CommunicationPatternProperty.Value.OneToOne, edge.getSrc(),
mcv);
+
newEdge.setProperty(EncoderProperty.of(edge.getPropertyValue(EncoderProperty.class).get()));
+
newEdge.setProperty(DecoderProperty.of(edge.getPropertyValue(DecoderProperty.class).get()));
+ return newEdge;
+ }
+
+ private IREdge generateEdgeToABV(final IREdge edge,
+ final OperatorVertex mcv,
+ final OperatorVertex abv) {
+ final IREdge newEdge = new
IREdge(CommunicationPatternProperty.Value.Shuffle, mcv, abv);
+
newEdge.setProperty(DataStoreProperty.of(DataStoreProperty.Value.LocalFileStore));
+
newEdge.setProperty(DataPersistenceProperty.of(DataPersistenceProperty.Value.Keep));
+ newEdge.setProperty(DataFlowProperty.of(DataFlowProperty.Value.Pull));
+
newEdge.setProperty(KeyExtractorProperty.of(edge.getPropertyValue(KeyExtractorProperty.class).get()));
+ newEdge.setProperty(AdditionalOutputTagProperty.of("DynOptData"));
+
+ // Dynamic optimization handles statistics on key-value data by default.
+ // We need to get coders for encoding/decoding the keys to send data to
+ // vertex with AggregateMetricTransform.
+ if (edge.getPropertyValue(KeyEncoderProperty.class).isPresent()
+ && edge.getPropertyValue(KeyDecoderProperty.class).isPresent()) {
+ final EncoderFactory keyEncoderFactory =
edge.getPropertyValue(KeyEncoderProperty.class).get();
+ final DecoderFactory keyDecoderFactory =
edge.getPropertyValue(KeyDecoderProperty.class).get();
+
newEdge.setProperty(EncoderProperty.of(PairEncoderFactory.of(keyEncoderFactory,
LongEncoderFactory.of())));
+
newEdge.setProperty(DecoderProperty.of(PairDecoderFactory.of(keyDecoderFactory,
LongDecoderFactory.of())));
+ } else {
+ // If not specified, follow encoder/decoder of the given shuffle edge.
+
newEdge.setProperty(EncoderProperty.of(edge.getPropertyValue(EncoderProperty.class).get()));
+
newEdge.setProperty(DecoderProperty.of(edge.getPropertyValue(DecoderProperty.class).get()));
+ }
+
+ return newEdge;
+ }
}
diff --git
a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/policy/DataSkewPolicy.java
b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/policy/DataSkewPolicy.java
index e8e7a7c..801cbed 100644
---
a/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/policy/DataSkewPolicy.java
+++
b/compiler/optimizer/src/main/java/edu/snu/nemo/compiler/optimizer/policy/DataSkewPolicy.java
@@ -31,10 +31,11 @@ import org.apache.reef.tang.Injector;
public final class DataSkewPolicy implements Policy {
public static final PolicyBuilder BUILDER =
new PolicyBuilder()
- .registerRuntimePass(new
DataSkewRuntimePass().setNumSkewedKeys(DataSkewRuntimePass.DEFAULT_NUM_SKEWED_KEYS),
- new SkewCompositePass())
- .registerCompileTimePass(new LoopOptimizationCompositePass())
- .registerCompileTimePass(new DefaultCompositePass());
+ .registerRuntimePass(new
DataSkewRuntimePass().setNumSkewedKeys(DataSkewRuntimePass.DEFAULT_NUM_SKEWED_KEYS),
+ new SkewCompositePass())
+ .registerCompileTimePass(new LoopOptimizationCompositePass())
+ .registerCompileTimePass(new DefaultCompositePass());
+
private final Policy policy;
/**
diff --git
a/compiler/optimizer/src/test/java/edu/snu/nemo/compiler/optimizer/policy/PolicyBuilderTest.java
b/compiler/optimizer/src/test/java/edu/snu/nemo/compiler/optimizer/policy/PolicyBuilderTest.java
index 79191ba..35ddbb2 100644
---
a/compiler/optimizer/src/test/java/edu/snu/nemo/compiler/optimizer/policy/PolicyBuilderTest.java
+++
b/compiler/optimizer/src/test/java/edu/snu/nemo/compiler/optimizer/policy/PolicyBuilderTest.java
@@ -38,7 +38,7 @@ public final class PolicyBuilderTest {
@Test
public void testDataSkewPolicy() {
- assertEquals(22, DataSkewPolicy.BUILDER.getCompileTimePasses().size());
+ assertEquals(20, DataSkewPolicy.BUILDER.getCompileTimePasses().size());
assertEquals(1, DataSkewPolicy.BUILDER.getRuntimePasses().size());
}
diff --git
a/compiler/optimizer/src/test/java/edu/snu/nemo/compiler/optimizer/policy/PolicyImplTest.java
b/compiler/optimizer/src/test/java/edu/snu/nemo/compiler/optimizer/policy/PolicyImplTest.java
index 2b2c7b7..66aa366 100644
---
a/compiler/optimizer/src/test/java/edu/snu/nemo/compiler/optimizer/policy/PolicyImplTest.java
+++
b/compiler/optimizer/src/test/java/edu/snu/nemo/compiler/optimizer/policy/PolicyImplTest.java
@@ -18,6 +18,9 @@ package edu.snu.nemo.compiler.optimizer.policy;
import edu.snu.nemo.common.dag.DAG;
import edu.snu.nemo.common.exception.CompileTimeOptimizationException;
+import edu.snu.nemo.common.ir.edge.IREdge;
+import edu.snu.nemo.common.ir.vertex.IRVertex;
+import edu.snu.nemo.common.ir.vertex.OperatorVertex;
import edu.snu.nemo.common.test.EmptyComponents;
import edu.snu.nemo.compiler.optimizer.pass.compiletime.CompileTimePass;
import edu.snu.nemo.runtime.common.optimizer.pass.runtime.RuntimePass;
@@ -31,10 +34,12 @@ import java.util.List;
public final class PolicyImplTest {
private DAG dag;
+ private DAG dagForSkew;
@Before
public void setUp() {
this.dag = EmptyComponents.buildEmptyDAG();
+ this.dagForSkew = EmptyComponents.buildEmptyDAGForSkew();
}
@Rule
@@ -55,7 +60,7 @@ public final class PolicyImplTest {
@Test
public void testDataSkewPolicy() throws Exception {
// this should run without an exception.
- DataSkewPolicy.BUILDER.build().runCompileTimeOptimization(dag,
DAG.EMPTY_DAG_DIRECTORY);
+ DataSkewPolicy.BUILDER.build().runCompileTimeOptimization(dagForSkew,
DAG.EMPTY_DAG_DIRECTORY);
}
@Test
@@ -110,6 +115,6 @@ public final class PolicyImplTest {
// This should throw an exception.
// DataSizeMetricCollection is not compatible with Push (All data have to
be stored before the data collection).
expectedException.expect(CompileTimeOptimizationException.class);
- combinedPolicy.runCompileTimeOptimization(dag, DAG.EMPTY_DAG_DIRECTORY);
+ combinedPolicy.runCompileTimeOptimization(dagForSkew,
DAG.EMPTY_DAG_DIRECTORY);
}
}
diff --git
a/compiler/test/src/test/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/composite/SkewCompositePassTest.java
b/compiler/test/src/test/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/composite/SkewCompositePassTest.java
index 1cd8996..b04b849 100644
---
a/compiler/test/src/test/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/composite/SkewCompositePassTest.java
+++
b/compiler/test/src/test/java/edu/snu/nemo/compiler/optimizer/pass/compiletime/composite/SkewCompositePassTest.java
@@ -18,12 +18,13 @@ package
edu.snu.nemo.compiler.optimizer.pass.compiletime.composite;
import edu.snu.nemo.client.JobLauncher;
import edu.snu.nemo.common.dag.DAG;
import edu.snu.nemo.common.ir.edge.IREdge;
+import
edu.snu.nemo.common.ir.edge.executionproperty.AdditionalOutputTagProperty;
import
edu.snu.nemo.common.ir.edge.executionproperty.CommunicationPatternProperty;
-import edu.snu.nemo.common.ir.edge.executionproperty.MetricCollectionProperty;
-import edu.snu.nemo.common.ir.edge.executionproperty.PartitionerProperty;
import edu.snu.nemo.common.ir.vertex.IRVertex;
-import edu.snu.nemo.common.ir.vertex.MetricCollectionBarrierVertex;
import edu.snu.nemo.common.ir.executionproperty.ExecutionProperty;
+import edu.snu.nemo.common.ir.vertex.OperatorVertex;
+import edu.snu.nemo.common.ir.vertex.transform.MetricCollectTransform;
+import edu.snu.nemo.common.ir.vertex.transform.AggregateMetricTransform;
import edu.snu.nemo.compiler.CompilerTestUtil;
import
edu.snu.nemo.common.ir.vertex.executionproperty.ResourceSkewedDataProperty;
import
edu.snu.nemo.compiler.optimizer.pass.compiletime.annotating.AnnotatingPass;
@@ -35,7 +36,6 @@ import org.powermock.modules.junit4.PowerMockRunner;
import java.util.HashSet;
import java.util.List;
-import java.util.Optional;
import java.util.Set;
import static org.junit.Assert.assertEquals;
@@ -48,7 +48,7 @@ import static org.junit.Assert.assertTrue;
@PrepareForTest(JobLauncher.class)
public class SkewCompositePassTest {
private DAG<IRVertex, IREdge> mrDAG;
- private static final long NUM_OF_PASSES_IN_DATA_SKEW_PASS = 5;
+ private static final long NUM_OF_PASSES_IN_DATA_SKEW_PASS = 3;
@Before
public void setUp() throws Exception {
@@ -74,37 +74,30 @@ public class SkewCompositePassTest {
}
/**
- * Test for {@link SkewCompositePass} with MR workload. It must insert a
{@link MetricCollectionBarrierVertex}
- * before each shuffle edge.
+ * Test for {@link SkewCompositePass} with MR workload.
+ * It should have inserted vertex with {@link MetricCollectTransform}
+ * and vertex with {@link AggregateMetricTransform}
+ * before each shuffle edge with no additional output tags.
* @throws Exception exception on the way.
*/
@Test
public void testDataSkewPass() throws Exception {
mrDAG = CompilerTestUtil.compileWordCountDAG();
final Integer originalVerticesNum = mrDAG.getVertices().size();
- final Long numOfShuffleGatherEdges =
mrDAG.getVertices().stream().filter(irVertex ->
+ final Long numOfShuffleEdgesWithOutAdditionalOutputTag =
+ mrDAG.getVertices().stream().filter(irVertex ->
mrDAG.getIncomingEdgesOf(irVertex).stream().anyMatch(irEdge ->
- CommunicationPatternProperty.Value.Shuffle
-
.equals(irEdge.getPropertyValue(CommunicationPatternProperty.class).get())))
- .count();
+ CommunicationPatternProperty.Value.Shuffle
+
.equals(irEdge.getPropertyValue(CommunicationPatternProperty.class).get())
+ &&
!irEdge.getPropertyValue(AdditionalOutputTagProperty.class).isPresent()))
+ .count();
final DAG<IRVertex, IREdge> processedDAG = new
SkewCompositePass().apply(mrDAG);
+ assertEquals(originalVerticesNum +
numOfShuffleEdgesWithOutAdditionalOutputTag * 2,
+ processedDAG.getVertices().size());
- assertEquals(originalVerticesNum + numOfShuffleGatherEdges,
processedDAG.getVertices().size());
- processedDAG.getVertices().stream().map(processedDAG::getIncomingEdgesOf)
- .flatMap(List::stream)
- .filter(irEdge -> CommunicationPatternProperty.Value.Shuffle
-
.equals(irEdge.getPropertyValue(CommunicationPatternProperty.class).get()))
- .map(IREdge::getSrc)
- .forEach(irVertex -> assertTrue(irVertex instanceof
MetricCollectionBarrierVertex));
-
- processedDAG.getVertices().forEach(v ->
processedDAG.getOutgoingEdgesOf(v).stream()
- .filter(e ->
Optional.of(MetricCollectionProperty.Value.DataSkewRuntimePass)
- .equals(e.getPropertyValue(MetricCollectionProperty.class)))
- .forEach(e ->
assertEquals(PartitionerProperty.Value.DataSkewHashPartitioner,
- e.getPropertyValue(PartitionerProperty.class).get())));
-
- processedDAG.filterVertices(v -> v instanceof
MetricCollectionBarrierVertex)
- .forEach(metricV -> {
+ processedDAG.filterVertices(v -> v instanceof OperatorVertex
+ && ((OperatorVertex) v).getTransform() instanceof MetricCollectTransform)
+ .forEach(metricV -> {
List<IRVertex> reducerV = processedDAG.getChildren(metricV.getId());
reducerV.forEach(rV ->
assertTrue(rV.getPropertyValue(ResourceSkewedDataProperty.class).get()));
});
diff --git
a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/optimizer/RunTimeOptimizer.java
b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/optimizer/RunTimeOptimizer.java
index 0d4cfa5..0de585f 100644
---
a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/optimizer/RunTimeOptimizer.java
+++
b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/optimizer/RunTimeOptimizer.java
@@ -46,7 +46,7 @@ public final class RunTimeOptimizer {
// is a map of <hash value, partition size>.
final PhysicalPlan physicalPlan =
new DataSkewRuntimePass()
- .apply(originalPlan, Pair.of(targetEdge, (Map<Integer, Long>)
dynOptData));
+ .apply(originalPlan, Pair.of(targetEdge, (Map<Object, Long>)
dynOptData));
return physicalPlan;
}
}
diff --git
a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/optimizer/pass/runtime/DataSkewRuntimePass.java
b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/optimizer/pass/runtime/DataSkewRuntimePass.java
index f277af1..a2b05f5 100644
---
a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/optimizer/pass/runtime/DataSkewRuntimePass.java
+++
b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/optimizer/pass/runtime/DataSkewRuntimePass.java
@@ -39,11 +39,11 @@ import java.util.stream.Collectors;
* this RuntimePass identifies a number of keys with big partition
sizes(skewed key)
* and evenly redistributes data via overwriting incoming edges of destination
tasks.
*/
-public final class DataSkewRuntimePass extends RuntimePass<Pair<StageEdge,
Map<Integer, Long>>> {
+public final class DataSkewRuntimePass extends RuntimePass<Pair<StageEdge,
Map<Object, Long>>> {
private static final Logger LOG =
LoggerFactory.getLogger(DataSkewRuntimePass.class.getName());
private final Set<Class<? extends RuntimeEventHandler>> eventHandlers;
// Skewed keys denote for top n keys in terms of partition size.
- public static final int DEFAULT_NUM_SKEWED_KEYS = 3;
+ public static final int DEFAULT_NUM_SKEWED_KEYS = 1;
private int numSkewedKeys;
/**
@@ -71,7 +71,7 @@ public final class DataSkewRuntimePass extends
RuntimePass<Pair<StageEdge, Map<I
@Override
public PhysicalPlan apply(final PhysicalPlan originalPlan,
- final Pair<StageEdge, Map<Integer, Long>>
metricData) {
+ final Pair<StageEdge, Map<Object, Long>>
metricData) {
final StageEdge targetEdge = metricData.left();
// Get number of evaluators of the next stage (number of blocks).
final Integer dstParallelism =
targetEdge.getDst().getPropertyValue(ParallelismProperty.class).
@@ -98,27 +98,29 @@ public final class DataSkewRuntimePass extends
RuntimePass<Pair<StageEdge, Map<I
return new PhysicalPlan(originalPlan.getPlanId(), stageDAG);
}
- public List<Integer> identifySkewedKeys(final Map<Integer, Long>
keyValToPartitionSizeMap) {
+ public List<Long> identifySkewedKeys(final List<Long> partitionSizeList) {
// Identify skewed keys.
- List<Map.Entry<Integer, Long>> sortedMetricData =
keyValToPartitionSizeMap.entrySet().stream()
- .sorted((e1, e2) -> e2.getValue().compareTo(e1.getValue()))
+ List<Long> sortedMetricData = partitionSizeList.stream()
+ .sorted(Comparator.reverseOrder())
.collect(Collectors.toList());
- List<Integer> skewedKeys = new ArrayList<>();
+ List<Long> skewedSizes = new ArrayList<>();
for (int i = 0; i < numSkewedKeys; i++) {
- skewedKeys.add(sortedMetricData.get(i).getKey());
- LOG.info("Skewed key: Key {} Size {}", sortedMetricData.get(i).getKey(),
sortedMetricData.get(i).getValue());
+ skewedSizes.add(sortedMetricData.get(i));
+ LOG.info("Skewed size: {}", sortedMetricData.get(i));
}
- return skewedKeys;
+ return skewedSizes;
}
- private boolean containsSkewedKey(final List<Integer> skewedKeys,
- final int startingKey, final int
finishingKey) {
- for (int k = startingKey; k < finishingKey; k++) {
- if (skewedKeys.contains(k)) {
+ private boolean containsSkewedSize(final List<Long> partitionSizeList,
+ final List<Long> skewedKeys,
+ final int startingKey, final int
finishingKey) {
+ for (int i = startingKey; i < finishingKey; i++) {
+ if (skewedKeys.contains(partitionSizeList.get(i))) {
return true;
}
}
+
return false;
}
@@ -133,24 +135,25 @@ public final class DataSkewRuntimePass extends
RuntimePass<Pair<StageEdge, Map<I
* @return the list of key ranges calculated.
*/
@VisibleForTesting
- public List<KeyRange> calculateKeyRanges(final Map<Integer, Long>
keyToPartitionSizeMap,
+ public List<KeyRange> calculateKeyRanges(final Map<Object, Long>
keyToPartitionSizeMap,
final Integer dstParallelism) {
- // Get the last key.
- final int lastKey = keyToPartitionSizeMap.keySet().stream()
- .max(Integer::compareTo)
- .get();
+ final List<Long> partitionSizeList = new ArrayList<>();
+ keyToPartitionSizeMap.forEach((k, v) -> partitionSizeList.add(v));
+
+ // Get the last index.
+ final int lastKey = partitionSizeList.size() - 1;
- // Identify skewed keys, which is top numSkewedKeys number of keys.
- List<Integer> skewedKeys = identifySkewedKeys(keyToPartitionSizeMap);
+ // Identify skewed sizes, which is top numSkewedKeys number of keys.
+ List<Long> skewedSizes = identifySkewedKeys(partitionSizeList);
// Calculate the ideal size for each destination task.
- final Long totalSize = keyToPartitionSizeMap.values().stream().mapToLong(n
-> n).sum(); // get total size
+ final Long totalSize = partitionSizeList.stream().mapToLong(n -> n).sum();
// get total size
final Long idealSizePerTask = totalSize / dstParallelism; // and derive
the ideal size per task
final List<KeyRange> keyRanges = new ArrayList<>(dstParallelism);
int startingKey = 0;
int finishingKey = 1;
- Long currentAccumulatedSize =
keyToPartitionSizeMap.getOrDefault(startingKey, 0L);
+ Long currentAccumulatedSize = partitionSizeList.get(startingKey);
Long prevAccumulatedSize = 0L;
for (int i = 1; i <= dstParallelism; i++) {
if (i != dstParallelism) {
@@ -158,21 +161,21 @@ public final class DataSkewRuntimePass extends
RuntimePass<Pair<StageEdge, Map<I
final Long idealAccumulatedSize = idealSizePerTask * i;
// By adding partition sizes, find the accumulated size nearest to the
given ideal size.
while (currentAccumulatedSize < idealAccumulatedSize) {
- currentAccumulatedSize +=
keyToPartitionSizeMap.getOrDefault(finishingKey, 0L);
+ currentAccumulatedSize += partitionSizeList.get(finishingKey);
finishingKey++;
}
final Long oneStepBack =
- currentAccumulatedSize -
keyToPartitionSizeMap.getOrDefault(finishingKey - 1, 0L);
+ currentAccumulatedSize - partitionSizeList.get(finishingKey - 1);
final Long diffFromIdeal = currentAccumulatedSize -
idealAccumulatedSize;
final Long diffFromIdealOneStepBack = idealAccumulatedSize -
oneStepBack;
// Go one step back if we came too far.
if (diffFromIdeal > diffFromIdealOneStepBack) {
finishingKey--;
- currentAccumulatedSize -=
keyToPartitionSizeMap.getOrDefault(finishingKey, 0L);
+ currentAccumulatedSize -= partitionSizeList.get(finishingKey);
}
- boolean isSkewedKey = containsSkewedKey(skewedKeys, startingKey,
finishingKey);
+ boolean isSkewedKey = containsSkewedSize(partitionSizeList,
skewedSizes, startingKey, finishingKey);
keyRanges.add(i - 1, HashRange.of(startingKey, finishingKey,
isSkewedKey));
LOG.debug("KeyRange {}~{}, Size {}", startingKey, finishingKey - 1,
currentAccumulatedSize - prevAccumulatedSize);
@@ -180,12 +183,12 @@ public final class DataSkewRuntimePass extends
RuntimePass<Pair<StageEdge, Map<I
prevAccumulatedSize = currentAccumulatedSize;
startingKey = finishingKey;
} else { // last one: we put the range of the rest.
- boolean isSkewedKey = containsSkewedKey(skewedKeys, startingKey,
lastKey + 1);
+ boolean isSkewedKey = containsSkewedSize(partitionSizeList,
skewedSizes, startingKey, lastKey + 1);
keyRanges.add(i - 1,
HashRange.of(startingKey, lastKey + 1, isSkewedKey));
while (finishingKey <= lastKey) {
- currentAccumulatedSize +=
keyToPartitionSizeMap.getOrDefault(finishingKey, 0L);
+ currentAccumulatedSize += partitionSizeList.get(finishingKey);
finishingKey++;
}
LOG.debug("KeyRange {}~{}, Size {}", startingKey, lastKey + 1,
diff --git
a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/PhysicalPlanGenerator.java
b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/PhysicalPlanGenerator.java
index b1fa4fb..6ecb7c4 100644
---
a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/PhysicalPlanGenerator.java
+++
b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/PhysicalPlanGenerator.java
@@ -244,8 +244,7 @@ public final class PhysicalPlanGenerator implements
Function<DAG<IRVertex, IREdg
stage.getIRDAG().getVertices().forEach(irVertex -> {
// Check vertex type.
if (!(irVertex instanceof SourceVertex
- || irVertex instanceof OperatorVertex
- || irVertex instanceof MetricCollectionBarrierVertex)) {
+ || irVertex instanceof OperatorVertex)) {
throw new UnsupportedOperationException(irVertex.toString());
}
});
diff --git
a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/StagePartitioner.java
b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/StagePartitioner.java
index 2fd5c8b..588e5d4 100644
---
a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/StagePartitioner.java
+++
b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/StagePartitioner.java
@@ -75,7 +75,7 @@ public final class StagePartitioner implements
Function<DAG<IRVertex, IREdge>, M
}
// Get stage id of irVertex
final int stageId = vertexToStageIdMap.get(irVertex);
- // Step case: inductively assign stage ids based on mergability with
irVertex
+ // Step case: inductively assign stage ids based on mergeability with
irVertex
for (final IREdge edge : irDAG.getOutgoingEdgesOf(irVertex)) {
final IRVertex connectedIRVertex = edge.getDst();
// Skip if it already has been assigned stageId
diff --git a/runtime/common/src/main/proto/ControlMessage.proto
b/runtime/common/src/main/proto/ControlMessage.proto
index 3c6bb8e..0ad5a29 100644
--- a/runtime/common/src/main/proto/ControlMessage.proto
+++ b/runtime/common/src/main/proto/ControlMessage.proto
@@ -120,12 +120,11 @@ message BlockStateChangedMsg {
}
message DataSizeMetricMsg {
- // TODO #96: Modularize DataSkewPolicy to use MetricVertex and
BarrierVertex.
repeated PartitionSizeEntry partitionSize = 1;
}
message PartitionSizeEntry {
- required int32 key = 1;
+ required string key = 1;
required int64 size = 2;
}
diff --git
a/runtime/common/src/test/java/edu/snu/nemo/runtime/common/optimizer/pass/runtime/DataSkewRuntimePassTest.java
b/runtime/common/src/test/java/edu/snu/nemo/runtime/common/optimizer/pass/runtime/DataSkewRuntimePassTest.java
index 04fca2b..d8948cf 100644
---
a/runtime/common/src/test/java/edu/snu/nemo/runtime/common/optimizer/pass/runtime/DataSkewRuntimePassTest.java
+++
b/runtime/common/src/test/java/edu/snu/nemo/runtime/common/optimizer/pass/runtime/DataSkewRuntimePassTest.java
@@ -28,7 +28,7 @@ import static org.junit.Assert.assertEquals;
* Test {@link DataSkewRuntimePass}.
*/
public class DataSkewRuntimePassTest {
- private final Map<Integer, Long> testMetricData = new HashMap<>();
+ private final Map<Object, Long> testMetricData = new HashMap<>();
@Before
public void setUp() {
diff --git
a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/data/BlockManagerWorker.java
b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/data/BlockManagerWorker.java
index 692d01f..3bb8cee 100644
---
a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/data/BlockManagerWorker.java
+++
b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/data/BlockManagerWorker.java
@@ -232,15 +232,11 @@ public final class BlockManagerWorker {
*
* @param block the block to write.
* @param blockStore the store to save the block.
- * @param reportPartitionSizes whether report the size of partitions to
master or not.
- * @param partitionSizeMap the map of partition keys and sizes to report.
* @param expectedReadTotal the expected number of read for this block.
* @param persistence how to handle the used block.
*/
public void writeBlock(final Block block,
final DataStoreProperty.Value blockStore,
- final boolean reportPartitionSizes,
- final Map<Integer, Long> partitionSizeMap,
final int expectedReadTotal,
final DataPersistenceProperty.Value persistence) {
final String blockId = block.getId();
@@ -278,28 +274,6 @@ public final class BlockManagerWorker {
.setType(ControlMessage.MessageType.BlockStateChanged)
.setBlockStateChangedMsg(blockStateChangedMsgBuilder.build())
.build());
-
- if (reportPartitionSizes) {
- final List<ControlMessage.PartitionSizeEntry> partitionSizeEntries = new
ArrayList<>();
- partitionSizeMap.forEach((key, size) ->
- partitionSizeEntries.add(
- ControlMessage.PartitionSizeEntry.newBuilder()
- .setKey(key)
- .setSize(size)
- .build())
- );
-
- // TODO #4: Refactor metric aggregation for (general) run-rime
optimization.
-
persistentConnectionToMasterMap.getMessageSender(MessageEnvironment.RUNTIME_MASTER_MESSAGE_LISTENER_ID)
- .send(ControlMessage.Message.newBuilder()
- .setId(RuntimeIdManager.generateMessageId())
-
.setListenerId(MessageEnvironment.RUNTIME_MASTER_MESSAGE_LISTENER_ID)
- .setType(ControlMessage.MessageType.DataSizeMetric)
-
.setDataSizeMetricMsg(ControlMessage.DataSizeMetricMsg.newBuilder()
- .addAllPartitionSize(partitionSizeEntries)
- )
- .build());
- }
}
/**
diff --git
a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/data/block/FileBlock.java
b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/data/block/FileBlock.java
index be42f6b..1cee827 100644
---
a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/data/block/FileBlock.java
+++
b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/data/block/FileBlock.java
@@ -26,6 +26,8 @@ import
edu.snu.nemo.runtime.executor.data.partition.SerializedPartition;
import edu.snu.nemo.runtime.executor.data.streamchainer.Serializer;
import edu.snu.nemo.runtime.executor.data.metadata.PartitionMetadata;
import edu.snu.nemo.runtime.executor.data.metadata.FileMetadata;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
import javax.annotation.concurrent.NotThreadSafe;
import java.io.*;
@@ -41,7 +43,7 @@ import java.util.*;
*/
@NotThreadSafe
public final class FileBlock<K extends Serializable> implements Block<K> {
-
+ private static final Logger LOG =
LoggerFactory.getLogger(FileBlock.class.getName());
private final String id;
private final Map<K, SerializedPartition<K>> nonCommittedPartitionsMap;
private final Serializer serializer;
diff --git
a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/data/partition/SerializedPartition.java
b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/data/partition/SerializedPartition.java
index 4015000..16c1259 100644
---
a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/data/partition/SerializedPartition.java
+++
b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/data/partition/SerializedPartition.java
@@ -18,6 +18,8 @@ package edu.snu.nemo.runtime.executor.data.partition;
import edu.snu.nemo.common.DirectByteArrayOutputStream;
import edu.snu.nemo.common.coder.EncoderFactory;
import edu.snu.nemo.runtime.executor.data.streamchainer.Serializer;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
import javax.annotation.Nullable;
import java.io.IOException;
@@ -31,6 +33,8 @@ import static
edu.snu.nemo.runtime.executor.data.DataUtil.buildOutputStream;
* @param <K> the key type of its partitions.
*/
public final class SerializedPartition<K> implements Partition<byte[], K> {
+ private static final Logger LOG =
LoggerFactory.getLogger(SerializedPartition.class.getName());
+
private final K key;
private volatile byte[] serializedData;
private volatile int length;
diff --git
a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/datatransfer/OutputCollectorImpl.java
b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/datatransfer/OutputCollectorImpl.java
index 0796b01..260f126 100644
---
a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/datatransfer/OutputCollectorImpl.java
+++
b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/datatransfer/OutputCollectorImpl.java
@@ -15,7 +15,10 @@
*/
package edu.snu.nemo.runtime.executor.datatransfer;
+import edu.snu.nemo.common.Pair;
import edu.snu.nemo.common.ir.OutputCollector;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
import java.util.*;
@@ -25,22 +28,26 @@ import java.util.*;
* @param <O> output type.
*/
public final class OutputCollectorImpl<O> implements OutputCollector<O> {
+ private static final Logger LOG =
LoggerFactory.getLogger(OutputCollectorImpl.class.getName());
private final Set<String> mainTagOutputChildren;
// Use ArrayList (not Queue) to allow 'null' values
private final ArrayList<O> mainTagElements;
- private final Map<String, ArrayList<Object>> additionalTagElementsMap;
+ // Key: Pair of tag and destination vertex id
+ // Value: data elements which will be input to the tagged destination vertex
+ private final Map<Pair<String, String>, ArrayList<Object>>
additionalTaggedChildToElementsMap;
/**
* Constructor of a new OutputCollectorImpl with tagged outputs.
* @param mainChildren main children vertices
- * @param taggedChildren additional children vertices
+ * @param tagToChildren additional children vertices
*/
public OutputCollectorImpl(final Set<String> mainChildren,
- final List<String> taggedChildren) {
+ final Map<String, String> tagToChildren) {
this.mainTagOutputChildren = mainChildren;
this.mainTagElements = new ArrayList<>(1);
- this.additionalTagElementsMap = new HashMap<>();
- taggedChildren.forEach(child -> this.additionalTagElementsMap.put(child,
new ArrayList<>(1)));
+ this.additionalTaggedChildToElementsMap = new HashMap<>();
+ tagToChildren.forEach((tag, child) ->
+ this.additionalTaggedChildToElementsMap.put(Pair.of(tag, child), new
ArrayList<>(1)));
}
@Override
@@ -55,10 +62,7 @@ public final class OutputCollectorImpl<O> implements
OutputCollector<O> {
emit((O) output);
} else {
// Note that String#hashCode() can be cached, thus accessing additional
output queues can be fast.
- final List<Object> dataElements =
this.additionalTagElementsMap.get(dstVertexId);
- if (dataElements == null) {
- throw new IllegalArgumentException("Wrong destination vertex id
passed!");
- }
+ final List<Object> dataElements =
getAdditionalTaggedDataFromDstVertexId(dstVertexId);
dataElements.add(output);
}
}
@@ -72,11 +76,7 @@ public final class OutputCollectorImpl<O> implements
OutputCollector<O> {
// This dstVertexId is for the main tag
return (Iterable<Object>) iterateMain();
} else {
- final List<Object> dataElements = this.additionalTagElementsMap.get(tag);
- if (dataElements == null) {
- throw new IllegalArgumentException("Wrong destination vertex id
passed!");
- }
- return dataElements;
+ return getAdditionalTaggedDataFromTag(tag);
}
}
@@ -90,10 +90,7 @@ public final class OutputCollectorImpl<O> implements
OutputCollector<O> {
clearMain();
} else {
// Note that String#hashCode() can be cached, thus accessing additional
output queues can be fast.
- final List<Object> dataElements = this.additionalTagElementsMap.get(tag);
- if (dataElements == null) {
- throw new IllegalArgumentException("Wrong destination vertex id
passed!");
- }
+ final List<Object> dataElements = getAdditionalTaggedDataFromTag(tag);
dataElements.clear();
}
}
@@ -106,11 +103,31 @@ public final class OutputCollectorImpl<O> implements
OutputCollector<O> {
if (this.mainTagOutputChildren.contains(dstVertexId)) {
return (List<Object>) this.mainTagElements;
} else {
- final List<Object> dataElements =
this.additionalTagElementsMap.get(dstVertexId);
- if (dataElements == null) {
- throw new IllegalArgumentException("Wrong destination vertex id
passed!");
- }
- return dataElements;
+ return getAdditionalTaggedDataFromDstVertexId(dstVertexId);
}
}
+
+ private List<Object> getAdditionalTaggedDataFromDstVertexId(final String
dstVertexId) {
+ final Pair<String, String> tagAndChild =
+ this.additionalTaggedChildToElementsMap.keySet().stream()
+ .filter(key -> key.right().equals(dstVertexId))
+ .findAny().orElseThrow(() -> new RuntimeException("Wrong destination
vertex id passed!"));
+ final List<Object> dataElements =
this.additionalTaggedChildToElementsMap.get(tagAndChild);
+ if (dataElements == null) {
+ throw new IllegalArgumentException("Wrong destination vertex id
passed!");
+ }
+ return dataElements;
+ }
+
+ private List<Object> getAdditionalTaggedDataFromTag(final String tag) {
+ final Pair<String, String> tagAndChild =
+ this.additionalTaggedChildToElementsMap.keySet().stream()
+ .filter(key -> key.left().equals(tag))
+ .findAny().orElseThrow(() -> new RuntimeException("Wrong tag " + tag +
" passed!"));
+ final List<Object> dataElements =
this.additionalTaggedChildToElementsMap.get(tagAndChild);
+ if (dataElements == null) {
+ throw new IllegalArgumentException("Wrong tag " + tag + " passed!");
+ }
+ return dataElements;
+ }
}
diff --git
a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/datatransfer/OutputWriter.java
b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/datatransfer/OutputWriter.java
index 162f491..b2f2b9a 100644
---
a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/datatransfer/OutputWriter.java
+++
b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/datatransfer/OutputWriter.java
@@ -62,7 +62,6 @@ public final class OutputWriter extends DataTransfer
implements AutoCloseable {
this.blockStoreValue =
runtimeEdge.getPropertyValue(DataStoreProperty.class).
orElseThrow(() -> new RuntimeException("No data store property on the
edge"));
-
// Setup partitioner
final int dstParallelism =
dstIrVertex.getPropertyValue(ParallelismProperty.class).
orElseThrow(() -> new RuntimeException("No parallelism property on the
destination vertex"));
@@ -125,8 +124,6 @@ public final class OutputWriter extends DataTransfer
implements AutoCloseable {
runtimeEdge.getPropertyValue(DataPersistenceProperty.class).
orElseThrow(() -> new RuntimeException("No data persistence
property on the edge"));
- final boolean isDataSizeMetricCollectionEdge =
Optional.of(MetricCollectionProperty.Value.DataSkewRuntimePass)
- .equals(runtimeEdge.getPropertyValue(MetricCollectionProperty.class));
final Optional<Map<Integer, Long>> partitionSizeMap =
blockToWrite.commit();
// Return the total size of the committed block.
if (partitionSizeMap.isPresent()) {
@@ -135,13 +132,10 @@ public final class OutputWriter extends DataTransfer
implements AutoCloseable {
blockSizeTotal += partitionSize;
}
this.writtenBytes = blockSizeTotal;
- blockManagerWorker.writeBlock(blockToWrite, blockStoreValue,
isDataSizeMetricCollectionEdge,
- partitionSizeMap.get(), getExpectedRead(), persistence);
} else {
this.writtenBytes = -1; // no written bytes info.
- blockManagerWorker.writeBlock(blockToWrite, blockStoreValue,
isDataSizeMetricCollectionEdge,
- Collections.emptyMap(), getExpectedRead(), persistence);
}
+ blockManagerWorker.writeBlock(blockToWrite, blockStoreValue,
getExpectedRead(), persistence);
}
/**
diff --git
a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/TaskExecutor.java
b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/TaskExecutor.java
index 4b156d9..ca5476a 100644
---
a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/TaskExecutor.java
+++
b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/TaskExecutor.java
@@ -22,6 +22,7 @@ import edu.snu.nemo.common.ir.Readable;
import
edu.snu.nemo.common.ir.edge.executionproperty.AdditionalOutputTagProperty;
import
edu.snu.nemo.common.ir.edge.executionproperty.BroadcastVariableIdProperty;
import edu.snu.nemo.common.ir.vertex.*;
+import edu.snu.nemo.common.ir.vertex.transform.AggregateMetricTransform;
import edu.snu.nemo.common.ir.vertex.transform.Transform;
import edu.snu.nemo.runtime.common.RuntimeIdManager;
import edu.snu.nemo.runtime.common.comm.ControlMessage;
@@ -56,6 +57,7 @@ import javax.annotation.concurrent.NotThreadSafe;
public final class TaskExecutor {
private static final Logger LOG =
LoggerFactory.getLogger(TaskExecutor.class.getName());
private static final int NONE_FINISHED = -1;
+ private static final String NULL_KEY = "NULL";
// Essential information
private boolean isExecuted;
@@ -156,7 +158,15 @@ public final class TaskExecutor {
// Prepare data WRITE
// Child-task writes
final Map<String, String> additionalOutputMap =
- getAdditionalOutputMap(irVertex, task.getTaskOutgoingEdges(),
irVertexDag);
+ getAdditionalOutputMap(irVertex, task.getTaskOutgoingEdges(),
irVertexDag);
+
+ final List<Boolean> isToAdditionalTagOutputs = children.stream()
+ .map(harness -> harness.getIRVertex().getId())
+ .map(additionalOutputMap::containsValue)
+ .collect(Collectors.toList());
+
+ // Handle writes
+ // Main output children task writes
final List<OutputWriter> mainChildrenTaskWriters =
getMainChildrenTaskWriters(
irVertex, task.getTaskOutgoingEdges(), dataTransferFactory,
additionalOutputMap);
final Map<String, OutputWriter> additionalChildrenTaskWriters =
getAdditionalChildrenTaskWriters(
@@ -164,12 +174,8 @@ public final class TaskExecutor {
// Intra-task writes
final List<String> additionalOutputVertices = new
ArrayList<>(additionalOutputMap.values());
final Set<String> mainChildren =
- getMainOutputVertices(irVertex, irVertexDag,
task.getTaskOutgoingEdges(), additionalOutputVertices);
- final OutputCollectorImpl oci = new OutputCollectorImpl(mainChildren,
additionalOutputVertices);
- final List<Boolean> isToAdditionalTagOutputs = children.stream()
- .map(harness -> harness.getIRVertex().getId())
- .map(additionalOutputMap::containsValue)
- .collect(Collectors.toList());
+ getMainOutputVertices(irVertex, irVertexDag,
task.getTaskOutgoingEdges(), additionalOutputVertices);
+ final OutputCollectorImpl oci = new OutputCollectorImpl(mainChildren,
additionalOutputMap);
// Create VERTEX HARNESS
final VertexHarness vertexHarness = new VertexHarness(
@@ -231,27 +237,26 @@ public final class TaskExecutor {
private void processElementRecursively(final VertexHarness vertexHarness,
final Object dataElement) {
final IRVertex irVertex = vertexHarness.getIRVertex();
final OutputCollectorImpl outputCollector =
vertexHarness.getOutputCollector();
+
if (irVertex instanceof SourceVertex) {
outputCollector.emit(dataElement);
} else if (irVertex instanceof OperatorVertex) {
final Transform transform = ((OperatorVertex) irVertex).getTransform();
transform.onData(dataElement);
- } else if (irVertex instanceof MetricCollectionBarrierVertex) {
- outputCollector.emit(dataElement);
- setIRVertexPutOnHold((MetricCollectionBarrierVertex) irVertex);
} else {
throw new UnsupportedOperationException("This type of IRVertex is not
supported");
}
// Given a single input element, a vertex can produce many output elements.
- // Here, we recursively process all of the main oltput elements.
- outputCollector.iterateMain().forEach(element ->
handleMainOutputElement(vertexHarness, element)); // Recursion
+ // Here, we recursively process all of the main output elements.
+ outputCollector.iterateMain().forEach(element ->
+ handleMainOutputElement(vertexHarness, element)); // Recursion
outputCollector.clearMain();
// Recursively process all of the additional output elements.
-
vertexHarness.getContext().getTagToAdditionalChildren().values().forEach(tag ->
{
- outputCollector.iterateTag(tag).forEach(
- element -> handleAdditionalOutputElement(vertexHarness, element,
tag)); // Recursion
+ vertexHarness.getAdditionalTagOutputChildren().keySet().forEach(tag -> {
+ outputCollector.iterateTag(tag).forEach(element ->
+ handleAdditionalOutputElement(vertexHarness, element, tag)); //
Recursion
outputCollector.clearTag(tag);
});
}
@@ -310,21 +315,67 @@ public final class TaskExecutor {
}
}
+ /**
+ * Send aggregated statistics for dynamic optimization to master.
+ * @param dynOptData the statistics to send.
+ */
+ public void sendDynOptData(final Object dynOptData) {
+ Map<Object, Long> aggregatedDynOptData = (Map<Object, Long>) dynOptData;
+ final List<ControlMessage.PartitionSizeEntry> partitionSizeEntries = new
ArrayList<>();
+ aggregatedDynOptData.forEach((key, size) ->
+ partitionSizeEntries.add(
+ ControlMessage.PartitionSizeEntry.newBuilder()
+ .setKey(key == null ? NULL_KEY : String.valueOf(key))
+ .setSize(size)
+ .build())
+ );
+
+
persistentConnectionToMasterMap.getMessageSender(MessageEnvironment.RUNTIME_MASTER_MESSAGE_LISTENER_ID)
+ .send(ControlMessage.Message.newBuilder()
+ .setId(RuntimeIdManager.generateMessageId())
+ .setListenerId(MessageEnvironment.RUNTIME_MASTER_MESSAGE_LISTENER_ID)
+ .setType(ControlMessage.MessageType.DataSizeMetric)
+ .setDataSizeMetricMsg(ControlMessage.DataSizeMetricMsg.newBuilder()
+ .addAllPartitionSize(partitionSizeEntries)
+ )
+ .build());
+ }
+
private void finalizeVertex(final VertexHarness vertexHarness) {
closeTransform(vertexHarness);
- final OutputCollectorImpl outputCollector =
vertexHarness.getOutputCollector();
- // handle main outputs
- outputCollector.iterateMain().forEach(element ->
handleMainOutputElement(vertexHarness, element)); // Recursion
- outputCollector.clearMain();
-
- // handle additional tagged outputs
- vertexHarness.getAdditionalTagOutputChildren().keySet().forEach(tag -> {
- outputCollector.iterateTag(tag).forEach(
- element -> handleAdditionalOutputElement(vertexHarness, element,
tag)); // Recursion
- outputCollector.clearTag(tag);
- });
- finalizeOutputWriters(vertexHarness);
+ final OutputCollectorImpl outputCollector =
vertexHarness.getOutputCollector();
+ final IRVertex v = vertexHarness.getIRVertex();
+ if (v instanceof OperatorVertex
+ && ((OperatorVertex) v).getTransform() instanceof
AggregateMetricTransform) {
+ // send aggregated dynamic optimization data to master
+ final Object aggregatedDynOptData =
outputCollector.iterateMain().iterator().next();
+ sendDynOptData(aggregatedDynOptData);
+ // set the id of this vertex to mark the corresponding stage as put on
hold
+ setIRVertexPutOnHold(v);
+ } else {
+ // handle main outputs
+ outputCollector.iterateMain().forEach(element -> {
+ handleMainOutputElement(vertexHarness, element);
+ }); // Recursion
+ outputCollector.clearMain();
+
+ // handle intra-task additional tagged outputs
+ vertexHarness.getAdditionalTagOutputChildren().keySet().forEach(tag -> {
+ outputCollector.iterateTag(tag).forEach(
+ element -> handleAdditionalOutputElement(vertexHarness, element,
tag)); // Recursion
+ outputCollector.clearTag(tag);
+ });
+
+ // handle inter-task additional tagged outputs
+ vertexHarness.getTagToAdditionalChildrenId().keySet().forEach(tag -> {
+ outputCollector.iterateTag(tag).forEach(
+ element -> handleAdditionalOutputElement(vertexHarness, element,
tag)); // Recursion
+ outputCollector.clearTag(tag);
+ });
+
+ finalizeOutputWriters(vertexHarness);
+ }
}
private void handleMainOutputElement(final VertexHarness harness, final
Object element) {
@@ -357,7 +408,6 @@ public final class TaskExecutor {
for (int i = 0; i < availableFetchers.size(); i++) {
final DataFetcher dataFetcher = availableFetchers.get(i);
final Object element;
-
try {
element = dataFetcher.fetchDataElement();
} catch (NoSuchElementException e) {
@@ -492,7 +542,7 @@ public final class TaskExecutor {
* @param outEdgesToChildrenTasks outgoing edges to child tasks
* @param dataTransferFactory dataTransferFactory
* @param taggedOutputs tag to vertex id map
- * @return additional children vertex id to OutputWriters map.
+ * @return additional tag to OutputWriters map.
*/
private Map<String, OutputWriter> getAdditionalChildrenTaskWriters(final
IRVertex irVertex,
final
List<StageEdge> outEdgesToChildrenTasks,
@@ -501,12 +551,17 @@ public final class TaskExecutor {
final Map<String, OutputWriter> additionalChildrenTaskWriters = new
HashMap<>();
outEdgesToChildrenTasks
- .stream()
- .filter(outEdge ->
outEdge.getSrcIRVertex().getId().equals(irVertex.getId()))
- .filter(outEdge ->
taggedOutputs.containsValue(outEdge.getDstIRVertex().getId()))
- .forEach(outEdgeForThisVertex ->
-
additionalChildrenTaskWriters.put(outEdgeForThisVertex.getDstIRVertex().getId(),
- dataTransferFactory.createWriter(taskId,
outEdgeForThisVertex.getDstIRVertex(), outEdgeForThisVertex)));
+ .stream()
+ .filter(outEdge ->
outEdge.getSrcIRVertex().getId().equals(irVertex.getId()))
+ .filter(outEdge ->
taggedOutputs.containsValue(outEdge.getDstIRVertex().getId()))
+ .forEach(outEdgeForThisVertex -> {
+ final String tag = taggedOutputs.entrySet().stream()
+ .filter(e ->
e.getValue().equals(outEdgeForThisVertex.getDstIRVertex().getId()))
+ .findAny().orElseThrow(() -> new RuntimeException("Unexpected
error while finding tag"))
+ .getKey();
+ additionalChildrenTaskWriters.put(tag,
+ dataTransferFactory.createWriter(taskId,
outEdgeForThisVertex.getDstIRVertex(), outEdgeForThisVertex));
+ });
return additionalChildrenTaskWriters;
}
@@ -530,18 +585,21 @@ public final class TaskExecutor {
private void prepareTransform(final VertexHarness vertexHarness) {
final IRVertex irVertex = vertexHarness.getIRVertex();
+ final Transform transform;
if (irVertex instanceof OperatorVertex) {
- final Transform transform = ((OperatorVertex) irVertex).getTransform();
+ transform = ((OperatorVertex) irVertex).getTransform();
transform.prepare(vertexHarness.getContext(),
vertexHarness.getOutputCollector());
}
}
private void closeTransform(final VertexHarness vertexHarness) {
final IRVertex irVertex = vertexHarness.getIRVertex();
+ final Transform transform;
if (irVertex instanceof OperatorVertex) {
- Transform transform = ((OperatorVertex) irVertex).getTransform();
+ transform = ((OperatorVertex) irVertex).getTransform();
transform.close();
}
+
vertexHarness.getContext().getSerializedData().ifPresent(data ->
persistentConnectionToMasterMap.getMessageSender(MessageEnvironment.RUNTIME_MASTER_MESSAGE_LISTENER_ID).send(
ControlMessage.Message.newBuilder()
@@ -554,7 +612,7 @@ public final class TaskExecutor {
////////////////////////////////////////////// Misc
- private void setIRVertexPutOnHold(final MetricCollectionBarrierVertex
irVertex) {
+ private void setIRVertexPutOnHold(final IRVertex irVertex) {
idOfVertexPutOnHold = irVertex.getId();
}
diff --git
a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/VertexHarness.java
b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/VertexHarness.java
index 2ad8868..f8ea8e0 100644
---
a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/VertexHarness.java
+++
b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/VertexHarness.java
@@ -19,6 +19,8 @@ import edu.snu.nemo.common.ir.vertex.IRVertex;
import edu.snu.nemo.common.ir.vertex.transform.Transform;
import edu.snu.nemo.runtime.executor.datatransfer.OutputCollectorImpl;
import edu.snu.nemo.runtime.executor.datatransfer.OutputWriter;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.HashMap;
@@ -29,6 +31,8 @@ import java.util.Map;
* Captures the relationship between a non-source IRVertex's outputCollector,
and mainTagChildren vertices.
*/
final class VertexHarness {
+ private static final Logger LOG =
LoggerFactory.getLogger(VertexHarness.class.getName());
+
// IRVertex and transform-specific information
private final IRVertex irVertex;
private final OutputCollectorImpl outputCollector;
@@ -37,6 +41,7 @@ final class VertexHarness {
// These lists can be empty
private final Map<String, VertexHarness> additionalTagOutputChildren;
+ private final Map<String, String> tagToAdditionalChildrenId;
private final List<OutputWriter> writersToMainChildrenTasks;
private final Map<String, OutputWriter> writersToAdditionalChildrenTasks;
@@ -54,14 +59,18 @@ final class VertexHarness {
}
final Map<String, String> taggedOutputMap =
context.getTagToAdditionalChildren();
final Map<String, VertexHarness> tagged = new HashMap<>();
+
+ // Classify input type for intra-task children
for (int i = 0; i < children.size(); i++) {
final VertexHarness child = children.get(i);
if (isAdditionalTagOutputs.get(i)) {
taggedOutputMap.entrySet().stream()
- .filter(kv -> child.getIRVertex().getId().equals(kv.getValue()))
- .forEach(kv -> tagged.put(kv.getValue(), child));
+ .filter(kv -> child.getIRVertex().getId().equals(kv.getValue()))
+ .forEach(kv -> tagged.put(kv.getKey(), child));
}
}
+
+ this.tagToAdditionalChildrenId = context.getTagToAdditionalChildren();
this.additionalTagOutputChildren = tagged;
final List<VertexHarness> mainTagChildrenTmp = new ArrayList<>(children);
mainTagChildrenTmp.removeAll(additionalTagOutputChildren.values());
@@ -100,6 +109,13 @@ final class VertexHarness {
}
/**
+ * @return map of tag to additional children id.
+ */
+ public Map<String, String> getTagToAdditionalChildrenId() {
+ return tagToAdditionalChildrenId;
+ }
+
+ /**
* @return OutputWriters for main outputs of this irVertex. (empty if none
exists)
*/
List<OutputWriter> getWritersToMainChildrenTasks() {
diff --git
a/runtime/executor/src/test/java/edu/snu/nemo/runtime/executor/datatransfer/DataTransferTest.java
b/runtime/executor/src/test/java/edu/snu/nemo/runtime/executor/datatransfer/DataTransferTest.java
index 0d0a69b..01af83f 100644
---
a/runtime/executor/src/test/java/edu/snu/nemo/runtime/executor/datatransfer/DataTransferTest.java
+++
b/runtime/executor/src/test/java/edu/snu/nemo/runtime/executor/datatransfer/DataTransferTest.java
@@ -15,9 +15,7 @@
*/
package edu.snu.nemo.runtime.executor.datatransfer;
-import edu.snu.nemo.common.DataSkewMetricFactory;
-import edu.snu.nemo.common.HashRange;
-import edu.snu.nemo.common.KeyRange;
+import edu.snu.nemo.common.*;
import edu.snu.nemo.common.coder.*;
import edu.snu.nemo.common.eventhandler.PubSubEventHandlerWrapper;
import edu.snu.nemo.common.ir.edge.IREdge;
@@ -29,7 +27,6 @@ import
edu.snu.nemo.common.ir.vertex.executionproperty.ParallelismProperty;
import edu.snu.nemo.common.ir.vertex.executionproperty.ScheduleGroupProperty;
import edu.snu.nemo.common.test.EmptyComponents;
import edu.snu.nemo.conf.JobConf;
-import edu.snu.nemo.common.Pair;
import edu.snu.nemo.common.dag.DAG;
import edu.snu.nemo.common.dag.DAGBuilder;
import edu.snu.nemo.common.ir.executionproperty.ExecutionPropertyMap;
@@ -48,6 +45,7 @@ import edu.snu.nemo.runtime.executor.data.BlockManagerWorker;
import edu.snu.nemo.runtime.executor.data.SerializerManager;
import edu.snu.nemo.runtime.master.*;
import edu.snu.nemo.runtime.master.eventhandler.UpdatePhysicalPlanEventHandler;
+import org.apache.beam.sdk.values.KV;
import org.apache.commons.io.FileUtils;
import org.apache.reef.driver.evaluator.EvaluatorRequestor;
import org.apache.reef.io.network.naming.NameResolverConfiguration;
@@ -68,6 +66,7 @@ import org.powermock.modules.junit4.PowerMockRunner;
import java.io.File;
import java.io.IOException;
+import java.io.Serializable;
import java.util.*;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
@@ -299,7 +298,7 @@ public final class DataTransferTest {
// Edge setup
final IREdge dummyIREdge = new IREdge(commPattern, srcVertex, dstVertex);
- dummyIREdge.setProperty(KeyExtractorProperty.of((element -> element)));
+ dummyIREdge.setProperty(KeyExtractorProperty.of(element -> element));
dummyIREdge.setProperty(CommunicationPatternProperty.of(commPattern));
dummyIREdge.setProperty(PartitionerProperty.of(PartitionerProperty.Value.HashPartitioner));
dummyIREdge.setProperty(DataStoreProperty.of(store));
@@ -387,7 +386,7 @@ public final class DataTransferTest {
final IREdge dummyIREdge = new IREdge(commPattern, srcVertex, dstVertex);
dummyIREdge.setProperty(EncoderProperty.of(ENCODER_FACTORY));
dummyIREdge.setProperty(DecoderProperty.of(DECODER_FACTORY));
- dummyIREdge.setProperty(KeyExtractorProperty.of((element -> element)));
+ dummyIREdge.setProperty(KeyExtractorProperty.of(element -> element));
dummyIREdge.setProperty(CommunicationPatternProperty.of(commPattern));
dummyIREdge.setProperty(PartitionerProperty.of(PartitionerProperty.Value.HashPartitioner));
dummyIREdge.setProperty(DuplicateEdgeGroupProperty.of(new
DuplicateEdgeGroupPropertyValue("dummy")));
@@ -532,5 +531,5 @@ public final class DataTransferTest {
stageExecutionProperty.put(ScheduleGroupProperty.of(0));
return new Stage(stageId, emptyDag, stageExecutionProperty,
Collections.emptyList());
}
-
}
+
diff --git
a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/DataSkewDynOptDataHandler.java
b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/DataSkewDynOptDataHandler.java
index 47671ed..439c4ff 100644
---
a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/DataSkewDynOptDataHandler.java
+++
b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/DataSkewDynOptDataHandler.java
@@ -25,7 +25,7 @@ import java.util.Map;
* Handler for aggregating data used in data skew dynamic optimization.
*/
public class DataSkewDynOptDataHandler implements DynOptDataHandler {
- private final Map<Integer, Long> aggregatedDynOptData;
+ private final Map<Object, Long> aggregatedDynOptData;
public DataSkewDynOptDataHandler() {
this.aggregatedDynOptData = new HashMap<>();
@@ -40,12 +40,12 @@ public class DataSkewDynOptDataHandler implements
DynOptDataHandler {
List<ControlMessage.PartitionSizeEntry> partitionSizeInfo
= (List<ControlMessage.PartitionSizeEntry>) dynOptData;
partitionSizeInfo.forEach(partitionSizeEntry -> {
- final int hashIndex = partitionSizeEntry.getKey();
+ final Object key = partitionSizeEntry.getKey();
final long partitionSize = partitionSizeEntry.getSize();
- if (aggregatedDynOptData.containsKey(hashIndex)) {
- aggregatedDynOptData.compute(hashIndex, (originalKey, originalValue)
-> originalValue + partitionSize);
+ if (aggregatedDynOptData.containsKey(key)) {
+ aggregatedDynOptData.compute(key, (originalKey, originalValue) ->
originalValue + partitionSize);
} else {
- aggregatedDynOptData.put(hashIndex, partitionSize);
+ aggregatedDynOptData.put(key, partitionSize);
}
});
}
diff --git
a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/RuntimeMaster.java
b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/RuntimeMaster.java
index 3129732..de6df92 100644
---
a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/RuntimeMaster.java
+++
b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/RuntimeMaster.java
@@ -378,7 +378,6 @@ public final class RuntimeMaster {
LOG.error(failedExecutorId + " failed, Stack Trace: ", exception);
throw new RuntimeException(exception);
case DataSizeMetric:
- // TODO #96: Modularize DataSkewPolicy to use MetricVertex and
BarrierVertex.
((BatchScheduler)
scheduler).updateDynOptData(message.getDataSizeMetricMsg().getPartitionSizeList());
break;
case MetricMessageReceived:
diff --git
a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/BatchScheduler.java
b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/BatchScheduler.java
index 5ccc2fc..21e196e 100644
---
a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/BatchScheduler.java
+++
b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/BatchScheduler.java
@@ -419,7 +419,14 @@ public final class BatchScheduler implements Scheduler {
}
/**
- * @param taskId the metric collected task ID.
+ * Get the target edge of dynamic optimization.
+ * The edge is annotated with {@link MetricCollectionProperty}, which is an
outgoing edge of
+ * a parent of the stage put on hold.
+ *
+ * See {@link
edu.snu.nemo.compiler.optimizer.pass.compiletime.reshaping.SkewReshapingPass}
+ * for setting the target edge of dynamic optimization.
+ *
+ * @param taskId the task ID that sent stage-level aggregated metric for
dynamic optimization.
* @return the edge to optimize.
*/
private StageEdge getEdgeToOptimize(final String taskId) {
@@ -429,8 +436,18 @@ public final class BatchScheduler implements Scheduler {
.findFirst()
.orElseThrow(() -> new RuntimeException());
+ // Stage put on hold, i.e. stage with vertex containing
AggregateMetricTransform
+ // should have a parent stage whose outgoing edges contain the target edge
of dynamic optimization.
+ final List<Stage> parentStages =
planStateManager.getPhysicalPlan().getStageDAG()
+ .getParents(stagePutOnHold.getId());
+
+ if (parentStages.size() > 1) {
+ throw new RuntimeException("Error in setting target edge of dynamic
optimization!");
+ }
+
// Get outgoing edges of that stage with MetricCollectionProperty
- List<StageEdge> stageEdges =
planStateManager.getPhysicalPlan().getStageDAG().getOutgoingEdgesOf(stagePutOnHold);
+ List<StageEdge> stageEdges =
planStateManager.getPhysicalPlan().getStageDAG()
+ .getOutgoingEdgesOf(parentStages.get(0));
for (StageEdge edge : stageEdges) {
if
(edge.getExecutionProperties().containsKey(MetricCollectionProperty.class)) {
return edge;