This is an automated email from the ASF dual-hosted git repository.

ncole pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/ambari.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 1e51e6b  [AMBARI-25002] Allow Stack To Define Custom Rolling 
Orchestration Logic (#2695)
1e51e6b is described below

commit 1e51e6b89e7da26ae9ecdf5d7228d20d1dfe61b4
Author: ncole <nc...@hortonworks.com>
AuthorDate: Thu Dec 6 13:43:05 2018 -0500

    [AMBARI-25002] Allow Stack To Define Custom Rolling Orchestration Logic 
(#2695)
---
 .../ambari/spi/upgrade/OrchestrationOptions.java   |  42 ++++
 .../server/stack/upgrade/ColocatedGrouping.java    | 214 ++++++++++++---------
 .../ambari/server/stack/upgrade/Grouping.java      |  11 +-
 .../ambari/server/stack/upgrade/UpgradePack.java   |  12 ++
 .../upgrade/orchestrate/StageWrapperBuilder.java   |  27 ++-
 .../stack/upgrade/orchestrate/UpgradeContext.java  |  62 +++++-
 ambari-server/src/main/resources/upgrade-pack.xsd  |   1 +
 .../upgrade/orchestrate/UpgradeHelperTest.java     |  58 +++++-
 8 files changed, 324 insertions(+), 103 deletions(-)

diff --git 
a/ambari-server-spi/src/main/java/org/apache/ambari/spi/upgrade/OrchestrationOptions.java
 
b/ambari-server-spi/src/main/java/org/apache/ambari/spi/upgrade/OrchestrationOptions.java
new file mode 100644
index 0000000..83dcb1f
--- /dev/null
+++ 
b/ambari-server-spi/src/main/java/org/apache/ambari/spi/upgrade/OrchestrationOptions.java
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.ambari.spi.upgrade;
+
+import org.apache.ambari.spi.ClusterInformation;
+
+/**
+ * A provider may specify the orchestration options for parts of the upgrade.
+ */
+public interface OrchestrationOptions {
+
+  /**
+   * Gets the count of components that may be run in parallel for groupings.
+   * 
+   * @param cluster
+   *          the cluster information containing topology and configurations
+   * @param service
+   *          the name of the service containing the component
+   * @param component
+   *          the name of the component
+   *          
+   * @return the number of slaves that may be run in parallel.  Returning a
+   *          value less than 1 results in non-parallel behavior
+   */
+  int getConcurrencyCount(ClusterInformation cluster, String service, String 
component);
+  
+}
diff --git 
a/ambari-server/src/main/java/org/apache/ambari/server/stack/upgrade/ColocatedGrouping.java
 
b/ambari-server/src/main/java/org/apache/ambari/server/stack/upgrade/ColocatedGrouping.java
index d70d1de..56a201b 100644
--- 
a/ambari-server/src/main/java/org/apache/ambari/server/stack/upgrade/ColocatedGrouping.java
+++ 
b/ambari-server/src/main/java/org/apache/ambari/server/stack/upgrade/ColocatedGrouping.java
@@ -22,12 +22,14 @@ import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashSet;
-import java.util.LinkedHashMap;
 import java.util.LinkedHashSet;
+import java.util.LinkedList;
 import java.util.List;
 import java.util.Map;
 import java.util.Map.Entry;
 import java.util.Set;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.stream.Stream;
 
 import javax.xml.bind.annotation.XmlElement;
 import javax.xml.bind.annotation.XmlType;
@@ -40,6 +42,8 @@ import 
org.apache.ambari.server.stack.upgrade.orchestrate.StageWrapperBuilder;
 import org.apache.ambari.server.stack.upgrade.orchestrate.TaskWrapper;
 import org.apache.ambari.server.stack.upgrade.orchestrate.TaskWrapperBuilder;
 import org.apache.ambari.server.stack.upgrade.orchestrate.UpgradeContext;
+import org.apache.ambari.server.utils.SetUtils;
+import org.apache.ambari.spi.upgrade.OrchestrationOptions;
 import org.apache.commons.collections.CollectionUtils;
 import org.apache.commons.lang.ArrayUtils;
 import org.apache.commons.lang.StringUtils;
@@ -69,7 +73,7 @@ public class ColocatedGrouping extends Grouping {
    */
   @Override
   public StageWrapperBuilder getBuilder() {
-    return new MultiHomedBuilder(this, batch, performServiceCheck);
+    return new MultiHomedBuilder(this, batch, performServiceCheck, 
parallelScheduler);
   }
 
   private static class MultiHomedBuilder extends StageWrapperBuilder {
@@ -77,12 +81,13 @@ public class ColocatedGrouping extends Grouping {
     private Batch m_batch;
     private boolean m_serviceCheck = true;
 
-    // !!! host -> list of tasks
-    private Map<String, List<TaskProxy>> initialBatch = new LinkedHashMap<>();
-    private Map<String, List<TaskProxy>> finalBatches = new LinkedHashMap<>();
+    // lists of tasks
+    private List<TaskProxy> initialBatch = new LinkedList<>();
+    private List<TaskProxy> finalBatches = new LinkedList<>();
 
 
-    private MultiHomedBuilder(Grouping grouping, Batch batch, boolean 
serviceCheck) {
+    private MultiHomedBuilder(Grouping grouping, Batch batch,
+        boolean serviceCheck, ParallelScheduler parallel) {
       super(grouping);
 
       m_batch = batch;
@@ -93,75 +98,107 @@ public class ColocatedGrouping extends Grouping {
     public void add(UpgradeContext context, HostsType hostsType, String 
service,
         boolean clientOnly, ProcessingComponent pc, Map<String, String> 
params) {
 
+      // !!! the percent is the number of co-located items that should be run
+      // before pausing. This value has no bearing on how many should be in a 
single stage.
       int count = Double.valueOf(Math.ceil(
           (double) m_batch.percent / 100 * 
hostsType.getHosts().size())).intValue();
 
-      int i = 0;
-      for (String host : hostsType.getHosts()) {
-        // This class required inserting a single host into the collection
-        HostsType singleHostsType = HostsType.single(host);
+      LinkedHashSet<String> first = new LinkedHashSet<>();
+      LinkedHashSet<String> remaining = new LinkedHashSet<>();
+      hostsType.getHosts().stream().forEach(hostName -> {
+        if (first.size() < count) {
+          first.add(hostName);
+        } else {
+          remaining.add(hostName);
+        }
+      });
+
+      // !!! resolve tasks only once
+      List<Task> preTasks = resolveTasks(context, true, pc);
+      Task task = resolveTask(context, pc);
+      List<Task> postTasks = resolveTasks(context, false, pc);
+      AtomicBoolean processInitial = new AtomicBoolean(true);
+
+      int parallelCount = -1;
+      OrchestrationOptions options = context.getOrchestrationOptions();
+      if (null != options) {
+        parallelCount = context.getOrchestrationOptions().getConcurrencyCount(
+            context.getCluster().buildClusterInformation(), service, pc.name);
+      }
 
-        Map<String, List<TaskProxy>> targetMap = ((i++) < count) ? 
initialBatch : finalBatches;
-        List<TaskProxy> targetList = targetMap.get(host);
+      if (parallelCount < 1) {
+        parallelCount = getParallelHostCount(context, 1);
+      }
 
-        if (null == targetList) {
-          targetList = new ArrayList<>();
-          targetMap.put(host, targetList);
-        }
+      // !!! stupid effective-final check
+      final int hostCount = parallelCount;
 
-        TaskProxy proxy = null;
+      Stream.of(first, remaining).forEach(hosts -> {
 
-        List<Task> tasks = resolveTasks(context, true, pc);
+        List<TaskProxy> targetList = processInitial.get() ? initialBatch : 
finalBatches;
 
-        if (null != tasks && tasks.size() > 0) {
-          // Our assumption is that all of the tasks in the StageWrapper are of
-          // the same type.
-          StageWrapper.Type type = tasks.get(0).getStageWrapperType();
+        List<Set<String>> hostSplit = SetUtils.split(hosts, hostCount);
 
-          proxy = new TaskProxy();
-          proxy.clientOnly = clientOnly;
-          proxy.message = getStageText("Preparing",
-              context.getComponentDisplay(service, pc.name), 
Collections.singleton(host));
-          proxy.tasks.addAll(TaskWrapperBuilder.getTaskList(service, pc.name, 
singleHostsType, tasks, params));
-          proxy.service = service;
-          proxy.component = pc.name;
-          proxy.type = type;
-          targetList.add(proxy);
-        }
+        hostSplit.forEach(hostSet -> {
 
-        // !!! FIXME upgrade definition have only one step, and it better be a 
restart
-        Task t = resolveTask(context, pc);
-        if (null != t && RestartTask.class.isInstance(t)) {
-          proxy = new TaskProxy();
-          proxy.clientOnly = clientOnly;
-          proxy.tasks.add(new TaskWrapper(service, pc.name, 
Collections.singleton(host), params, t));
-          proxy.restart = true;
-          proxy.service = service;
-          proxy.component = pc.name;
-          proxy.type = Type.RESTART;
-          proxy.message = getStageText("Restarting",
-              context.getComponentDisplay(service, pc.name), 
Collections.singleton(host));
-          targetList.add(proxy);
-        }
+          List<Task> tasks = preTasks;
+
+          TaskProxy proxy;
+          if (CollectionUtils.isNotEmpty(preTasks)) {
+            // Our assumption is that all of the tasks in the StageWrapper are 
of
+            // the same type.
+            StageWrapper.Type type = preTasks.get(0).getStageWrapperType();
+
+            proxy = new TaskProxy();
+            proxy.clientOnly = clientOnly;
+            proxy.message = getStageText("Preparing",
+                context.getComponentDisplay(service, pc.name), hostSet);
+            proxy.tasks.addAll(TaskWrapperBuilder.getTaskList(service, pc.name,
+                HostsType.normal(new LinkedHashSet<>(hostSet)), tasks, 
params));
+            proxy.service = service;
+            proxy.component = pc.name;
+            proxy.type = type;
+
+            targetList.add(proxy);
+          }
+
+          if (null != task && RestartTask.class.isInstance(task)) {
+            proxy = new TaskProxy();
+            proxy.clientOnly = clientOnly;
+            proxy.tasks.add(new TaskWrapper(service, pc.name, hostSet, params, 
task));
+            proxy.restart = true;
+            proxy.service = service;
+            proxy.component = pc.name;
+            proxy.type = Type.RESTART;
+            proxy.message = getStageText("Restarting",
+                context.getComponentDisplay(service, pc.name), hostSet);
+
+            targetList.add(proxy);
+          }
+
+          tasks = postTasks;
+          if (CollectionUtils.isNotEmpty(preTasks)) {
+            // Our assumption is that all of the tasks in the StageWrapper are 
of
+            // the same type.
+            StageWrapper.Type type = preTasks.get(0).getStageWrapperType();
+
+            proxy = new TaskProxy();
+            proxy.clientOnly = clientOnly;
+            proxy.message = getStageText("Completing",
+                context.getComponentDisplay(service, pc.name), hostSet);
+            proxy.tasks.addAll(TaskWrapperBuilder.getTaskList(service, pc.name,
+                HostsType.normal(new LinkedHashSet<>(hostSet)), tasks, 
params));
+            proxy.service = service;
+            proxy.component = pc.name;
+            proxy.type = type;
+
+            targetList.add(proxy);
+          }
+        });
+
+        processInitial.set(false);
+      });
 
-        tasks = resolveTasks(context, false, pc);
-
-        if (null != tasks && tasks.size() > 0) {
-          // Our assumption is that all of the tasks in the StageWrapper are of
-          // the same type.
-          StageWrapper.Type type = tasks.get(0).getStageWrapperType();
-
-          proxy = new TaskProxy();
-          proxy.clientOnly = clientOnly;
-          proxy.component = pc.name;
-          proxy.service = service;
-          proxy.type = type;
-          proxy.tasks.addAll(TaskWrapperBuilder.getTaskList(service, pc.name, 
singleHostsType, tasks, params));
-          proxy.message = getStageText("Completing",
-              context.getComponentDisplay(service, pc.name), 
Collections.singleton(host));
-          targetList.add(proxy);
-        }
-      }
     }
 
 
@@ -222,47 +259,45 @@ public class ColocatedGrouping extends Grouping {
     }
 
     private List<StageWrapper> fromProxies(Direction direction,
-        Map<String, List<TaskProxy>> wrappers, Predicate<Task> predicate) {
+        List<TaskProxy> proxies, Predicate<Task> predicate) {
 
       List<StageWrapper> results = new ArrayList<>();
 
       Set<String> serviceChecks = new HashSet<>();
 
-      for (Entry<String, List<TaskProxy>> entry : wrappers.entrySet()) {
 
-        // !!! stage per host, per type
+      proxies.forEach(proxy -> {
         StageWrapper wrapper = null;
         List<StageWrapper> execwrappers = new ArrayList<>();
 
-        for (TaskProxy t : entry.getValue()) {
-          if (!t.clientOnly) {
-            serviceChecks.add(t.service);
-          }
-
-          if (!t.restart) {
-            if (null == wrapper) {
-              TaskWrapper[] tasks = t.getTasksArray(predicate);
-
-              if (LOG.isDebugEnabled()) {
-                for (TaskWrapper tw : tasks) {
-                  LOG.debug("{}", tw);
-                }
-              }
+        if (!proxy.clientOnly) {
+          serviceChecks.add(proxy.service);
+        }
 
-              if (ArrayUtils.isNotEmpty(tasks)) {
-                wrapper = new StageWrapper(t.type, t.message, tasks);
-              }
-            }
-          } else {
-            TaskWrapper[] tasks = t.getTasksArray(null);
+        if (!proxy.restart) {
+          if (null == wrapper) {
+            TaskWrapper[] tasks = proxy.getTasksArray(predicate);
 
             if (LOG.isDebugEnabled()) {
               for (TaskWrapper tw : tasks) {
                 LOG.debug("{}", tw);
               }
             }
-            execwrappers.add(new StageWrapper(StageWrapper.Type.RESTART, 
t.message, tasks));
+
+            if (ArrayUtils.isNotEmpty(tasks)) {
+              wrapper = new StageWrapper(proxy.type, proxy.message, tasks);
+            }
           }
+        } else {
+          TaskWrapper[] tasks = proxy.getTasksArray(null);
+
+          if (LOG.isDebugEnabled()) {
+            for (TaskWrapper tw : tasks) {
+              LOG.debug("{}", tw);
+            }
+          }
+          // !!! TODO check parallelism values
+          execwrappers.add(new StageWrapper(StageWrapper.Type.RESTART, 
proxy.message, tasks));
         }
 
         if (null != wrapper) {
@@ -272,8 +307,7 @@ public class ColocatedGrouping extends Grouping {
         if (execwrappers.size() > 0) {
           results.addAll(execwrappers);
         }
-
-      }
+      });
 
       if (direction.isUpgrade() && m_serviceCheck && serviceChecks.size() > 0) 
{
         // !!! add the service check task
diff --git 
a/ambari-server/src/main/java/org/apache/ambari/server/stack/upgrade/Grouping.java
 
b/ambari-server/src/main/java/org/apache/ambari/server/stack/upgrade/Grouping.java
index f5e7d08..b2c02a6 100644
--- 
a/ambari-server/src/main/java/org/apache/ambari/server/stack/upgrade/Grouping.java
+++ 
b/ambari-server/src/main/java/org/apache/ambari/server/stack/upgrade/Grouping.java
@@ -229,15 +229,8 @@ public class Grouping {
       for (TaskWrapper tw : tasks) {
         List<Set<String>> hostSets = null;
 
-        if (m_grouping.parallelScheduler != null) {
-          int taskParallelism = 
m_grouping.parallelScheduler.maxDegreeOfParallelism;
-          if (taskParallelism == Integer.MAX_VALUE) {
-            taskParallelism = ctx.getDefaultMaxDegreeOfParallelism();
-          }
-          hostSets = SetUtils.split(tw.getHosts(), taskParallelism);
-        } else {
-          hostSets = SetUtils.split(tw.getHosts(), 1);
-        }
+        int parallel = getParallelHostCount(ctx, 1);
+        hostSets = SetUtils.split(tw.getHosts(), parallel);
 
         int numBatchesNeeded = hostSets.size();
         int batchNum = 0;
diff --git 
a/ambari-server/src/main/java/org/apache/ambari/server/stack/upgrade/UpgradePack.java
 
b/ambari-server/src/main/java/org/apache/ambari/server/stack/upgrade/UpgradePack.java
index 075ff06..a47606e 100644
--- 
a/ambari-server/src/main/java/org/apache/ambari/server/stack/upgrade/UpgradePack.java
+++ 
b/ambari-server/src/main/java/org/apache/ambari/server/stack/upgrade/UpgradePack.java
@@ -82,6 +82,9 @@ public class UpgradePack {
   @XmlElement(name="prerequisite-checks")
   private PrerequisiteChecks prerequisiteChecks;
 
+  @XmlElement(name="orchestration-options-class")
+  private String orchestrationOptionsClass;
+
   /**
    * In the case of a rolling upgrade, will specify processing logic for a 
particular component.
    * NonRolling upgrades are simpler so the "processing" is embedded into the  
group's "type", which is a function like
@@ -882,4 +885,13 @@ public class UpgradePack {
   public StackId getOwnerStackId() {
     return ownerStackId;
   }
+
+  /**
+   * @return
+   *      the class name used for orchestration options
+   */
+  public String getOrchestrationOptions() {
+    return orchestrationOptionsClass;
+  }
+
 }
diff --git 
a/ambari-server/src/main/java/org/apache/ambari/server/stack/upgrade/orchestrate/StageWrapperBuilder.java
 
b/ambari-server/src/main/java/org/apache/ambari/server/stack/upgrade/orchestrate/StageWrapperBuilder.java
index 9214490..d7e9b50 100644
--- 
a/ambari-server/src/main/java/org/apache/ambari/server/stack/upgrade/orchestrate/StageWrapperBuilder.java
+++ 
b/ambari-server/src/main/java/org/apache/ambari/server/stack/upgrade/orchestrate/StageWrapperBuilder.java
@@ -26,6 +26,7 @@ import java.util.Set;
 import 
org.apache.ambari.server.serveraction.upgrades.AutoSkipFailedSummaryAction;
 import org.apache.ambari.server.stack.HostsType;
 import org.apache.ambari.server.stack.upgrade.Grouping;
+import org.apache.ambari.server.stack.upgrade.ParallelScheduler;
 import org.apache.ambari.server.stack.upgrade.ServerActionTask;
 import org.apache.ambari.server.stack.upgrade.ServiceCheckGrouping;
 import org.apache.ambari.server.stack.upgrade.Task;
@@ -199,7 +200,7 @@ public abstract class StageWrapperBuilder {
   }
 
   /**
-   * Determine the list of tasks given these rules
+   * Determine the list of pre- or post-tasks given these rules
    * <ul>
    *   <li>When performing an upgrade, only use upgrade tasks</li>
    *   <li>When performing a downgrade, use the downgrade tasks if they are 
defined</li>
@@ -263,4 +264,28 @@ public abstract class StageWrapperBuilder {
 
     return null;
   }
+
+  /**
+   * Gets the parallel setting for a grouping, if defined.
+   *
+   * @param ctx
+   *          the upgrade context
+   * @param defaultValue
+   *          if the parallel scheduler is not found, return this value instead
+   * @return
+   *          the count of hosts to run in parallel
+   */
+  protected int getParallelHostCount(UpgradeContext ctx, int defaultValue) {
+
+    if (m_grouping.parallelScheduler != null) {
+      int taskParallelism = 
m_grouping.parallelScheduler.maxDegreeOfParallelism;
+      if (taskParallelism == 
ParallelScheduler.DEFAULT_MAX_DEGREE_OF_PARALLELISM) {
+        taskParallelism = ctx.getDefaultMaxDegreeOfParallelism();
+      }
+      return taskParallelism;
+    }
+
+    return defaultValue;
+  }
+
 }
diff --git 
a/ambari-server/src/main/java/org/apache/ambari/server/stack/upgrade/orchestrate/UpgradeContext.java
 
b/ambari-server/src/main/java/org/apache/ambari/server/stack/upgrade/orchestrate/UpgradeContext.java
index dc1cf8e..56094c0 100644
--- 
a/ambari-server/src/main/java/org/apache/ambari/server/stack/upgrade/orchestrate/UpgradeContext.java
+++ 
b/ambari-server/src/main/java/org/apache/ambari/server/stack/upgrade/orchestrate/UpgradeContext.java
@@ -83,10 +83,12 @@ import org.apache.ambari.server.state.SecurityType;
 import org.apache.ambari.server.state.Service;
 import org.apache.ambari.server.state.ServiceComponentHost;
 import org.apache.ambari.server.state.StackId;
+import org.apache.ambari.server.state.StackInfo;
 import org.apache.ambari.server.state.repository.ClusterVersionSummary;
 import org.apache.ambari.server.state.repository.VersionDefinitionXml;
 import org.apache.ambari.spi.RepositoryType;
 import org.apache.ambari.spi.RepositoryVersion;
+import org.apache.ambari.spi.upgrade.OrchestrationOptions;
 import org.apache.ambari.spi.upgrade.UpgradeCheckStatus;
 import org.apache.ambari.spi.upgrade.UpgradeInformation;
 import org.apache.ambari.spi.upgrade.UpgradeType;
@@ -287,6 +289,8 @@ public class UpgradeContext {
   @Inject
   private Configuration configuration;
 
+  private OrchestrationOptions m_orchestrationOptions;
+
   /**
    * Reading upgrade type from provided request  or if nothing were provided,
    * from previous upgrade for downgrade direction.
@@ -400,6 +404,8 @@ public class UpgradeContext {
       m_direction = Direction.DOWNGRADE;
       m_orchestration = revertUpgrade.getOrchestration();
       m_upgradePack = getUpgradePack(revertUpgrade);
+      m_orchestrationOptions = getOrchestrationOptions(metaInfo, 
m_upgradePack);
+
     } else {
 
       // determine direction
@@ -446,6 +452,8 @@ public class UpgradeContext {
               upgradeFromRepositoryVersion.getStackId(), 
m_repositoryVersion.getStackId(), m_direction,
               m_type, preferredUpgradePackName);
 
+          m_orchestrationOptions = getOrchestrationOptions(metaInfo, 
m_upgradePack);
+
           break;
         }
         case DOWNGRADE:{
@@ -464,6 +472,7 @@ public class UpgradeContext {
           }
 
           m_upgradePack = getUpgradePack(upgrade);
+          m_orchestrationOptions = getOrchestrationOptions(metaInfo, 
m_upgradePack);
 
           break;
         }
@@ -561,6 +570,8 @@ public class UpgradeContext {
 
     m_isRevert = upgradeEntity.getOrchestration().isRevertable()
         && upgradeEntity.getDirection() == Direction.DOWNGRADE;
+
+    m_orchestrationOptions = getOrchestrationOptions(ambariMetaInfo, 
m_upgradePack);
   }
 
   /**
@@ -1054,6 +1065,13 @@ public class UpgradeContext {
   }
 
   /**
+   * @return the orchestration options, or {@code null} if not defined
+   */
+  public OrchestrationOptions getOrchestrationOptions() {
+    return m_orchestrationOptions;
+  }
+
+  /**
    * Gets the set of services which will participate in the upgrade. The
    * services available in the repository are compared against those installed
    * in the cluster to arrive at the final subset.
@@ -1451,7 +1469,7 @@ public class UpgradeContext {
    * @return
    *          the upgrade pack.  May be {@code null} if it doesn't exist
    */
-  UpgradePack getUpgradePack(UpgradeEntity upgrade) {
+  private UpgradePack getUpgradePack(UpgradeEntity upgrade) {
     StackId stackId = upgrade.getUpgradePackStackId();
 
     Map<String, UpgradePack> packs = m_metaInfo.getUpgradePacks(
@@ -1499,4 +1517,44 @@ public class UpgradeContext {
 
     return upgradeInformation;
   }
-}
\ No newline at end of file
+
+  /**
+   * Loads the orchestration options for the context
+   *
+   * @param metaInfo
+   *          the ambari meta-info used to load custom classes
+   * @param pack
+   *          the upgrade pack
+   * @return
+   *          the orchestration options instance.  Can return {@code null}.
+   */
+  private OrchestrationOptions getOrchestrationOptions(AmbariMetaInfo 
metaInfo, UpgradePack pack) {
+
+    // !!! only for testing
+    if (null == pack) {
+       return null;
+    }
+
+    String className = pack.getOrchestrationOptions();
+
+    if (null == className) {
+      return null;
+    }
+
+    StackId stackId = pack.getOwnerStackId();
+
+    try {
+      StackInfo stack = metaInfo.getStack(stackId);
+      ClassLoader cl = stack.getLibraryClassLoader();
+
+      Class<?> clazz = (null == cl) ? Class.forName(className) :
+        cl.loadClass(className);
+
+      return (OrchestrationOptions) clazz.newInstance();
+    } catch (Exception e) {
+      LOG.error(String.format("Could not load orchestration options for stack 
{}: {}",
+          stackId, e.getMessage()));
+      return null;
+    }
+  }
+}
diff --git a/ambari-server/src/main/resources/upgrade-pack.xsd 
b/ambari-server/src/main/resources/upgrade-pack.xsd
index b743095..d8f952e 100644
--- a/ambari-server/src/main/resources/upgrade-pack.xsd
+++ b/ambari-server/src/main/resources/upgrade-pack.xsd
@@ -478,6 +478,7 @@
           <xs:element name="source-stack" type="xs:string" />
           <xs:element name="target-stack" type="xs:string" />
         </xs:choice>
+        <xs:element name="orchestration-options-class" minOccurs="0" 
type="xs:string" />
         <xs:element name="skip-failures" minOccurs="0" type="xs:boolean" />
         <xs:element name="skip-service-check-failures" minOccurs="0" 
type="xs:boolean" />
         <xs:element name="downgrade-allowed" minOccurs="0" type="xs:boolean" />
diff --git 
a/ambari-server/src/test/java/org/apache/ambari/server/stack/upgrade/orchestrate/UpgradeHelperTest.java
 
b/ambari-server/src/test/java/org/apache/ambari/server/stack/upgrade/orchestrate/UpgradeHelperTest.java
index e78af4c..ebe69c7 100644
--- 
a/ambari-server/src/test/java/org/apache/ambari/server/stack/upgrade/orchestrate/UpgradeHelperTest.java
+++ 
b/ambari-server/src/test/java/org/apache/ambari/server/stack/upgrade/orchestrate/UpgradeHelperTest.java
@@ -48,6 +48,7 @@ import java.util.Optional;
 import java.util.Set;
 import java.util.regex.Matcher;
 import java.util.regex.Pattern;
+import java.util.stream.Collectors;
 
 import org.apache.ambari.annotations.Experimental;
 import org.apache.ambari.annotations.ExperimentalFeature;
@@ -113,7 +114,9 @@ import org.apache.ambari.server.state.ServiceInfo;
 import org.apache.ambari.server.state.StackId;
 import org.apache.ambari.server.state.UpgradeState;
 import org.apache.ambari.server.utils.EventBusSynchronizer;
+import org.apache.ambari.spi.ClusterInformation;
 import org.apache.ambari.spi.RepositoryType;
+import org.apache.ambari.spi.upgrade.OrchestrationOptions;
 import org.apache.ambari.spi.upgrade.UpgradeType;
 import org.easymock.Capture;
 import org.easymock.EasyMock;
@@ -712,7 +715,7 @@ public class UpgradeHelperTest extends EasyMockSupport {
     assertEquals(4, groups.get(0).items.size());
     assertEquals(8, groups.get(1).items.size());
     assertEquals(6, groups.get(2).items.size());
-    assertEquals(7, groups.get(3).items.size());
+    assertEquals(6, groups.get(3).items.size());
     assertEquals(8, groups.get(4).items.size());
   }
 
@@ -2870,6 +2873,46 @@ public class UpgradeHelperTest extends EasyMockSupport {
     assertEquals(2, taskWrapper.getHosts().size());
   }
 
+  @Test
+  public void testOrchestrationOptions() throws Exception {
+
+    Map<String, UpgradePack> upgrades = ambariMetaInfo.getUpgradePacks("HDP", 
"2.2.0");
+    assertTrue(upgrades.containsKey("upgrade_from_211"));
+
+    UpgradePack upgrade = upgrades.get("upgrade_from_211");
+    assertNotNull(upgrade);
+
+    Cluster cluster = makeCluster();
+
+    UpgradeContext context = getMockUpgradeContext(cluster, Direction.UPGRADE, 
UpgradeType.ROLLING, false);
+
+    SimpleOrchestrationOptions options = new SimpleOrchestrationOptions(1);
+
+    expect(context.getOrchestrationOptions()).andReturn(options).anyTimes();
+    replay(context);
+
+    List<UpgradeGroupHolder> groups = m_upgradeHelper.createSequence(upgrade, 
context);
+    groups = groups.stream().filter(g -> 
g.name.equals("CORE_SLAVES")).collect(Collectors.toList());
+    assertEquals(1, groups.size());
+
+    List<StageWrapper> restarts = groups.get(0).items.stream().filter(sw ->
+        sw.getType() == StageWrapper.Type.RESTART && 
sw.getText().contains("DataNode"))
+        .collect(Collectors.toList());
+
+    assertEquals("Expecting wrappers for each of 3 hosts", 3, restarts.size());
+
+    options.m_count = 2;
+    groups = m_upgradeHelper.createSequence(upgrade, context);
+    groups = groups.stream().filter(g -> 
g.name.equals("CORE_SLAVES")).collect(Collectors.toList());
+    assertEquals(1, groups.size());
+
+    restarts = groups.get(0).items.stream().filter(sw ->
+        sw.getType() == StageWrapper.Type.RESTART && 
sw.getText().contains("DataNode"))
+        .collect(Collectors.toList());
+
+    assertEquals("Expecting wrappers for each", 2, restarts.size());
+  }
+
   /**
    * Builds a mock upgrade context using the following parameters:
    * <ul>
@@ -3214,4 +3257,17 @@ public class UpgradeHelperTest extends EasyMockSupport {
     }
 
   }
+
+  private static class SimpleOrchestrationOptions implements 
OrchestrationOptions {
+    private int m_count;
+
+    private SimpleOrchestrationOptions(int count) {
+      m_count = count;
+    }
+
+    @Override
+    public int getConcurrencyCount(ClusterInformation cluster, String service, 
String component) {
+      return m_count;
+    }
+  }
 }

Reply via email to