gortiz commented on code in PR #18458:
URL: https://github.com/apache/pinot/pull/18458#discussion_r3241865983


##########
pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/streaming/StreamingQuerySession.java:
##########
@@ -0,0 +1,317 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.query.service.dispatch.streaming;
+
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.LinkedHashSet;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.locks.ReentrantLock;
+import javax.annotation.Nullable;
+import org.apache.pinot.common.proto.Worker;
+import org.apache.pinot.query.runtime.plan.MultiStageStatsTreeDecoder;
+import org.apache.pinot.query.runtime.plan.StageStatsTreeNode;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+
+/**
+ * Broker-side per-query state for the {@code SubmitWithStream} dispatch path. 
Owns the per-stage tree accumulator,
+ * the outstanding-opchain count, the per-stage coverage counters, and the set 
of open server streams.
+ *
+ * <p>Concurrency model — all mutating methods acquire the per-session lock, 
so the accumulator and counters need no
+ * additional internal synchronization. gRPC client {@code onNext} callbacks 
land on I/O threads and call into this
+ * session directly; the work per call is short (decode + merge + decrement) 
so doing it on the I/O thread is fine.
+ *
+ * <p>Completion semantics — {@link #awaitCompletion(long, TimeUnit)} returns 
{@code true} as soon as every expected
+ * opchain has reported (early completion), and {@code false} if the timeout 
fires first. The dispatcher should call
+ * it <strong>only after</strong> the broker receiving mailbox has finished, 
so that a successful return means both
+ * "data done" and "stats fully accounted for". When it returns {@code false} 
the per-stage coverage exposes which
+ * stages are missing.
+ */
+public class StreamingQuerySession {
+  private static final Logger LOGGER = 
LoggerFactory.getLogger(StreamingQuerySession.class);
+
+  private final long _requestId;
+  private final int _expectedOpChains;
+  private final CountDownLatch _completionLatch;
+  private final ReentrantLock _lock = new ReentrantLock();
+
+  /** Per-stage merged accumulator. Mutated under {@link #_lock}. */
+  private final Map<Integer, StageStatsTreeNode> _stageAccumulator = new 
HashMap<>();
+  /** Per-stage count of opchains that have responded successfully and merged 
cleanly. */
+  private final Map<Integer, Integer> _respondedByStage = new HashMap<>();
+  /** Per-stage count of opchains that responded but the broker couldn't merge 
their payload. */
+  private final Map<Integer, Integer> _mergeFailedByStage = new HashMap<>();
+
+  /** Set of open server streams. Iteration order is insertion order so cancel 
fan-out is deterministic. */
+  private final Set<StreamingServerHandle> _openStreams = new 
LinkedHashSet<>();
+
+  /** True after the first peer error (success=false OpChainComplete or stream 
onError). Used to trigger fan-out
+   * cancel idempotently. */
+  private boolean _peerErrorObserved = false;
+
+  public StreamingQuerySession(long requestId, int expectedOpChains) {
+    _requestId = requestId;
+    _expectedOpChains = expectedOpChains;
+    _completionLatch = new CountDownLatch(expectedOpChains);
+  }
+
+  public long getRequestId() {
+    return _requestId;
+  }
+
+  public int getExpectedOpChains() {
+    return _expectedOpChains;
+  }
+
+  /**
+   * Registers an open server stream so the session can iterate them later for 
fan-out cancel. Must be called by the
+   * dispatcher when the {@code SubmitWithStream} call is opened.
+   */
+  public void registerStream(StreamingServerHandle stream) {
+    _lock.lock();
+    try {
+      _openStreams.add(stream);
+    } finally {
+      _lock.unlock();
+    }
+  }
+
+  /**
+   * Removes a stream from the open-streams set. Called when the server emits 
{@code ServerDone} (clean close) or the
+   * stream errors. Idempotent.
+   */
+  public void unregisterStream(StreamingServerHandle stream) {
+    _lock.lock();
+    try {
+      _openStreams.remove(stream);
+    } finally {
+      _lock.unlock();
+    }
+  }
+
+  /**
+   * Records an {@link Worker.OpChainComplete} message decoded from a server 
stream. Decrements the outstanding count
+   * and merges the contained tree into the per-stage accumulator (or marks 
the stage {@code mergeFailed} on a shape
+   * mismatch / decode failure). Also records {@code success=false} reports as 
peer errors so fan-out cancel can fire.
+   */
+  public void recordOpChainComplete(Worker.OpChainComplete message) {
+    int stageId = message.getStageId();
+    boolean isSuccess = message.getSuccess();
+    Worker.MultiStageStatsTree statsTree = message.getStats();
+
+    boolean shouldFanOutCancel = false;
+    _lock.lock();
+    try {
+      if (!isSuccess) {
+        if (!_peerErrorObserved) {
+          _peerErrorObserved = true;
+          shouldFanOutCancel = true;
+        }
+      }
+      if (statsTree.hasCurrentStage()) {
+        try {
+          MultiStageStatsTreeDecoder.Decoded decoded = 
MultiStageStatsTreeDecoder.decode(statsTree);

Review Comment:
   This is a valid concern. I've modified the code to decode outside the lock, 
but we need the lock to protect the merge code. I think that, now that we have 
virtual threads, the best way to resolve this is to have a vthread handle 
decoding and merging the stats for each query. This vthread would read the 
stats in raw format from a queue, decode and merge (without having to use 
locks), similar to an actor. I added a comment suggesting that, but didn't want 
to introduce that to not make this PR more conflict



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to