This is an automated email from the ASF dual-hosted git repository.

yashmayya pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new 92096577382 Add Adaptive Routing MSE inflight reqs stats (#18553)
92096577382 is described below

commit 92096577382f0e4b828f1792a3ac494499414651
Author: Timothy Elgersma <[email protected]>
AuthorDate: Fri May 29 18:18:55 2026 -0400

    Add Adaptive Routing MSE inflight reqs stats (#18553)
    
    ---------
    
    Co-authored-by: Claude Opus 4.6 (1M context) <[email protected]>
---
 .../broker/api/resources/PinotBrokerDebug.java     |   8 +-
 .../broker/broker/helix/BaseBrokerStarter.java     |   9 +-
 .../MultiStageBrokerRequestHandler.java            |  18 ++-
 .../apache/pinot/common/metrics/BrokerGauge.java   |   7 +-
 .../routing/stats/ServerRoutingStatsManager.java   | 130 +++++++++++++++------
 .../stats/ServerRoutingStatsManagerTest.java       | 106 +++++++++++++++++
 .../query/service/dispatch/QueryDispatcher.java    |  36 ++++++
 .../service/dispatch/QueryDispatcherTest.java      | 100 ++++++++++++++++
 8 files changed, 370 insertions(+), 44 deletions(-)

diff --git 
a/pinot-broker/src/main/java/org/apache/pinot/broker/api/resources/PinotBrokerDebug.java
 
b/pinot-broker/src/main/java/org/apache/pinot/broker/api/resources/PinotBrokerDebug.java
index 69248672bb5..0e33145ba40 100644
--- 
a/pinot-broker/src/main/java/org/apache/pinot/broker/api/resources/PinotBrokerDebug.java
+++ 
b/pinot-broker/src/main/java/org/apache/pinot/broker/api/resources/PinotBrokerDebug.java
@@ -64,6 +64,7 @@ import org.apache.pinot.spi.accounting.QueryResourceTracker;
 import org.apache.pinot.spi.accounting.ThreadAccountant;
 import org.apache.pinot.spi.accounting.ThreadResourceTracker;
 import org.apache.pinot.spi.config.table.TableType;
+import org.apache.pinot.spi.query.QueryExecutionContext.QueryType;
 import org.apache.pinot.spi.utils.builder.TableNameBuilder;
 import org.apache.pinot.sql.parsers.CalciteSqlCompiler;
 
@@ -277,9 +278,12 @@ public class PinotBrokerDebug {
       @ApiResponse(code = 404, message = "Server routing Stats not found"),
       @ApiResponse(code = 500, message = "Internal server error")
   })
-  public Map<String, ServerRoutingStatsEntry> getServerRoutingStats() {
+  public Map<String, ServerRoutingStatsEntry> getServerRoutingStats(
+      @ApiParam(value = "Query engine type (SSE or MSE)", allowableValues = 
"SSE, MSE")
+      @QueryParam("queryType") String queryTypeStr) {
     if (_serverRoutingStatsManager.isEnabled()) {
-      return _serverRoutingStatsManager.getServerRoutingStats();
+      QueryType queryType = queryTypeStr == null ? QueryType.SSE : 
QueryType.valueOf(queryTypeStr);
+      return _serverRoutingStatsManager.getServerRoutingStats(queryType);
     } else {
       throw new WebApplicationException("Server routing stats is not enabled", 
Response.Status.NOT_FOUND);
     }
diff --git 
a/pinot-broker/src/main/java/org/apache/pinot/broker/broker/helix/BaseBrokerStarter.java
 
b/pinot-broker/src/main/java/org/apache/pinot/broker/broker/helix/BaseBrokerStarter.java
index 7fdd3154b4e..ec18f45ec88 100644
--- 
a/pinot-broker/src/main/java/org/apache/pinot/broker/broker/helix/BaseBrokerStarter.java
+++ 
b/pinot-broker/src/main/java/org/apache/pinot/broker/broker/helix/BaseBrokerStarter.java
@@ -310,10 +310,12 @@ public abstract class BaseBrokerStarter implements 
ServiceStartable {
       QueryQuotaManager queryQuotaManager, TableCache tableCache,
       MultiStageQueryThrottler multiStageQueryThrottler, FailureDetector 
failureDetector,
       ThreadAccountant threadAccountant, MultiClusterRoutingContext 
multiClusterRoutingContext,
-      WorkerManager workerManager, WorkerManager multiClusterWorkerManager) {
+      WorkerManager workerManager, WorkerManager multiClusterWorkerManager,
+      ServerRoutingStatsManager serverRoutingStatsManager) {
     return new MultiStageBrokerRequestHandler(config, brokerId, 
requestIdGenerator, routingManager,
         accessControlFactory, queryQuotaManager, tableCache, 
multiStageQueryThrottler, failureDetector,
-        threadAccountant, multiClusterRoutingContext, workerManager, 
multiClusterWorkerManager);
+        threadAccountant, multiClusterRoutingContext, workerManager, 
multiClusterWorkerManager,
+        serverRoutingStatsManager);
   }
 
   private void setupHelixSystemProperties() {
@@ -559,7 +561,8 @@ public abstract class BaseBrokerStarter implements 
ServiceStartable {
       multiStageBrokerRequestHandler =
           createMultiStageBrokerRequestHandler(_brokerConf, brokerId, 
requestIdGenerator, _routingManager,
               _accessControlFactory, _queryQuotaManager, _tableCache, 
_multiStageQueryThrottler, _failureDetector,
-              _threadAccountant, multiClusterRoutingContext, workerManager, 
multiClusterWorkerManager);
+              _threadAccountant, multiClusterRoutingContext, workerManager, 
multiClusterWorkerManager,
+              _serverRoutingStatsManager);
       MultiStageBrokerRequestHandler finalHandler = 
multiStageBrokerRequestHandler;
       _routingManager.setServerReenableCallback(
           serverInstance -> 
finalHandler.getQueryDispatcher().resetClientConnectionBackoff(serverInstance));
diff --git 
a/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/MultiStageBrokerRequestHandler.java
 
b/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/MultiStageBrokerRequestHandler.java
index 25c268a6bac..1749d13d6e7 100644
--- 
a/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/MultiStageBrokerRequestHandler.java
+++ 
b/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/MultiStageBrokerRequestHandler.java
@@ -76,6 +76,7 @@ import org.apache.pinot.common.utils.tls.TlsUtils;
 import org.apache.pinot.core.routing.MultiClusterRoutingContext;
 import org.apache.pinot.core.routing.RoutingManager;
 import org.apache.pinot.core.transport.ServerInstance;
+import 
org.apache.pinot.core.transport.server.routing.stats.ServerRoutingStatsManager;
 import org.apache.pinot.query.ImmutableQueryEnvironment;
 import org.apache.pinot.query.QueryEnvironment;
 import org.apache.pinot.query.mailbox.MailboxService;
@@ -135,6 +136,8 @@ public class MultiStageBrokerRequestHandler extends 
BaseBrokerRequestHandler {
   private final WorkerManager _multiClusterWorkerManager;
   private final MailboxService _mailboxService;
   private final QueryDispatcher _queryDispatcher;
+  @Nullable
+  private final ServerRoutingStatsManager _serverRoutingStatsManager;
   private final boolean _explainAskingServerDefault;
   private final MultiStageQueryThrottler _queryThrottler;
   private final ExecutorService _queryCompileExecutor;
@@ -155,8 +158,21 @@ public class MultiStageBrokerRequestHandler extends 
BaseBrokerRequestHandler {
       MultiStageQueryThrottler queryThrottler, FailureDetector 
failureDetector, ThreadAccountant threadAccountant,
       MultiClusterRoutingContext multiClusterRoutingContext,
       WorkerManager workerManager, WorkerManager multiClusterWorkerManager) {
+    this(config, brokerId, requestIdGenerator, routingManager, 
accessControlFactory, queryQuotaManager, tableCache,
+        queryThrottler, failureDetector, threadAccountant, 
multiClusterRoutingContext, workerManager,
+        multiClusterWorkerManager, null);
+  }
+
+  public MultiStageBrokerRequestHandler(PinotConfiguration config, String 
brokerId,
+      BrokerRequestIdGenerator requestIdGenerator, RoutingManager 
routingManager,
+      AccessControlFactory accessControlFactory, QueryQuotaManager 
queryQuotaManager, TableCache tableCache,
+      MultiStageQueryThrottler queryThrottler, FailureDetector 
failureDetector, ThreadAccountant threadAccountant,
+      MultiClusterRoutingContext multiClusterRoutingContext,
+      WorkerManager workerManager, WorkerManager multiClusterWorkerManager,
+      @Nullable ServerRoutingStatsManager statsManager) {
     super(config, brokerId, requestIdGenerator, routingManager, 
accessControlFactory, queryQuotaManager, tableCache,
         threadAccountant, multiClusterRoutingContext);
+    _serverRoutingStatsManager = statsManager;
     String hostname = 
config.getProperty(CommonConstants.MultiStageQueryRunner.KEY_OF_QUERY_RUNNER_HOSTNAME);
     int port = 
Integer.parseInt(config.getProperty(CommonConstants.MultiStageQueryRunner.KEY_OF_QUERY_RUNNER_PORT));
 
@@ -707,7 +723,7 @@ public class MultiStageBrokerRequestHandler extends 
BaseBrokerRequestHandler {
         _opchainsStartedMeter.mark(opChainCount);
         try {
           queryResults = _queryDispatcher.submitAndReduce(requestContext, 
dispatchableSubPlan,
-              timer.getRemainingTimeMs(), query.getOptions());
+              timer.getRemainingTimeMs(), query.getOptions(), 
_serverRoutingStatsManager);
         } catch (QueryException e) {
           throw e;
         } catch (Throwable t) {
diff --git 
a/pinot-common/src/main/java/org/apache/pinot/common/metrics/BrokerGauge.java 
b/pinot-common/src/main/java/org/apache/pinot/common/metrics/BrokerGauge.java
index b730aa0e9d5..9e82214def1 100644
--- 
a/pinot-common/src/main/java/org/apache/pinot/common/metrics/BrokerGauge.java
+++ 
b/pinot-common/src/main/java/org/apache/pinot/common/metrics/BrokerGauge.java
@@ -65,12 +65,17 @@ public enum BrokerGauge implements AbstractMetrics.Gauge {
   ADAPTIVE_SERVER_SELECTOR_TYPE("adaptiveServerSelectorType", true),
 
   /**
-   * Per-server adaptive routing stats exported as metrics.
+   * Per-server adaptive routing stats exported as metrics (SSE / single-stage 
engine).
    */
   ADAPTIVE_SERVER_NUM_IN_FLIGHT_REQUESTS("adaptiveServerNumInFlightRequests", 
false),
   ADAPTIVE_SERVER_LATENCY_EMA("adaptiveServerLatencyEma", false),
   ADAPTIVE_SERVER_HYBRID_SCORE("adaptiveServerHybridScore", false),
 
+  /**
+   * Per-server adaptive routing stats exported as metrics (MSE / multi-stage 
engine).
+   */
+  
ADAPTIVE_SERVER_MSE_NUM_IN_FLIGHT_REQUESTS("adaptiveServerMseNumInFlightRequests",
 false),
+
   /**
    * The queue size of ServerRoutingStatsManager main executor service.
    */
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/transport/server/routing/stats/ServerRoutingStatsManager.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/transport/server/routing/stats/ServerRoutingStatsManager.java
index 6c1bea37f14..fe85d62f761 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/transport/server/routing/stats/ServerRoutingStatsManager.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/transport/server/routing/stats/ServerRoutingStatsManager.java
@@ -29,11 +29,14 @@ import java.util.concurrent.Executors;
 import java.util.concurrent.ScheduledExecutorService;
 import java.util.concurrent.ThreadPoolExecutor;
 import java.util.concurrent.TimeUnit;
+import javax.annotation.Nullable;
 import org.apache.commons.lang3.tuple.ImmutablePair;
 import org.apache.commons.lang3.tuple.Pair;
 import org.apache.pinot.common.metrics.BrokerGauge;
 import org.apache.pinot.common.metrics.BrokerMetrics;
 import org.apache.pinot.spi.env.PinotConfiguration;
+import org.apache.pinot.spi.query.QueryExecutionContext.QueryType;
+import org.apache.pinot.spi.query.QueryThreadContext;
 import 
org.apache.pinot.spi.utils.CommonConstants.Broker.AdaptiveServerSelector;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -52,6 +55,7 @@ public class ServerRoutingStatsManager {
   private final BrokerMetrics _brokerMetrics;
   private volatile boolean _isEnabled;
   private ConcurrentHashMap<String, ServerRoutingStatsEntry> 
_serverQueryStatsMap;
+  private ConcurrentHashMap<String, ServerRoutingStatsEntry> 
_mseServerQueryStatsMap;
 
   // Main executor service for collecting and aggregating stats for all 
servers.
   private ExecutorService _executorService;
@@ -104,6 +108,7 @@ public class ServerRoutingStatsManager {
     // Entries in this map are never deleted unless the broker process 
restarts. This is okay for now because the
     // number of servers will be finite and should not cause memory bloat.
     _serverQueryStatsMap = new ConcurrentHashMap<>();
+    _mseServerQueryStatsMap = new ConcurrentHashMap<>();
 
     _enableStatsMetricExport = 
_config.getProperty(AdaptiveServerSelector.CONFIG_OF_ENABLE_STATS_METRIC_EXPORT,
         AdaptiveServerSelector.DEFAULT_ENABLE_STATS_METRIC_EXPORT);
@@ -149,6 +154,29 @@ public class ServerRoutingStatsManager {
     return tpe.getCompletedTaskCount();
   }
 
+  private ConcurrentHashMap<String, ServerRoutingStatsEntry> 
getStatsMap(QueryType queryType) {
+    switch (queryType) {
+      case SSE:
+        return _serverQueryStatsMap;
+      case MSE:
+        return _mseServerQueryStatsMap;
+      default:
+        LOGGER.warn("Unsupported query type for adaptive server routing: {}; 
defaulting to SSE stats map", queryType);
+        return _serverQueryStatsMap;
+    }
+  }
+
+  /// Callers are expected to run on a thread with a {@link 
QueryThreadContext} set.
+  /// If absent, defaults to the SSE stats map with a warning.
+  private ConcurrentHashMap<String, ServerRoutingStatsEntry> getStatsMap() {
+    QueryThreadContext qtc = QueryThreadContext.getIfAvailable();
+    if (qtc == null) {
+      LOGGER.warn("QueryThreadContext is null; defaulting to SSE stats map");
+      return getStatsMap(QueryType.SSE);
+    }
+    return getStatsMap(qtc.getExecutionContext().getQueryType());
+  }
+
   /**
    * Called just before submitting a query to a server. Updates stats 
corresponding to query submission.
    */
@@ -157,18 +185,20 @@ public class ServerRoutingStatsManager {
       return;
     }
 
+    ConcurrentHashMap<String, ServerRoutingStatsEntry> statsMap = 
getStatsMap();
     _executorService.execute(() -> {
       try {
         recordQueueSizeMetrics();
-        updateStatsAfterQuerySubmission(serverInstanceId);
+        updateStatsAfterQuerySubmission(serverInstanceId, statsMap);
       } catch (Exception e) {
         LOGGER.error("Exception caught while updating stats. requestId={}, 
exception={}", requestId, e);
       }
     });
   }
 
-  private void updateStatsAfterQuerySubmission(String serverInstanceId) {
-    ServerRoutingStatsEntry stats = 
_serverQueryStatsMap.computeIfAbsent(serverInstanceId,
+  private void updateStatsAfterQuerySubmission(String serverInstanceId,
+      ConcurrentHashMap<String, ServerRoutingStatsEntry> statsMap) {
+    ServerRoutingStatsEntry stats = statsMap.computeIfAbsent(serverInstanceId,
         k -> new ServerRoutingStatsEntry(serverInstanceId, _alpha, 
_autoDecayWindowMs, _warmupDurationMs,
             _avgInitializationVal, _hybridScoreExponent, 
_hybridScoreQueueFloor, _periodicTaskExecutor));
 
@@ -188,17 +218,19 @@ public class ServerRoutingStatsManager {
       return;
     }
 
+    ConcurrentHashMap<String, ServerRoutingStatsEntry> statsMap = 
getStatsMap();
     _executorService.execute(() -> {
       try {
-        updateStatsUponResponseArrival(serverInstanceId, latency);
+        updateStatsUponResponseArrival(serverInstanceId, latency, statsMap);
       } catch (Exception e) {
         LOGGER.error("Exception caught while updating stats. requestId={}, 
exception={}", requestId, e);
       }
     });
   }
 
-  private void updateStatsUponResponseArrival(String serverInstanceId, long 
latencyMs) {
-    ServerRoutingStatsEntry stats = 
_serverQueryStatsMap.computeIfAbsent(serverInstanceId,
+  private void updateStatsUponResponseArrival(String serverInstanceId, long 
latencyMs,
+      ConcurrentHashMap<String, ServerRoutingStatsEntry> statsMap) {
+    ServerRoutingStatsEntry stats = statsMap.computeIfAbsent(serverInstanceId,
         k -> new ServerRoutingStatsEntry(serverInstanceId, _alpha, 
_autoDecayWindowMs, _warmupDurationMs,
             _avgInitializationVal, _hybridScoreExponent, 
_hybridScoreQueueFloor, _periodicTaskExecutor));
 
@@ -217,10 +249,18 @@ public class ServerRoutingStatsManager {
     return _serverQueryStatsMap;
   }
 
+  public Map<String, ServerRoutingStatsEntry> getServerRoutingStats(QueryType 
queryType) {
+    return getStatsMap(queryType);
+  }
+
   /**
    * Returns ServerRoutingStatsStr for debugging/logging.
    */
   public String getServerRoutingStatsStr() {
+    return getServerRoutingStatsStr(QueryType.SSE);
+  }
+
+  public String getServerRoutingStatsStr(QueryType queryType) {
     if (!_isEnabled) {
       return "";
     }
@@ -228,7 +268,7 @@ public class ServerRoutingStatsManager {
     StringBuilder stringBuilder =
         new 
StringBuilder("(Server=NumInFlightRequests,NumInFlightRequestsEMA,LatencyEMA," 
+ "Score)");
 
-    for (Map.Entry<String, ServerRoutingStatsEntry> entry : 
_serverQueryStatsMap.entrySet()) {
+    for (Map.Entry<String, ServerRoutingStatsEntry> entry : 
getStatsMap(queryType).entrySet()) {
       String server = entry.getKey();
       Preconditions.checkState(entry.getValue() != null, "Server stats is 
null");
       ServerRoutingStatsEntry stats = entry.getValue();
@@ -274,7 +314,7 @@ public class ServerRoutingStatsManager {
       return response;
     }
 
-    for (Map.Entry<String, ServerRoutingStatsEntry> entry : 
_serverQueryStatsMap.entrySet()) {
+    for (Map.Entry<String, ServerRoutingStatsEntry> entry : 
getStatsMap().entrySet()) {
       String server = entry.getKey();
       Preconditions.checkState(entry.getValue() != null, "Server stats is 
null");
       ServerRoutingStatsEntry stats = entry.getValue();
@@ -297,7 +337,7 @@ public class ServerRoutingStatsManager {
       return null;
     }
 
-    ServerRoutingStatsEntry stats = _serverQueryStatsMap.get(server);
+    ServerRoutingStatsEntry stats = getStatsMap().get(server);
     if (stats == null) {
       return null;
     }
@@ -319,7 +359,7 @@ public class ServerRoutingStatsManager {
       return response;
     }
 
-    for (Map.Entry<String, ServerRoutingStatsEntry> entry : 
_serverQueryStatsMap.entrySet()) {
+    for (Map.Entry<String, ServerRoutingStatsEntry> entry : 
getStatsMap().entrySet()) {
       String server = entry.getKey();
       Preconditions.checkState(entry.getValue() != null, "Server stats is 
null");
       ServerRoutingStatsEntry stats = entry.getValue();
@@ -342,7 +382,7 @@ public class ServerRoutingStatsManager {
       return null;
     }
 
-    ServerRoutingStatsEntry stats = _serverQueryStatsMap.get(server);
+    ServerRoutingStatsEntry stats = getStatsMap().get(server);
     if (stats == null) {
       return null;
     }
@@ -365,7 +405,7 @@ public class ServerRoutingStatsManager {
       return response;
     }
 
-    for (Map.Entry<String, ServerRoutingStatsEntry> entry : 
_serverQueryStatsMap.entrySet()) {
+    for (Map.Entry<String, ServerRoutingStatsEntry> entry : 
getStatsMap().entrySet()) {
       String server = entry.getKey();
       Preconditions.checkState(entry.getValue() != null, "Server stats is 
null");
       ServerRoutingStatsEntry stats = entry.getValue();
@@ -388,7 +428,7 @@ public class ServerRoutingStatsManager {
       return null;
     }
 
-    ServerRoutingStatsEntry stats = _serverQueryStatsMap.get(server);
+    ServerRoutingStatsEntry stats = getStatsMap().get(server);
     if (stats == null) {
       return null;
     }
@@ -408,32 +448,48 @@ public class ServerRoutingStatsManager {
 
   private void exportStatsAsMetrics() {
     try {
-      for (Map.Entry<String, ServerRoutingStatsEntry> entry : 
_serverQueryStatsMap.entrySet()) {
-        String serverInstanceId = entry.getKey();
-        ServerRoutingStatsEntry stats = entry.getValue();
-
-        int numInFlightRequests;
-        double latencyEma;
-        double hybridScore;
-
-        stats.getServerReadLock().lock();
-        try {
-          numInFlightRequests = stats.getNumInFlightRequests();
-          latencyEma = stats.getLatencyEMA();
-          hybridScore = stats.computeHybridScore();
-        } finally {
-          stats.getServerReadLock().unlock();
-        }
-
-        
_brokerMetrics.setValueOfGlobalGauge(BrokerGauge.ADAPTIVE_SERVER_NUM_IN_FLIGHT_REQUESTS,
-            "server." + serverInstanceId, numInFlightRequests);
-        
_brokerMetrics.setValueOfGlobalGauge(BrokerGauge.ADAPTIVE_SERVER_LATENCY_EMA, 
"server." + serverInstanceId,
-            (long) latencyEma);
-        
_brokerMetrics.setValueOfGlobalGauge(BrokerGauge.ADAPTIVE_SERVER_HYBRID_SCORE, 
"server." + serverInstanceId,
-            (long) hybridScore);
-      }
+      exportStatsForMap(_serverQueryStatsMap, "server.",
+          BrokerGauge.ADAPTIVE_SERVER_NUM_IN_FLIGHT_REQUESTS,
+          BrokerGauge.ADAPTIVE_SERVER_LATENCY_EMA,
+          BrokerGauge.ADAPTIVE_SERVER_HYBRID_SCORE);
+      // TODO: Export MSE latency stats once we support it
+      exportStatsForMap(_mseServerQueryStatsMap, "server.",
+          BrokerGauge.ADAPTIVE_SERVER_MSE_NUM_IN_FLIGHT_REQUESTS,
+          null,
+          null);
     } catch (Exception e) {
       LOGGER.error("Exception caught while exporting routing stats as 
metrics.", e);
     }
   }
+
+  private void exportStatsForMap(ConcurrentHashMap<String, 
ServerRoutingStatsEntry> statsMap, String tagPrefix,
+      BrokerGauge numInFlightGauge, @Nullable BrokerGauge latencyEmaGauge,
+      @Nullable BrokerGauge hybridScoreGauge) {
+    for (Map.Entry<String, ServerRoutingStatsEntry> entry : 
statsMap.entrySet()) {
+      String serverInstanceId = entry.getKey();
+      ServerRoutingStatsEntry stats = entry.getValue();
+
+      int numInFlightRequests;
+      double latencyEma;
+      double hybridScore;
+
+      stats.getServerReadLock().lock();
+      try {
+        numInFlightRequests = stats.getNumInFlightRequests();
+        latencyEma = stats.getLatencyEMA();
+        hybridScore = stats.computeHybridScore();
+      } finally {
+        stats.getServerReadLock().unlock();
+      }
+
+      String tag = tagPrefix + serverInstanceId;
+      _brokerMetrics.setValueOfGlobalGauge(numInFlightGauge, tag, 
numInFlightRequests);
+      if (latencyEmaGauge != null) {
+        _brokerMetrics.setValueOfGlobalGauge(latencyEmaGauge, tag, (long) 
latencyEma);
+      }
+      if (hybridScoreGauge != null) {
+        _brokerMetrics.setValueOfGlobalGauge(hybridScoreGauge, tag, (long) 
hybridScore);
+      }
+    }
+  }
 }
diff --git 
a/pinot-core/src/test/java/org/apache/pinot/core/transport/server/routing/stats/ServerRoutingStatsManagerTest.java
 
b/pinot-core/src/test/java/org/apache/pinot/core/transport/server/routing/stats/ServerRoutingStatsManagerTest.java
index 988236d557f..17b536091be 100644
--- 
a/pinot-core/src/test/java/org/apache/pinot/core/transport/server/routing/stats/ServerRoutingStatsManagerTest.java
+++ 
b/pinot-core/src/test/java/org/apache/pinot/core/transport/server/routing/stats/ServerRoutingStatsManagerTest.java
@@ -28,6 +28,7 @@ import org.apache.pinot.common.metrics.BrokerMetrics;
 import org.apache.pinot.spi.env.PinotConfiguration;
 import org.apache.pinot.spi.metrics.PinotMetricUtils;
 import org.apache.pinot.spi.metrics.PinotMetricsRegistry;
+import org.apache.pinot.spi.query.QueryThreadContext;
 import org.apache.pinot.spi.utils.CommonConstants;
 import org.apache.pinot.util.TestUtils;
 import org.testng.annotations.BeforeTest;
@@ -455,6 +456,111 @@ public class ServerRoutingStatsManagerTest {
     assertTrue(fastScore < slowScore, "Idle servers should be ranked by 
latency");
   }
 
+  @Test
+  public void testMseAndSseStatsIsolation() {
+    Map<String, Object> properties = new HashMap<>();
+    
properties.put(CommonConstants.Broker.AdaptiveServerSelector.CONFIG_OF_ENABLE_STATS_COLLECTION,
 true);
+    
properties.put(CommonConstants.Broker.AdaptiveServerSelector.CONFIG_OF_EWMA_ALPHA,
 1.0);
+    
properties.put(CommonConstants.Broker.AdaptiveServerSelector.CONFIG_OF_AUTODECAY_WINDOW_MS,
 -1);
+    
properties.put(CommonConstants.Broker.AdaptiveServerSelector.CONFIG_OF_WARMUP_DURATION_MS,
 0);
+    
properties.put(CommonConstants.Broker.AdaptiveServerSelector.CONFIG_OF_AVG_INITIALIZATION_VAL,
 0.0);
+    
properties.put(CommonConstants.Broker.AdaptiveServerSelector.CONFIG_OF_HYBRID_SCORE_EXPONENT,
 3);
+    ServerRoutingStatsManager manager = new ServerRoutingStatsManager(new 
PinotConfiguration(properties),
+        _brokerMetrics);
+    manager.init();
+
+    int requestId = 0;
+
+    // Record 1 SSE request and 2 MSE requests to the same server.
+    try (QueryThreadContext ignored = QueryThreadContext.openForSseTest()) {
+      manager.recordStatsForQuerySubmission(requestId++, "server1");
+    }
+    try (QueryThreadContext ignored = QueryThreadContext.openForMseTest()) {
+      manager.recordStatsForQuerySubmission(requestId++, "server1");
+      manager.recordStatsForQuerySubmission(requestId++, "server1");
+    }
+    waitForStatsUpdate(manager, requestId);
+
+    // SSE should see 1 in-flight, MSE should see 2 — stats are isolated.
+    try (QueryThreadContext ignored = QueryThreadContext.openForSseTest()) {
+      
assertEquals(manager.fetchNumInFlightRequestsForServer("server1").intValue(), 
1);
+    }
+    try (QueryThreadContext ignored = QueryThreadContext.openForMseTest()) {
+      
assertEquals(manager.fetchNumInFlightRequestsForServer("server1").intValue(), 
2);
+    }
+
+    // Complete the SSE request with 50ms latency.
+    try (QueryThreadContext ignored = QueryThreadContext.openForSseTest()) {
+      manager.recordStatsUponResponseArrival(requestId++, "server1", 50);
+    }
+    waitForStatsUpdate(manager, requestId);
+
+    // SSE in-flight drops to 0, MSE remains at 2.
+    try (QueryThreadContext ignored = QueryThreadContext.openForSseTest()) {
+      
assertEquals(manager.fetchNumInFlightRequestsForServer("server1").intValue(), 
0);
+    }
+    try (QueryThreadContext ignored = QueryThreadContext.openForMseTest()) {
+      
assertEquals(manager.fetchNumInFlightRequestsForServer("server1").intValue(), 
2);
+    }
+
+    // SSE latency updated, MSE latency still at init value.
+    try (QueryThreadContext ignored = QueryThreadContext.openForSseTest()) {
+      assertEquals(manager.fetchEMALatencyForServer("server1"), 50.0);
+    }
+    try (QueryThreadContext ignored = QueryThreadContext.openForMseTest()) {
+      assertEquals(manager.fetchEMALatencyForServer("server1"), 0.0);
+    }
+
+    // Complete one MSE request with 200ms latency.
+    try (QueryThreadContext ignored = QueryThreadContext.openForMseTest()) {
+      manager.recordStatsUponResponseArrival(requestId++, "server1", 200);
+    }
+    waitForStatsUpdate(manager, requestId);
+
+    // MSE in-flight drops to 1, SSE still 0.
+    try (QueryThreadContext ignored = QueryThreadContext.openForSseTest()) {
+      
assertEquals(manager.fetchNumInFlightRequestsForServer("server1").intValue(), 
0);
+    }
+    try (QueryThreadContext ignored = QueryThreadContext.openForMseTest()) {
+      
assertEquals(manager.fetchNumInFlightRequestsForServer("server1").intValue(), 
1);
+    }
+
+    // MSE latency updated, SSE latency unchanged.
+    try (QueryThreadContext ignored = QueryThreadContext.openForSseTest()) {
+      assertEquals(manager.fetchEMALatencyForServer("server1"), 50.0);
+    }
+    try (QueryThreadContext ignored = QueryThreadContext.openForMseTest()) {
+      assertEquals(manager.fetchEMALatencyForServer("server1"), 200.0);
+    }
+
+    // Hybrid scores are independent.
+    Double sseScore;
+    Double mseScore;
+    try (QueryThreadContext ignored = QueryThreadContext.openForSseTest()) {
+      sseScore = manager.fetchHybridScoreForServer("server1");
+    }
+    try (QueryThreadContext ignored = QueryThreadContext.openForMseTest()) {
+      mseScore = manager.fetchHybridScoreForServer("server1");
+    }
+    assertNotNull(sseScore);
+    assertNotNull(mseScore);
+    assertNotEquals(sseScore, mseScore);
+
+    // fetchAllServers lists are also isolated.
+    List<Pair<String, Integer>> sseList;
+    List<Pair<String, Integer>> mseList;
+    try (QueryThreadContext ignored = QueryThreadContext.openForSseTest()) {
+      sseList = manager.fetchNumInFlightRequestsForAllServers();
+    }
+    try (QueryThreadContext ignored = QueryThreadContext.openForMseTest()) {
+      mseList = manager.fetchNumInFlightRequestsForAllServers();
+    }
+    assertEquals(sseList.size(), 1);
+    assertEquals(mseList.size(), 1);
+    assertEquals(sseList.get(0).getRight().intValue(), 0);
+    assertEquals(mseList.get(0).getRight().intValue(), 1);
+  }
+
   private void assertStatsNullForInstance(ServerRoutingStatsManager manager, 
String instanceId) {
     Integer numInFlightReq = 
manager.fetchNumInFlightRequestsForServer(instanceId);
     assertNull(numInFlightReq);
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/QueryDispatcher.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/QueryDispatcher.java
index 0ee05dcfe1d..38860780853 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/QueryDispatcher.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/QueryDispatcher.java
@@ -63,6 +63,7 @@ import 
org.apache.pinot.common.utils.DataSchema.ColumnDataType;
 import org.apache.pinot.common.utils.grpc.ServerGrpcQueryClient;
 import org.apache.pinot.core.instance.context.BrokerContext;
 import org.apache.pinot.core.transport.ServerInstance;
+import 
org.apache.pinot.core.transport.server.routing.stats.ServerRoutingStatsManager;
 import org.apache.pinot.core.util.DataBlockExtractUtils;
 import org.apache.pinot.core.util.trace.TracedThreadFactory;
 import org.apache.pinot.query.mailbox.MailboxService;
@@ -167,10 +168,40 @@ public class QueryDispatcher {
   public QueryResult submitAndReduce(RequestContext context, 
DispatchableSubPlan dispatchableSubPlan, long timeoutMs,
       Map<String, String> queryOptions)
       throws Exception {
+    return submitAndReduce(context, dispatchableSubPlan, timeoutMs, 
queryOptions, null);
+  }
+
+  /// Same as {@link #submitAndReduce(RequestContext, DispatchableSubPlan, 
long, Map)} but records per-server
+  /// in-flight request statistics into {@code statsManager} for use by the 
adaptive query router.
+  /// When {@code statsManager} is non-null:
+  /// <ul>
+  ///   <li>Each leaf server is registered as having one more in-flight 
request via
+  ///       {@link ServerRoutingStatsManager#recordStatsForQuerySubmission} 
after the fan-out begins.</li>
+  ///   <li>After the full fan-out completes (or fails), each server is 
decremented via
+  ///       {@link ServerRoutingStatsManager#recordStatsUponResponseArrival} 
with {@code latency = -1}
+  ///       (no latency is recorded at this stage).</li>
+  /// </ul>
+  /// TODO: Replace the coarse end-of-fanout decrement with per-sender arrival 
once per-sender EOS
+  ///       interception is in place, and record real leaf-stage latency at 
that point.
+  public QueryResult submitAndReduce(RequestContext context, 
DispatchableSubPlan dispatchableSubPlan, long timeoutMs,
+      Map<String, String> queryOptions, @Nullable ServerRoutingStatsManager 
statsManager)
+      throws Exception {
     long requestId = context.getRequestId();
     Set<QueryServerInstance> servers = new HashSet<>();
+    // Tracks servers where recordStatsForQuerySubmission was actually called, 
so the finally block only
+    // decrements servers that were incremented — guarding against a partial 
failure in submit().
+    Set<QueryServerInstance> incrementedServers = new HashSet<>();
     try {
       submit(requestId, dispatchableSubPlan, timeoutMs, servers, queryOptions);
+      // The SSE engine increments before `submit`, but here we increment 
after because `submit` populates
+      // the list of servers. Getting the list of servers before calling 
`submit` would expose
+      // implementation details of `submit`.
+      if (statsManager != null) {
+        for (QueryServerInstance server : servers) {
+          statsManager.recordStatsForQuerySubmission(requestId, 
server.getInstanceId());
+          incrementedServers.add(server);
+        }
+      }
       QueryResult result = runReducer(dispatchableSubPlan, queryOptions, 
_mailboxService);
       if (result.getProcessingException() != null) {
         cancel(requestId);
@@ -183,6 +214,11 @@ public class QueryDispatcher {
       cancel(requestId);
       throw e;
     } finally {
+      if (statsManager != null) {
+        for (QueryServerInstance server : incrementedServers) {
+          statsManager.recordStatsUponResponseArrival(requestId, 
server.getInstanceId(), -1);
+        }
+      }
       if (isQueryCancellationEnabled()) {
         _serversByQuery.remove(requestId);
       }
diff --git 
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/dispatch/QueryDispatcherTest.java
 
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/dispatch/QueryDispatcherTest.java
index ad4a374ac14..889486043c9 100644
--- 
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/dispatch/QueryDispatcherTest.java
+++ 
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/dispatch/QueryDispatcherTest.java
@@ -21,6 +21,7 @@ package org.apache.pinot.query.service.dispatch;
 import io.grpc.stub.StreamObserver;
 import java.time.Duration;
 import java.util.ArrayList;
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
@@ -30,18 +31,27 @@ import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.TimeoutException;
 import java.util.concurrent.atomic.AtomicLong;
 import org.apache.pinot.common.failuredetector.FailureDetector;
+import org.apache.pinot.common.metrics.BrokerMetrics;
 import org.apache.pinot.common.proto.Worker;
+import 
org.apache.pinot.core.transport.server.routing.stats.ServerRoutingStatsManager;
 import org.apache.pinot.query.QueryEnvironment;
 import org.apache.pinot.query.QueryEnvironmentTestBase;
 import org.apache.pinot.query.QueryTestSet;
 import org.apache.pinot.query.mailbox.MailboxService;
+import org.apache.pinot.query.planner.physical.DispatchablePlanFragment;
 import org.apache.pinot.query.planner.physical.DispatchableSubPlan;
+import org.apache.pinot.query.routing.QueryServerInstance;
 import org.apache.pinot.query.runtime.QueryRunner;
 import org.apache.pinot.query.service.server.QueryServer;
 import org.apache.pinot.query.testutils.QueryTestUtils;
+import org.apache.pinot.spi.env.PinotConfiguration;
+import org.apache.pinot.spi.metrics.PinotMetricUtils;
+import org.apache.pinot.spi.metrics.PinotMetricsRegistry;
 import org.apache.pinot.spi.query.QueryThreadContext;
 import org.apache.pinot.spi.trace.DefaultRequestContext;
 import org.apache.pinot.spi.trace.RequestContext;
+import org.apache.pinot.spi.utils.CommonConstants;
+import org.apache.pinot.util.TestUtils;
 import org.mockito.Mockito;
 import org.testng.Assert;
 import org.testng.annotations.AfterClass;
@@ -221,4 +231,94 @@ public class QueryDispatcherTest extends QueryTestSet {
       _queryDispatcher.submit(REQUEST_ID_GEN.getAndIncrement(), 
dispatchableSubPlan, 0L, new HashSet<>(), Map.of());
     }
   }
+
+  @Test
+  public void testStatsManagerNotCalledWhenSubmitFails()
+      throws Exception {
+    ServerRoutingStatsManager statsManager = 
Mockito.mock(ServerRoutingStatsManager.class);
+    String sql = "SELECT * FROM a WHERE col1 = 'foo'";
+    long requestId = REQUEST_ID_GEN.getAndIncrement();
+    RequestContext context = new DefaultRequestContext();
+    context.setRequestId(requestId);
+
+    QueryServer failingQueryServer = 
_queryServerMap.values().iterator().next();
+    Mockito.doThrow(new RuntimeException("partial dispatch failure"))
+        .when(failingQueryServer).submit(Mockito.any(), Mockito.any());
+
+    DispatchableSubPlan plan = _queryEnvironment.planQuery(sql);
+    try (QueryThreadContext ignore = QueryThreadContext.openForMseTest()) {
+      _queryDispatcher.submitAndReduce(context, plan, 10_000L, Map.of(), 
statsManager);
+      Assert.fail("Should have thrown");
+    } catch (Exception e) {
+      Assert.assertTrue(e.getMessage().contains("Error dispatching query"));
+    }
+
+    Mockito.verifyNoInteractions(statsManager);
+    Mockito.reset(failingQueryServer);
+  }
+
+  @Test
+  public void testRealStatsManagerInflightReturnsToZero()
+      throws Exception {
+    Map<String, Object> properties = new HashMap<>();
+    
properties.put(CommonConstants.Broker.AdaptiveServerSelector.CONFIG_OF_ENABLE_STATS_COLLECTION,
 true);
+    
properties.put(CommonConstants.Broker.AdaptiveServerSelector.CONFIG_OF_EWMA_ALPHA,
 1.0);
+    
properties.put(CommonConstants.Broker.AdaptiveServerSelector.CONFIG_OF_AUTODECAY_WINDOW_MS,
 -1);
+    
properties.put(CommonConstants.Broker.AdaptiveServerSelector.CONFIG_OF_WARMUP_DURATION_MS,
 0);
+    
properties.put(CommonConstants.Broker.AdaptiveServerSelector.CONFIG_OF_AVG_INITIALIZATION_VAL,
 0.0);
+    
properties.put(CommonConstants.Broker.AdaptiveServerSelector.CONFIG_OF_HYBRID_SCORE_EXPONENT,
 3);
+
+    PinotConfiguration brokerConfig = new PinotConfiguration();
+    PinotMetricsRegistry metricsRegistry = 
PinotMetricUtils.getPinotMetricsRegistry(
+        brokerConfig.subset(CommonConstants.Broker.METRICS_CONFIG_PREFIX));
+    BrokerMetrics brokerMetrics = new BrokerMetrics(
+        CommonConstants.Broker.DEFAULT_METRICS_NAME_PREFIX,
+        metricsRegistry,
+        CommonConstants.Broker.DEFAULT_ENABLE_TABLE_LEVEL_METRICS,
+        Collections.emptyList());
+    brokerMetrics.initializeGlobalMeters();
+    BrokerMetrics.register(brokerMetrics);
+
+    ServerRoutingStatsManager statsManager = new ServerRoutingStatsManager(
+        new PinotConfiguration(properties), brokerMetrics);
+    statsManager.init();
+
+    String sql = "SELECT * FROM a";
+    long requestId = REQUEST_ID_GEN.getAndIncrement();
+    RequestContext context = new DefaultRequestContext();
+    context.setRequestId(requestId);
+    DispatchableSubPlan plan = _queryEnvironment.planQuery(sql);
+
+    Set<String> expectedInstanceIds = new HashSet<>();
+    for (DispatchablePlanFragment fragment : plan.getQueryStagesWithoutRoot()) 
{
+      for (QueryServerInstance server : 
fragment.getServerInstanceToWorkerIdMap().keySet()) {
+        expectedInstanceIds.add(server.getInstanceId());
+      }
+    }
+    Assert.assertFalse(expectedInstanceIds.isEmpty());
+
+    try (QueryThreadContext ignore = QueryThreadContext.openForMseTest()) {
+      _queryDispatcher.submitAndReduce(context, plan, 10_000L, Map.of(), 
statsManager);
+    } catch (NullPointerException e) {
+      // expected: reduce phase fails with mocked MailboxService
+    }
+
+    // Wait for the async executor to process all stats tasks (1 submission + 
1 arrival per server).
+    int expectedTasks = expectedInstanceIds.size() * 2;
+    TestUtils.waitForCondition(
+        aVoid -> statsManager.getCompletedTaskCount() >= expectedTasks,
+        10L, 5000,
+        "Timed out waiting for stats manager to process all tasks");
+
+    try (QueryThreadContext ignore = QueryThreadContext.openForMseTest()) {
+      for (String instanceId : expectedInstanceIds) {
+        Integer numInFlight = 
statsManager.fetchNumInFlightRequestsForServer(instanceId);
+        Assert.assertNotNull(numInFlight, "Expected stats entry for " + 
instanceId);
+        Assert.assertEquals(numInFlight.intValue(), 0,
+            "Expected 0 in-flight requests for " + instanceId + " after 
submitAndReduce returns");
+      }
+    }
+
+    statsManager.shutDown();
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to