Github user HanumathRao commented on a diff in the pull request:

    https://github.com/apache/drill/pull/1238#discussion_r183875553
  
    --- Diff: 
exec/java-exec/src/main/java/org/apache/drill/exec/store/TimedCallable.java ---
    @@ -0,0 +1,265 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one
    + * or more contributor license agreements.  See the NOTICE file
    + * distributed with this work for additional information
    + * regarding copyright ownership.  The ASF licenses this file
    + * to you under the Apache License, Version 2.0 (the
    + * "License"); you may not use this file except in compliance
    + * with the License.  You may obtain a copy of the License at
    + *
    + * http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +package org.apache.drill.exec.store;
    +
    +import java.io.IOException;
    +import java.util.List;
    +import java.util.Objects;
    +import java.util.concurrent.Callable;
    +import java.util.concurrent.CancellationException;
    +import java.util.concurrent.ExecutionException;
    +import java.util.concurrent.ExecutorService;
    +import java.util.concurrent.Executors;
    +import java.util.concurrent.Future;
    +import java.util.concurrent.RejectedExecutionException;
    +import java.util.concurrent.TimeUnit;
    +import java.util.function.Consumer;
    +import java.util.function.Function;
    +import java.util.stream.Collectors;
    +
    +import org.apache.drill.common.exceptions.UserException;
    +
    +import org.slf4j.Logger;
    +import org.slf4j.LoggerFactory;
    +
    +import com.google.common.base.Preconditions;
    +import com.google.common.base.Stopwatch;
    +import com.google.common.util.concurrent.MoreExecutors;
    +import com.google.common.util.concurrent.ThreadFactoryBuilder;
    +
    +/**
    + * Class used to allow parallel executions of tasks in a simplified way. 
Also maintains and reports timings of task completion.
    + * TODO: look at switching to fork join.
    + * @param <V> The time value that will be returned when the task is 
executed.
    + */
    +public abstract class TimedCallable<V> implements Callable<V> {
    +  private static final Logger logger = 
LoggerFactory.getLogger(TimedCallable.class);
    +
    +  private static long TIMEOUT_PER_RUNNABLE_IN_MSECS = 15000;
    +
    +  private volatile long startTime = 0;
    +  private volatile long executionTime = -1;
    +
    +  private static class FutureMapper<V> implements Function<Future<V>, V> {
    +    int count;
    +    Throwable throwable = null;
    +
    +    private void setThrowable(Throwable t) {
    +      if (throwable == null) {
    +        throwable = t;
    +      } else {
    +        throwable.addSuppressed(t);
    +      }
    +    }
    +
    +    @Override
    +    public V apply(Future<V> future) {
    +      Preconditions.checkState(future.isDone());
    +      if (!future.isCancelled()) {
    +        try {
    +          count++;
    +          return future.get();
    +        } catch (InterruptedException e) {
    +          // there is no wait as we are getting result from the 
completed/done future
    +          logger.error("Unexpected exception", e);
    +          throw UserException.internalError(e)
    +              .message("Unexpected exception")
    +              .build(logger);
    +        } catch (ExecutionException e) {
    +          setThrowable(e.getCause());
    +        }
    +      } else {
    +        setThrowable(new CancellationException());
    +      }
    +      return null;
    +    }
    +  }
    +
    +  private static class Statistics<V> implements Consumer<TimedCallable<V>> 
{
    +    final long start = System.nanoTime();
    +    final Stopwatch watch = Stopwatch.createStarted();
    +    long totalExecution;
    +    long maxExecution;
    +    int count;
    +    int startedCount;
    +    private int doneCount;
    +    // measure thread creation times
    +    long earliestStart;
    +    long latestStart;
    +    long totalStart;
    +
    +    @Override
    +    public void accept(TimedCallable<V> task) {
    +      count++;
    +      long threadStart = task.getStartTime(TimeUnit.NANOSECONDS) - start;
    +      if (threadStart >= 0) {
    +        startedCount++;
    +        earliestStart = Math.min(earliestStart, threadStart);
    +        latestStart = Math.max(latestStart, threadStart);
    +        totalStart += threadStart;
    +        long executionTime = task.getExecutionTime(TimeUnit.NANOSECONDS);
    +        if (executionTime != -1) {
    +          doneCount++;
    +          totalExecution += executionTime;
    +          maxExecution = Math.max(maxExecution, executionTime);
    +        } else {
    +          logger.info("Task {} started at {} did not finish", task, 
threadStart);
    +        }
    +      } else {
    +        logger.info("Task {} never commenced execution", task);
    +      }
    +    }
    +
    +    Statistics<V> collect(final List<TimedCallable<V>> tasks) {
    +      totalExecution = maxExecution = 0;
    +      count = startedCount = doneCount = 0;
    +      earliestStart = Long.MAX_VALUE;
    +      latestStart = totalStart = 0;
    +      tasks.forEach(this);
    +      return this;
    +    }
    +
    +    void info(final String activity, final Logger logger, int parallelism) 
{
    +      if (startedCount > 0) {
    +        logger.info("{}: started {} out of {} using {} threads. (start 
time: min {} ms, avg {} ms, max {} ms).",
    +            activity, startedCount, count, parallelism,
    +            TimeUnit.NANOSECONDS.toMillis(earliestStart),
    +            TimeUnit.NANOSECONDS.toMillis(totalStart) / startedCount,
    +            TimeUnit.NANOSECONDS.toMillis(latestStart));
    +      } else {
    +        logger.info("{}: started {} out of {} using {} threads.", 
activity, startedCount, count, parallelism);
    +      }
    +
    +      if (doneCount > 0) {
    +        logger.info("{}: completed {} out of {} using {} threads 
(execution time: total {} ms, avg {} ms, max {} ms).",
    +            activity, doneCount, count, parallelism, 
watch.elapsed(TimeUnit.MILLISECONDS),
    +            TimeUnit.NANOSECONDS.toMillis(totalExecution) / doneCount, 
TimeUnit.NANOSECONDS.toMillis(maxExecution));
    +      } else {
    +        logger.info("{}: completed {} out of {} using {} threads", 
activity, doneCount, count, parallelism);
    +      }
    +    }
    +  }
    +
    +  @Override
    +  public final V call() throws Exception {
    +    long start = System.nanoTime();
    +    startTime = start;
    +    try {
    +      logger.debug("Started execution of '{}' task at {} ms", this, 
TimeUnit.MILLISECONDS.convert(start, TimeUnit.NANOSECONDS));
    +      return runInner();
    +    } catch (InterruptedException e) {
    +      logger.warn("Task '{}' interrupted", this, e);
    +      throw e;
    +    } finally {
    +      long time = System.nanoTime() - start;
    +      if (logger.isWarnEnabled()) {
    +        long timeMillis = TimeUnit.MILLISECONDS.convert(time, 
TimeUnit.NANOSECONDS);
    +        if (timeMillis > TIMEOUT_PER_RUNNABLE_IN_MSECS) {
    +          logger.warn("Task '{}' execution time {} ms exceeds timeout {} 
ms.", this, timeMillis, TIMEOUT_PER_RUNNABLE_IN_MSECS);
    +        } else {
    +          logger.debug("Task '{}' execution time is {} ms", this, 
timeMillis);
    +        }
    +      }
    +      executionTime = time;
    +    }
    +  }
    +
    +  protected abstract V runInner() throws Exception;
    +
    +  private long getStartTime(TimeUnit unit) {
    +    return unit.convert(startTime, TimeUnit.NANOSECONDS);
    +  }
    +
    +  private long getExecutionTime(TimeUnit unit) {
    +    return unit.convert(executionTime, TimeUnit.NANOSECONDS);
    +  }
    +
    +
    +  /**
    +   * Execute the list of runnables with the given parallelization.  At 
end, return values and report completion time
    +   * stats to provided logger. Each runnable is allowed a certain timeout. 
If the timeout exceeds, existing/pending
    +   * tasks will be cancelled and a {@link UserException} is thrown.
    +   * @param activity Name of activity for reporting in logger.
    +   * @param logger The logger to use to report results.
    +   * @param tasks List of callable that should be executed and timed.  If 
this list has one item, task will be
    +   *                  completed in-thread. Each callable must handle 
{@link InterruptedException}s.
    +   * @param parallelism  The number of threads that should be run to 
complete this task.
    +   * @return The list of outcome objects.
    +   * @throws IOException All exceptions are coerced to IOException since 
this was build for storage system tasks initially.
    +   */
    +  public static <V> List<V> run(final String activity, final Logger 
logger, final List<TimedCallable<V>> tasks, int parallelism) throws IOException 
{
    +    
Preconditions.checkArgument(!Preconditions.checkNotNull(tasks).isEmpty(), "list 
of tasks is empty");
    +    parallelism = Math.min(parallelism, tasks.size());
    +    final ExecutorService threadPool = parallelism == 1 ? 
MoreExecutors.newDirectExecutorService()
    +        : Executors.newFixedThreadPool(parallelism, new 
ThreadFactoryBuilder().setNameFormat(activity + "-%d").build());
    +    final long timeout = (long)Math.ceil((TIMEOUT_PER_RUNNABLE_IN_MSECS * 
tasks.size())/parallelism);
    --- End diff --
    
    IMO, this code part of the original logic needs to be changed to 
TIMEOUT_PER_RUNNABLE_IN_MSECS * (long)Math.ceil(tasks.size()/parallelism); 
instead of (long)Math.ceil((TIMEOUT_PER_RUNNABLE_IN_MSECS * 
tasks.size())/parallelism). 


---

Reply via email to