javanna commented on code in PR #12689: URL: https://github.com/apache/lucene/pull/12689#discussion_r1362621063
########## lucene/core/src/java/org/apache/lucene/search/TaskExecutor.java: ########## @@ -64,64 +67,124 @@ public final class TaskExecutor { * @param <T> the return type of the task execution */ public <T> List<T> invokeAll(Collection<Callable<T>> callables) throws IOException { - List<Task<T>> tasks = new ArrayList<>(callables.size()); - boolean runOnCallerThread = numberOfRunningTasksInCurrentThread.get() > 0; - for (Callable<T> callable : callables) { - Task<T> task = new Task<>(callable); - tasks.add(task); - if (runOnCallerThread) { - task.run(); - } else { - executor.execute(task); + TaskGroup<T> taskGroup = new TaskGroup<>(callables); + return taskGroup.invokeAll(executor); + } + + /** + * Holds all the sub-tasks that a certain operation gets split into as it gets parallelized and + * exposes the ability to invoke such tasks and wait for them all to complete their execution and + * provide their results. Ensures that each task does not get parallelized further: this is + * important to avoid a deadlock in situations where one executor thread waits on other executor + * threads to complete before it can progress. This happens in situations where for instance + * {@link Query#createWeight(IndexSearcher, ScoreMode, float)} is called as part of searching each + * slice, like {@link TopFieldCollector#populateScores(ScoreDoc[], IndexSearcher, Query)} does. + * Additionally, if one task throws an exception, all other tasks from the same group are + * cancelled, to avoid needless computation as their results would not be exposed anyways. Creates + * one {@link FutureTask} for each {@link Callable} provided + * + * @param <T> the return type of all the callables + */ + private static final class TaskGroup<T> { + private final Collection<RunnableFuture<T>> futures; + + TaskGroup(Collection<Callable<T>> callables) { + List<RunnableFuture<T>> tasks = new ArrayList<>(callables.size()); + for (Callable<T> callable : callables) { + tasks.add(createTask(callable)); } + this.futures = Collections.unmodifiableCollection(tasks); } - Throwable exc = null; - final List<T> results = new ArrayList<>(); - for (Future<T> future : tasks) { - try { - results.add(future.get()); - } catch (InterruptedException e) { - var newException = new ThreadInterruptedException(e); - if (exc == null) { - exc = newException; - } else { - exc.addSuppressed(newException); + private FutureTask<T> createTask(Callable<T> callable) { + AtomicBoolean started = new AtomicBoolean(false); + return new FutureTask<>(callable) { + @Override + public void run() { + if (started.compareAndSet(false, true)) { + try { + Integer counter = numberOfRunningTasksInCurrentThread.get(); + numberOfRunningTasksInCurrentThread.set(counter + 1); + super.run(); + } finally { + Integer counter = numberOfRunningTasksInCurrentThread.get(); + numberOfRunningTasksInCurrentThread.set(counter - 1); + } + } else { + // task is cancelled hence it has no results to return. That's fine: they would be + // ignored anyway. + set(null); + } + } + + @Override + protected void setException(Throwable t) { + cancelAll(); + super.setException(t); } - } catch (ExecutionException e) { - if (exc == null) { - exc = e.getCause(); + + @Override + public boolean cancel(boolean mayInterruptIfRunning) { + assert mayInterruptIfRunning == false; + /* + Future#get (called in invokeAll) throws CancellationException for a cancelled task when invoked but leaves the task running. + We rather want to make sure that invokeAll does not leave any running tasks behind when it returns. + Overriding cancel ensures that tasks that are already started will complete normally once cancelled, and Future#get will + wait for them to finish instead of throwing CancellationException. Tasks that are cancelled before they are started won't start. + */ + return started.compareAndSet(false, true); + } + }; + } + + List<T> invokeAll(Executor executor) throws IOException { + boolean runOnCallerThread = numberOfRunningTasksInCurrentThread.get() > 0; + for (Runnable runnable : futures) { + if (runOnCallerThread) { + runnable.run(); } else { - exc.addSuppressed(e.getCause()); + executor.execute(runnable); } } + Throwable exc = null; + List<T> results = new ArrayList<>(futures.size()); + for (Future<T> future : futures) { + try { + results.add(future.get()); + } catch (InterruptedException e) { + var newException = new ThreadInterruptedException(e); + if (exc == null) { + exc = newException; + } else { + exc.addSuppressed(newException); + } + } catch (ExecutionException e) { + if (exc == null) { + exc = e.getCause(); + } else { + exc.addSuppressed(e.getCause()); + } + } + } + assert assertAllFuturesCompleted() : "Some tasks are still running?"; Review Comment: this is new: we have a test for it but I thought an assertion is also good, to verify that we don't leave running tasks behind before returning. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@lucene.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: issues-unsubscr...@lucene.apache.org For additional commands, e-mail: issues-h...@lucene.apache.org