johnyangk commented on a change in pull request #2: [NEMO-7] Intra-TaskGroup 
pipelining
URL: https://github.com/apache/incubator-nemo/pull/2#discussion_r172707585
 
 

 ##########
 File path: 
runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/TaskGroupExecutor.java
 ##########
 @@ -145,303 +201,504 @@ private void initializeDataTransfer() {
         .collect(Collectors.toSet());
   }
 
-  // Helper functions to initializes stage-internal edges.
-  private void createLocalReader(final Task task, final RuntimeEdge<Task> 
internalEdge) {
-    final InputReader inputReader = 
channelFactory.createLocalReader(taskGroupIdx, internalEdge);
-    addInputReader(task, inputReader);
-  }
+  /**
+   * Add input pipes to each {@link Task}.
+   * Input pipe denotes all the pipes of intra-Stage parent tasks of this task.
+   *
+   * @param task the Task to add input pipes to.
+   */
+  private void addInputPipe(final Task task) {
+    List<LocalPipe> inputPipes = new ArrayList<>();
+    List<Task> parentTasks = taskGroupDag.getParents(task.getId());
+    final String physicalTaskId = getPhysicalTaskId(task.getId());
 
-  private void createLocalWriter(final Task task, final RuntimeEdge<Task> 
internalEdge) {
-    final OutputWriter outputWriter = channelFactory.createLocalWriter(task, 
taskGroupIdx, internalEdge);
-    addOutputWriter(task, outputWriter);
+    if (parentTasks != null) {
+      parentTasks.forEach(parent -> {
+        final LocalPipe parentOutputPipe = taskToOutputPipeMap.get(parent);
+        inputPipes.add(parentOutputPipe);
+        LOG.info("log: Added Outputpipe of {} as InputPipe of {} {}",
+            getPhysicalTaskId(parent.getId()), taskGroupId, physicalTaskId);
+      });
+      taskToInputPipesMap.put(task, inputPipes);
+    }
   }
 
-  // Helper functions to add the initialized reader/writer to the maintained 
map.
-  private void addInputReader(final Task task, final InputReader inputReader) {
+  /**
+   * Add output pipes to each {@link Task}.
+   * Output pipe denotes the one and only one pipe of this task.
+   * Check the outgoing edges that will use this pipe,
+   * and set this pipe as side input if any one of the edges uses this pipe as 
side input.
+   *
+   * @param task the Task to add output pipes to.
+   */
+  private void addOutputPipe(final Task task) {
+    final LocalPipe outputPipe = new LocalPipe();
     final String physicalTaskId = getPhysicalTaskId(task.getId());
-    physicalTaskIdToInputReaderMap.computeIfAbsent(physicalTaskId, readerList 
-> new ArrayList<>());
-    physicalTaskIdToInputReaderMap.get(physicalTaskId).add(inputReader);
-  }
+    final List<RuntimeEdge<Task>> outEdges = 
taskGroupDag.getOutgoingEdgesOf(task);
+
+    outEdges.forEach(outEdge -> {
+      if (outEdge.isSideInput()) {
+        outputPipe.setSideInputRuntimeEdge(outEdge);
+        outputPipe.setAsSideInput(physicalTaskId);
+        LOG.info("log: {} {} Marked as accepting sideInput(edge {})",
+            taskGroupId, physicalTaskId, outEdge.getId());
+      }
+    });
 
-  private void addOutputWriter(final Task task, final OutputWriter 
outputWriter) {
-    final String physicalTaskId = getPhysicalTaskId(task.getId());
-    physicalTaskIdToOutputWriterMap.computeIfAbsent(physicalTaskId, readerList 
-> new ArrayList<>());
-    physicalTaskIdToOutputWriterMap.get(physicalTaskId).add(outputWriter);
+    taskToOutputPipeMap.put(task, outputPipe);
+    LOG.info("log: {} {} Added OutputPipe", taskGroupId, physicalTaskId);
   }
 
-  /**
-   * Executes the task group.
-   */
-  public void execute() {
-    LOG.info("{} Execution Started!", taskGroupId);
-    if (isExecutionRequested) {
-      throw new RuntimeException("TaskGroup {" + taskGroupId + "} execution 
called again!");
-    } else {
-      isExecutionRequested = true;
-    }
+  private boolean hasInputPipe(final Task task) {
+    return taskToInputPipesMap.containsKey(task);
+  }
 
-    taskGroupStateManager.onTaskGroupStateChanged(
-        TaskGroupState.State.EXECUTING, Optional.empty(), Optional.empty());
+  private boolean hasOutputWriter(final Task task) {
+    return taskToOutputWritersMap.containsKey(task);
+  }
 
-    taskGroupDag.topologicalDo(task -> {
-      final String physicalTaskId = getPhysicalTaskId(task.getId());
-      taskGroupStateManager.onTaskStateChanged(physicalTaskId, 
TaskState.State.EXECUTING, Optional.empty());
-      try {
-        if (task instanceof BoundedSourceTask) {
-          launchBoundedSourceTask((BoundedSourceTask) task);
-          taskGroupStateManager.onTaskStateChanged(physicalTaskId, 
TaskState.State.COMPLETE, Optional.empty());
-          LOG.info("{} Execution Complete!", taskGroupId);
-        } else if (task instanceof OperatorTask) {
-          launchOperatorTask((OperatorTask) task);
-          taskGroupStateManager.onTaskStateChanged(physicalTaskId, 
TaskState.State.COMPLETE, Optional.empty());
-          LOG.info("{} Execution Complete!", taskGroupId);
-        } else if (task instanceof MetricCollectionBarrierTask) {
-          launchMetricCollectionBarrierTask((MetricCollectionBarrierTask) 
task);
-          taskGroupStateManager.onTaskStateChanged(physicalTaskId, 
TaskState.State.ON_HOLD, Optional.empty());
-          LOG.info("{} Execution Complete!", taskGroupId);
-        } else {
-          throw new UnsupportedOperationException(task.toString());
-        }
-      } catch (final BlockFetchException ex) {
-        taskGroupStateManager.onTaskStateChanged(physicalTaskId, 
TaskState.State.FAILED_RECOVERABLE,
-            
Optional.of(TaskGroupState.RecoverableFailureCause.INPUT_READ_FAILURE));
-        LOG.warn("{} Execution Failed (Recoverable)! Exception: {}",
-            new Object[] {taskGroupId, ex.toString()});
-      } catch (final BlockWriteException ex2) {
-        taskGroupStateManager.onTaskStateChanged(physicalTaskId, 
TaskState.State.FAILED_RECOVERABLE,
-            
Optional.of(TaskGroupState.RecoverableFailureCause.OUTPUT_WRITE_FAILURE));
-        LOG.warn("{} Execution Failed (Recoverable)! Exception: {}",
-            new Object[] {taskGroupId, ex2.toString()});
-      } catch (final Exception e) {
-        taskGroupStateManager.onTaskStateChanged(
-            physicalTaskId, TaskState.State.FAILED_UNRECOVERABLE, 
Optional.empty());
-        throw new RuntimeException(e);
-      }
-    });
+  private void setTaskPutOnHold(final MetricCollectionBarrierTask task) {
+    final String physicalTaskId = getPhysicalTaskId(task.getId());
+    logicalTaskIdPutOnHold = 
RuntimeIdGenerator.getLogicalTaskIdIdFromPhysicalTaskId(physicalTaskId);
   }
 
-  /**
-   * Processes a BoundedSourceTask.
-   *
-   * @param boundedSourceTask the bounded source task to execute
-   * @throws Exception occurred during input read.
-   */
-  private void launchBoundedSourceTask(final BoundedSourceTask 
boundedSourceTask) throws Exception {
-    final String physicalTaskId = getPhysicalTaskId(boundedSourceTask.getId());
+  private void writeAndCloseOutputWriters(final Task task) {
+    final String physicalTaskId = getPhysicalTaskId(task.getId());
+    final List<Long> writtenBytesList = new ArrayList<>();
     final Map<String, Object> metric = new HashMap<>();
     metricCollector.beginMeasurement(physicalTaskId, metric);
+    final long writeStartTime = System.currentTimeMillis();
 
-    final long readStartTime = System.currentTimeMillis();
-    final Readable readable = boundedSourceTask.getReadable();
-    final Iterable readData = readable.read();
-    final long readEndTime = System.currentTimeMillis();
-    metric.put("BoundedSourceReadTime(ms)", readEndTime - readStartTime);
-
-    final List<Long> writtenBytesList = new ArrayList<>();
-    for (final OutputWriter outputWriter : 
physicalTaskIdToOutputWriterMap.get(physicalTaskId)) {
-      outputWriter.write(readData);
+    taskToOutputWritersMap.get(task).forEach(outputWriter -> {
+      LOG.info("Write and close outputWriter of task {}", 
getPhysicalTaskId(task.getId()));
+      outputWriter.write();
       outputWriter.close();
       final Optional<Long> writtenBytes = outputWriter.getWrittenBytes();
       writtenBytes.ifPresent(writtenBytesList::add);
-    }
+    });
+
     final long writeEndTime = System.currentTimeMillis();
-    metric.put("OutputWriteTime(ms)", writeEndTime - readEndTime);
+    metric.put("OutputWriteTime(ms)", writeEndTime - writeStartTime);
     putWrittenBytesMetric(writtenBytesList, metric);
     metricCollector.endMeasurement(physicalTaskId, metric);
   }
 
-  /**
-   * Processes an OperatorTask.
-   * @param operatorTask to execute
-   */
-  private void launchOperatorTask(final OperatorTask operatorTask) {
-    final Map<Transform, Object> sideInputMap = new HashMap<>();
-    final List<DataUtil.IteratorWithNumBytes> sideInputIterators = new 
ArrayList<>();
-    final String physicalTaskId = getPhysicalTaskId(operatorTask.getId());
+  private void prepareInputFromSource() {
+    taskGroupDag.topologicalDo(task -> {
+      if (task instanceof BoundedSourceTask) {
+        try {
+          final Readable readable = ((BoundedSourceTask) task).getReadable();
+          final Iterable readData = readable.read();
+          numBoundedSources++;
+          numIterators++;
+
+          final String iteratorId = generateIteratorId();
+          final Iterator iterator = readData.iterator();
+          idToSrcIteratorMap.putIfAbsent(iteratorId, iterator);
+          srcIteratorIdToTasksMap.putIfAbsent(iteratorId, new ArrayList<>());
+          srcIteratorIdToTasksMap.get(iteratorId).add(task);
+        } catch (final BlockFetchException ex) {
+          
taskGroupStateManager.onTaskGroupStateChanged(TaskGroupState.State.FAILED_RECOVERABLE,
+              Optional.empty(), 
Optional.of(TaskGroupState.RecoverableFailureCause.INPUT_READ_FAILURE));
+          LOG.info("{} Execution Failed (Recoverable: input read failure)! 
Exception: {}",
+              taskGroupId, ex.toString());
+        } catch (final Exception e) {
+          
taskGroupStateManager.onTaskGroupStateChanged(TaskGroupState.State.FAILED_UNRECOVERABLE,
+              Optional.empty(), Optional.empty());
+          LOG.info("{} Execution Failed! Exception: {}", taskGroupId, 
e.toString());
+          throw new RuntimeException(e);
+        }
+      }
+      // TODO #XXX: Support other types of source tasks, i. e. 
InitializedSourceTask
+    });
+  }
 
-    final Map<String, Object> metric = new HashMap<>();
-    metricCollector.beginMeasurement(physicalTaskId, metric);
-    long accumulatedBlockedReadTime = 0;
-    long accumulatedWriteTime = 0;
-    long accumulatedSerializedBlockSize = 0;
-    long accumulatedEncodedBlockSize = 0;
-    boolean blockSizeAvailable = true;
-
-    final long readStartTime = System.currentTimeMillis();
-    // Check for side inputs
-    
physicalTaskIdToInputReaderMap.get(physicalTaskId).stream().filter(InputReader::isSideInputReader)
-        .forEach(inputReader -> {
+  private void prepareInputFromOtherStages() {
+    inputReaders.stream().forEach(inputReader -> {
+      final List<CompletableFuture<DataUtil.IteratorWithNumBytes>> futures = 
inputReader.read();
+      numIterators += futures.size();
+      final long blockedReadStartTime = System.currentTimeMillis();
+
+      // Add consumers which will push iterator when the futures are complete.
+      futures.forEach(compFuture -> compFuture.whenComplete((iterator, 
exception) -> {
+        if (exception != null) {
+          throw new BlockFetchException(exception);
+        }
+
+        completedFutures.getAndIncrement();
+        final String iteratorId = generateIteratorId();
+        if (iteratorIdToTasksMap.containsKey(iteratorId)) {
+          throw new RuntimeException("iteratorIdToTasksMap already contains " 
+ iteratorId);
+        } else {
+          iteratorIdToTasksMap.computeIfAbsent(iteratorId, absentIteratorId -> 
{
+            final List<Task> list = new ArrayList<>();
+            list.addAll(inputReaderToTasksMap.get(inputReader));
+            return Collections.unmodifiableList(list);
+          });
           try {
-            if (!inputReader.isSideInputReader()) {
-              // Trying to get sideInput from a reader that does not handle 
sideInput.
-              // This is probably a bug. We're not trying to recover but 
ensure a hard fail.
-              throw new RuntimeException("Trying to get sideInput from 
non-sideInput reader");
-            }
-            final DataUtil.IteratorWithNumBytes sideInputIterator = 
inputReader.read().get(0).get();
-            final Object sideInput = getSideInput(sideInputIterator);
-
-            final RuntimeEdge inEdge = inputReader.getRuntimeEdge();
-            final Transform srcTransform;
-            if (inEdge instanceof PhysicalStageEdge) {
-              srcTransform = ((OperatorVertex) ((PhysicalStageEdge) 
inEdge).getSrcVertex())
-                  .getTransform();
-            } else {
-              srcTransform = ((OperatorTask) inEdge.getSrc()).getTransform();
-            }
-            sideInputMap.put(srcTransform, sideInput);
-            sideInputIterators.add(sideInputIterator);
-          } catch (final InterruptedException | ExecutionException e) {
-            throw new BlockFetchException(e);
+            iteratorQueue.put(Pair.of(iteratorId, iterator));
+          } catch (InterruptedException e) {
+            throw new RuntimeException("Interrupted while receiving iterator " 
+ e);
           }
-        });
+        }
+        final long blockedReadEndTime = System.currentTimeMillis();
+        accumulatedBlockedReadTime += blockedReadEndTime - 
blockedReadStartTime;
+      }));
+    });
+  }
 
-    for (final DataUtil.IteratorWithNumBytes iterator : sideInputIterators) {
-      try {
-        accumulatedSerializedBlockSize += iterator.getNumSerializedBytes();
-        accumulatedEncodedBlockSize += iterator.getNumEncodedBytes();
-      } catch (final 
DataUtil.IteratorWithNumBytes.NumBytesNotSupportedException e) {
-        blockSizeAvailable = false;
-        break;
-      }
+  private boolean finishedAllTasks() {
+    // Total size of this TaskGroup
+    int taskNum = taskToOutputPipeMap.keySet().size();
+    int finishedTaskNum = finishedTaskIds.size();
+
+    return finishedTaskNum == taskNum;
+  }
+
+  private void initializePipeToDstTasksMap() {
+    srcIteratorIdToTasksMap.values().forEach(tasks ->
+        tasks.forEach(task -> {
+          final List<Task> dstTasks = taskGroupDag.getChildren(task.getId());
+          LocalPipe pipe = taskToOutputPipeMap.get(task);
+          pipeIdToDstTasksMap.putIfAbsent(pipe.getId(), dstTasks);
+          LOG.info("{} pipeIdToDstTasksMap: [{}'s OutputPipe, {}]",
+              taskGroupId, getPhysicalTaskId(task.getId()), dstTasks);
+        }));
+    iteratorIdToTasksMap.values().forEach(tasks ->
+        tasks.forEach(task -> {
+          final List<Task> dstTasks = taskGroupDag.getChildren(task.getId());
+          LocalPipe pipe = taskToOutputPipeMap.get(task);
+          pipeIdToDstTasksMap.putIfAbsent(pipe.getId(), dstTasks);
+          LOG.info("{} pipeIdToDstTasksMap: [{}'s OutputPipe, {}]",
+              taskGroupId, getPhysicalTaskId(task.getId()), dstTasks);
+        }));
+  }
+
+  private void updatePipeToDstTasksMap() {
+    Map<String, List<Task>> currentMap = pipeIdToDstTasksMap;
+    Map<String, List<Task>> updatedMap = new HashMap<>();
+
+    currentMap.values().forEach(tasks ->
+        tasks.forEach(task -> {
+          final List<Task> dstTasks = taskGroupDag.getChildren(task.getId());
+          LocalPipe pipe = taskToOutputPipeMap.get(task);
+          updatedMap.putIfAbsent(pipe.getId(), dstTasks);
+          LOG.info("{} pipeIdToDstTasksMap: [{}, {}]",
+              taskGroupId, getPhysicalTaskId(task.getId()), dstTasks);
+        })
+    );
+
+    pipeIdToDstTasksMap = updatedMap;
+  }
+
+  private void closeTransform(final Task task) {
+    if (task instanceof OperatorTask) {
+      Transform transform = ((OperatorTask) task).getTransform();
+      transform.close();
+      LOG.info("{} {} Closed Transform {}!", taskGroupId, 
getPhysicalTaskId(task.getId()), transform);
     }
+  }
 
-    final Transform.Context transformContext = new ContextImpl(sideInputMap);
-    final OutputCollectorImpl outputCollector = new OutputCollectorImpl();
-
-    final Transform transform = operatorTask.getTransform();
-    transform.prepare(transformContext, outputCollector);
-
-    // Check for non-side inputs
-    // This blocking queue contains the pairs having data and source vertex 
ids.
-    final BlockingQueue<Pair<DataUtil.IteratorWithNumBytes, String>> dataQueue 
= new LinkedBlockingQueue<>();
-    final AtomicInteger sourceParallelism = new AtomicInteger(0);
-    
physicalTaskIdToInputReaderMap.get(physicalTaskId).stream().filter(inputReader 
-> !inputReader.isSideInputReader())
-        .forEach(inputReader -> {
-          final List<CompletableFuture<DataUtil.IteratorWithNumBytes>> futures 
= inputReader.read();
-          final String srcIrVtxId = inputReader.getSrcIrVertexId();
-          sourceParallelism.getAndAdd(inputReader.getSourceParallelism());
-          // Add consumers which will push the data to the data queue when it 
ready to the futures.
-          futures.forEach(compFuture -> compFuture.whenComplete((data, 
exception) -> {
-            if (exception != null) {
-              throw new BlockFetchException(exception);
-            }
-            dataQueue.add(Pair.of(data, srcIrVtxId));
-          }));
-        });
-    final long readFutureEndTime = System.currentTimeMillis();
-    // Consumes all of the partitions from incoming edges.
-    for (int srcTaskNum = 0; srcTaskNum < sourceParallelism.get(); 
srcTaskNum++) {
+  private void sideInputFromOtherStages(final Task task, final Map<Transform, 
Object> sideInputMap) {
+    taskToSideInputReadersMap.get(task).forEach(sideInputReader -> {
       try {
-        // Because the data queue is a blocking queue, we may need to wait 
some available data to be pushed.
-        final long blockedReadStartTime = System.currentTimeMillis();
-        final Pair<DataUtil.IteratorWithNumBytes, String> availableData = 
dataQueue.take();
-        final long blockedReadEndTime = System.currentTimeMillis();
-        accumulatedBlockedReadTime += blockedReadEndTime - 
blockedReadStartTime;
-        transform.onData(availableData.left(), availableData.right());
-        if (blockSizeAvailable) {
-          try {
-            accumulatedSerializedBlockSize += 
availableData.left().getNumSerializedBytes();
-            accumulatedEncodedBlockSize += 
availableData.left().getNumEncodedBytes();
-          } catch (final 
DataUtil.IteratorWithNumBytes.NumBytesNotSupportedException e) {
-            blockSizeAvailable = false;
-          }
+        final DataUtil.IteratorWithNumBytes sideInputIterator = 
sideInputReader.read().get(0).get();
+        final Object sideInput = getSideInput(sideInputIterator);
+        final RuntimeEdge inEdge = sideInputReader.getRuntimeEdge();
+        final Transform srcTransform;
+        if (inEdge instanceof PhysicalStageEdge) {
+          srcTransform = ((OperatorVertex) ((PhysicalStageEdge) 
inEdge).getSrcVertex()).getTransform();
+        } else {
+          srcTransform = ((OperatorTask) inEdge.getSrc()).getTransform();
+        }
+        sideInputMap.put(srcTransform, sideInput);
+
+        // Collect metrics on block size if possible.
+        try {
+          serBlockSize += sideInputIterator.getNumSerializedBytes();
+        } catch (final 
DataUtil.IteratorWithNumBytes.NumBytesNotSupportedException e) {
+          serBlockSize = -1;
         }
-      } catch (final InterruptedException e) {
+        try {
+          encodedBlockSize += sideInputIterator.getNumEncodedBytes();
+        } catch (final 
DataUtil.IteratorWithNumBytes.NumBytesNotSupportedException e) {
+          encodedBlockSize = -1;
+        }
+
+        LOG.info("log: {} {} read sideInput from InputReader {}",
+            taskGroupId, getPhysicalTaskId(task.getId()), sideInput);
+      } catch (final InterruptedException | ExecutionException e) {
         throw new BlockFetchException(e);
       }
+    });
+  }
+
+  private void sideInputFromThisStage(final Task task, final Map<Transform, 
Object> sideInputMap) {
+    final String physicalTaskId = getPhysicalTaskId(task.getId());
+
+    taskToInputPipesMap.get(task).forEach(inputPipe -> {
+      if (inputPipe.hasSideInputFor(physicalTaskId)) {
+        // because sideInput is only 1 element in the pipe
+        Object sideInput = inputPipe.remove();
+        final RuntimeEdge inEdge = inputPipe.getSideInputRuntimeEdge();
+        final Transform srcTransform;
+        if (inEdge instanceof PhysicalStageEdge) {
+          srcTransform = ((OperatorVertex) ((PhysicalStageEdge) 
inEdge).getSrcVertex()).getTransform();
+        } else {
+          srcTransform = ((OperatorTask) inEdge.getSrc()).getTransform();
+        }
+        sideInputMap.put(srcTransform, sideInput);
+        LOG.info("log: {} {} read sideInput from InputPipe {}", taskGroupId, 
physicalTaskId, sideInput);
+      }
+    });
+  }
+
+  private void prepareTransform(final Transform transform, final Task task) {
+    if (!preparedTransforms.contains(transform)) {
+      final Map<Transform, Object> sideInputMap = new HashMap<>();
+
+      // Check and collect side inputs.
+      if (taskToSideInputReadersMap.keySet().contains(task)) {
+        sideInputFromOtherStages(task, sideInputMap);
+      }
+      if (hasInputPipe(task)) {
+        sideInputFromThisStage(task, sideInputMap);
+      }
 
-      // Check whether there is any output data from the transform and write 
the output of this task to the writer.
-      final List output = outputCollector.collectOutputList();
-      if (!output.isEmpty() && 
physicalTaskIdToOutputWriterMap.containsKey(physicalTaskId)) {
-        final long writeStartTime = System.currentTimeMillis();
-        
physicalTaskIdToOutputWriterMap.get(physicalTaskId).forEach(outputWriter -> 
outputWriter.write(output));
-        final long writeEndTime = System.currentTimeMillis();
-        accumulatedWriteTime += writeEndTime - writeStartTime;
-      } // If else, this is a sink task.
+      final Transform.Context transformContext = new ContextImpl(sideInputMap);
+      final LocalPipe outputPipe = taskToOutputPipeMap.get(task);
+      transform.prepare(transformContext, outputPipe);
 
 Review comment:
   I'd prepare all transforms of a task group, before processing any data.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to