HIVE-17338: Utilities.get*Tasks multiple methods duplicate code (Gergely Hajós 
via Zoltan Haindrich)

Signed-off-by: Zoltan Haindrich <k...@rxd.hu>


Project: http://git-wip-us.apache.org/repos/asf/hive/repo
Commit: http://git-wip-us.apache.org/repos/asf/hive/commit/cb770077
Tree: http://git-wip-us.apache.org/repos/asf/hive/tree/cb770077
Diff: http://git-wip-us.apache.org/repos/asf/hive/diff/cb770077

Branch: refs/heads/master
Commit: cb770077d47931ff33a5f272982133b03f8d2a75
Parents: 229a7cc
Author: Gergely Hajós <rogoz...@gmail.com>
Authored: Thu Sep 21 10:23:52 2017 +0200
Committer: Zoltan Haindrich <k...@rxd.hu>
Committed: Thu Sep 21 10:23:52 2017 +0200

----------------------------------------------------------------------
 .../apache/hadoop/hive/ql/exec/Utilities.java   | 106 +++++--------------
 .../hadoop/hive/ql/exec/TestUtilities.java      |  22 ++++
 2 files changed, 46 insertions(+), 82 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/hive/blob/cb770077/ql/src/java/org/apache/hadoop/hive/ql/exec/Utilities.java
----------------------------------------------------------------------
diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/Utilities.java 
b/ql/src/java/org/apache/hadoop/hive/ql/exec/Utilities.java
index 4322cc6..ae70cba 100644
--- a/ql/src/java/org/apache/hadoop/hive/ql/exec/Utilities.java
+++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/Utilities.java
@@ -2444,99 +2444,41 @@ public final class Utilities {
   }
 
   public static List<TezTask> getTezTasks(List<Task<? extends Serializable>> 
tasks) {
-    List<TezTask> tezTasks = new ArrayList<TezTask>();
-    if (tasks != null) {
-      Set<Task<? extends Serializable>> visited = new HashSet<Task<? extends 
Serializable>>();
-      while (!tasks.isEmpty()) {
-        tasks = getTezTasks(tasks, tezTasks, visited);
-      }
-    }
-    return tezTasks;
-  }
-
-  private static List<Task<? extends Serializable>> getTezTasks(
-          List<Task<? extends Serializable>> tasks,
-          List<TezTask> tezTasks,
-          Set<Task<? extends Serializable>> visited) {
-    List<Task<? extends Serializable>> childTasks = new ArrayList<>();
-    for (Task<? extends Serializable> task : tasks) {
-      if (visited.contains(task)) {
-        continue;
-      }
-      if (task instanceof TezTask && !tezTasks.contains(task)) {
-        tezTasks.add((TezTask) task);
-      }
-
-      if (task.getDependentTasks() != null) {
-        childTasks.addAll(task.getDependentTasks());
-      }
-      visited.add(task);
-    }
-    return childTasks;
+    return getTasks(tasks, TezTask.class);
   }
 
   public static List<SparkTask> getSparkTasks(List<Task<? extends 
Serializable>> tasks) {
-    List<SparkTask> sparkTasks = new ArrayList<SparkTask>();
-    if (tasks != null) {
-      Set<Task<? extends Serializable>> visited = new HashSet<Task<? extends 
Serializable>>();
-      while (!tasks.isEmpty()) {
-        tasks = getSparkTasks(tasks, sparkTasks, visited);
-      }
-    }
-    return sparkTasks;
+    return getTasks(tasks, SparkTask.class);
   }
 
-  private static List<Task<? extends Serializable>> getSparkTasks(
-          List<Task<? extends Serializable>> tasks,
-          List<SparkTask> sparkTasks,
-          Set<Task<? extends Serializable>> visited) {
-    List<Task<? extends Serializable>> childTasks = new ArrayList<>();
-    for (Task<? extends Serializable> task : tasks) {
-      if (visited.contains(task)) {
-        continue;
-      }
-      if (task instanceof SparkTask && !sparkTasks.contains(task)) {
-        sparkTasks.add((SparkTask) task);
-      }
-
-      if (task.getDependentTasks() != null) {
-        childTasks.addAll(task.getDependentTasks());
-      }
-      visited.add(task);
-    }
-    return childTasks;
+  public static List<ExecDriver> getMRTasks(List<Task<? extends Serializable>> 
tasks) {
+    return getTasks(tasks, ExecDriver.class);
   }
 
-  public static List<ExecDriver> getMRTasks(List<Task<? extends Serializable>> 
tasks) {
-    List<ExecDriver> mrTasks = new ArrayList<ExecDriver>();
+  @SuppressWarnings("unchecked")
+  public static <T> List<T> getTasks(List<Task<? extends Serializable>> tasks, 
Class<T> requiredType) {
+    List<T> typeSpecificTasks = new ArrayList<>();
     if (tasks != null) {
-      Set<Task<? extends Serializable>> visited = new HashSet<Task<? extends 
Serializable>>();
+      Set<Task<? extends Serializable>> visited = new HashSet<>();
       while (!tasks.isEmpty()) {
-        tasks = getMRTasks(tasks, mrTasks, visited);
-      }
-    }
-    return mrTasks;
-  }
-
-  private static List<Task<? extends Serializable>> getMRTasks(
-          List<Task<? extends Serializable>> tasks,
-          List<ExecDriver> mrTasks,
-          Set<Task<? extends Serializable>> visited) {
-    List<Task<? extends Serializable>> childTasks = new ArrayList<>();
-    for (Task<? extends Serializable> task : tasks) {
-      if (visited.contains(task)) {
-        continue;
-      }
-      if (task instanceof ExecDriver && !mrTasks.contains(task)) {
-        mrTasks.add((ExecDriver) task);
-      }
-
-      if (task.getDependentTasks() != null) {
-        childTasks.addAll(task.getDependentTasks());
+        List<Task<? extends Serializable>> childTasks = new ArrayList<>();
+        for (Task<? extends Serializable> task : tasks) {
+          if (visited.contains(task)) {
+            continue;
+          }
+          if (requiredType.isInstance(task) && 
!typeSpecificTasks.contains(task)) {
+            typeSpecificTasks.add((T) task);
+          }
+          if (task.getDependentTasks() != null) {
+            childTasks.addAll(task.getDependentTasks());
+          }
+          visited.add(task);
+        }
+        // start recursion
+        tasks = childTasks;
       }
-      visited.add(task);
     }
-    return childTasks;
+    return typeSpecificTasks;
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/hive/blob/cb770077/ql/src/test/org/apache/hadoop/hive/ql/exec/TestUtilities.java
----------------------------------------------------------------------
diff --git a/ql/src/test/org/apache/hadoop/hive/ql/exec/TestUtilities.java 
b/ql/src/test/org/apache/hadoop/hive/ql/exec/TestUtilities.java
index 1a464c8..9a22c54 100644
--- a/ql/src/test/org/apache/hadoop/hive/ql/exec/TestUtilities.java
+++ b/ql/src/test/org/apache/hadoop/hive/ql/exec/TestUtilities.java
@@ -72,6 +72,7 @@ import org.apache.hadoop.hive.ql.plan.ExprNodeDesc;
 import org.apache.hadoop.hive.ql.plan.ExprNodeGenericFuncDesc;
 import org.apache.hadoop.hive.ql.plan.FileSinkDesc;
 import org.apache.hadoop.hive.ql.plan.MapWork;
+import org.apache.hadoop.hive.ql.plan.MapredWork;
 import org.apache.hadoop.hive.ql.plan.OperatorDesc;
 import org.apache.hadoop.hive.ql.plan.PartitionDesc;
 import org.apache.hadoop.hive.ql.plan.TableDesc;
@@ -909,4 +910,25 @@ public class TestUtilities {
 
   }
 
+  private static Task<MapredWork> getMapredWork() {
+    return TaskFactory.get(MapredWork.class, new HiveConf());
+  }
+
+  @Test
+  @SuppressWarnings("unchecked")
+  public void testGetTasksRecursion() {
+
+    Task<MapredWork> rootTask = getMapredWork();
+    Task<MapredWork> child1 = getMapredWork();
+    Task<MapredWork> child2 = getMapredWork();
+    Task<MapredWork> child11 = getMapredWork();
+
+    rootTask.addDependentTask(child1);
+    rootTask.addDependentTask(child2);
+    child1.addDependentTask(child11);
+
+    assertEquals(Lists.newArrayList(rootTask, child1, child2, child11),
+        Utilities.getMRTasks(getTestDiamondTaskGraph(rootTask)));
+
+  }
 }

Reply via email to