HIVE-17338: Utilities.get*Tasks multiple methods duplicate code (Gergely Hajós via Zoltan Haindrich)
Signed-off-by: Zoltan Haindrich <k...@rxd.hu> Project: http://git-wip-us.apache.org/repos/asf/hive/repo Commit: http://git-wip-us.apache.org/repos/asf/hive/commit/cb770077 Tree: http://git-wip-us.apache.org/repos/asf/hive/tree/cb770077 Diff: http://git-wip-us.apache.org/repos/asf/hive/diff/cb770077 Branch: refs/heads/master Commit: cb770077d47931ff33a5f272982133b03f8d2a75 Parents: 229a7cc Author: Gergely Hajós <rogoz...@gmail.com> Authored: Thu Sep 21 10:23:52 2017 +0200 Committer: Zoltan Haindrich <k...@rxd.hu> Committed: Thu Sep 21 10:23:52 2017 +0200 ---------------------------------------------------------------------- .../apache/hadoop/hive/ql/exec/Utilities.java | 106 +++++-------------- .../hadoop/hive/ql/exec/TestUtilities.java | 22 ++++ 2 files changed, 46 insertions(+), 82 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/hive/blob/cb770077/ql/src/java/org/apache/hadoop/hive/ql/exec/Utilities.java ---------------------------------------------------------------------- diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/Utilities.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/Utilities.java index 4322cc6..ae70cba 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/Utilities.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/Utilities.java @@ -2444,99 +2444,41 @@ public final class Utilities { } public static List<TezTask> getTezTasks(List<Task<? extends Serializable>> tasks) { - List<TezTask> tezTasks = new ArrayList<TezTask>(); - if (tasks != null) { - Set<Task<? extends Serializable>> visited = new HashSet<Task<? extends Serializable>>(); - while (!tasks.isEmpty()) { - tasks = getTezTasks(tasks, tezTasks, visited); - } - } - return tezTasks; - } - - private static List<Task<? extends Serializable>> getTezTasks( - List<Task<? extends Serializable>> tasks, - List<TezTask> tezTasks, - Set<Task<? extends Serializable>> visited) { - List<Task<? extends Serializable>> childTasks = new ArrayList<>(); - for (Task<? extends Serializable> task : tasks) { - if (visited.contains(task)) { - continue; - } - if (task instanceof TezTask && !tezTasks.contains(task)) { - tezTasks.add((TezTask) task); - } - - if (task.getDependentTasks() != null) { - childTasks.addAll(task.getDependentTasks()); - } - visited.add(task); - } - return childTasks; + return getTasks(tasks, TezTask.class); } public static List<SparkTask> getSparkTasks(List<Task<? extends Serializable>> tasks) { - List<SparkTask> sparkTasks = new ArrayList<SparkTask>(); - if (tasks != null) { - Set<Task<? extends Serializable>> visited = new HashSet<Task<? extends Serializable>>(); - while (!tasks.isEmpty()) { - tasks = getSparkTasks(tasks, sparkTasks, visited); - } - } - return sparkTasks; + return getTasks(tasks, SparkTask.class); } - private static List<Task<? extends Serializable>> getSparkTasks( - List<Task<? extends Serializable>> tasks, - List<SparkTask> sparkTasks, - Set<Task<? extends Serializable>> visited) { - List<Task<? extends Serializable>> childTasks = new ArrayList<>(); - for (Task<? extends Serializable> task : tasks) { - if (visited.contains(task)) { - continue; - } - if (task instanceof SparkTask && !sparkTasks.contains(task)) { - sparkTasks.add((SparkTask) task); - } - - if (task.getDependentTasks() != null) { - childTasks.addAll(task.getDependentTasks()); - } - visited.add(task); - } - return childTasks; + public static List<ExecDriver> getMRTasks(List<Task<? extends Serializable>> tasks) { + return getTasks(tasks, ExecDriver.class); } - public static List<ExecDriver> getMRTasks(List<Task<? extends Serializable>> tasks) { - List<ExecDriver> mrTasks = new ArrayList<ExecDriver>(); + @SuppressWarnings("unchecked") + public static <T> List<T> getTasks(List<Task<? extends Serializable>> tasks, Class<T> requiredType) { + List<T> typeSpecificTasks = new ArrayList<>(); if (tasks != null) { - Set<Task<? extends Serializable>> visited = new HashSet<Task<? extends Serializable>>(); + Set<Task<? extends Serializable>> visited = new HashSet<>(); while (!tasks.isEmpty()) { - tasks = getMRTasks(tasks, mrTasks, visited); - } - } - return mrTasks; - } - - private static List<Task<? extends Serializable>> getMRTasks( - List<Task<? extends Serializable>> tasks, - List<ExecDriver> mrTasks, - Set<Task<? extends Serializable>> visited) { - List<Task<? extends Serializable>> childTasks = new ArrayList<>(); - for (Task<? extends Serializable> task : tasks) { - if (visited.contains(task)) { - continue; - } - if (task instanceof ExecDriver && !mrTasks.contains(task)) { - mrTasks.add((ExecDriver) task); - } - - if (task.getDependentTasks() != null) { - childTasks.addAll(task.getDependentTasks()); + List<Task<? extends Serializable>> childTasks = new ArrayList<>(); + for (Task<? extends Serializable> task : tasks) { + if (visited.contains(task)) { + continue; + } + if (requiredType.isInstance(task) && !typeSpecificTasks.contains(task)) { + typeSpecificTasks.add((T) task); + } + if (task.getDependentTasks() != null) { + childTasks.addAll(task.getDependentTasks()); + } + visited.add(task); + } + // start recursion + tasks = childTasks; } - visited.add(task); } - return childTasks; + return typeSpecificTasks; } /** http://git-wip-us.apache.org/repos/asf/hive/blob/cb770077/ql/src/test/org/apache/hadoop/hive/ql/exec/TestUtilities.java ---------------------------------------------------------------------- diff --git a/ql/src/test/org/apache/hadoop/hive/ql/exec/TestUtilities.java b/ql/src/test/org/apache/hadoop/hive/ql/exec/TestUtilities.java index 1a464c8..9a22c54 100644 --- a/ql/src/test/org/apache/hadoop/hive/ql/exec/TestUtilities.java +++ b/ql/src/test/org/apache/hadoop/hive/ql/exec/TestUtilities.java @@ -72,6 +72,7 @@ import org.apache.hadoop.hive.ql.plan.ExprNodeDesc; import org.apache.hadoop.hive.ql.plan.ExprNodeGenericFuncDesc; import org.apache.hadoop.hive.ql.plan.FileSinkDesc; import org.apache.hadoop.hive.ql.plan.MapWork; +import org.apache.hadoop.hive.ql.plan.MapredWork; import org.apache.hadoop.hive.ql.plan.OperatorDesc; import org.apache.hadoop.hive.ql.plan.PartitionDesc; import org.apache.hadoop.hive.ql.plan.TableDesc; @@ -909,4 +910,25 @@ public class TestUtilities { } + private static Task<MapredWork> getMapredWork() { + return TaskFactory.get(MapredWork.class, new HiveConf()); + } + + @Test + @SuppressWarnings("unchecked") + public void testGetTasksRecursion() { + + Task<MapredWork> rootTask = getMapredWork(); + Task<MapredWork> child1 = getMapredWork(); + Task<MapredWork> child2 = getMapredWork(); + Task<MapredWork> child11 = getMapredWork(); + + rootTask.addDependentTask(child1); + rootTask.addDependentTask(child2); + child1.addDependentTask(child11); + + assertEquals(Lists.newArrayList(rootTask, child1, child2, child11), + Utilities.getMRTasks(getTestDiamondTaskGraph(rootTask))); + + } }